# macro3.py
# Andrew Davison, ad@coe.psu.ac.th, May 2026

"""
Reads a macro source file whose name ends in ".m3" and
writes the expanded text with the extension ".txt".

Usage:
  python macro3.py define.m3
  python3 macro3.py -trace changeq.m3

Converted from macro3.pas in
  "Software Tools in Pascal"
  Brian W. Kernighan, P. J. Plauger
  Addison-Wesley, 1981
  Chapter 8

Supported built-in macros
 * define(name, body)
    – define a macro; $1..$9 expand to arguments
 * expr(e)
    – evaluate an arithmetic expression (+,-,*,/,%)
 * substr(s, from, n)
    – substring of s starting at 0-based index from, n chars
 * ifelse(a, b, t, f)
    – if a==b expand to t, else f
 * len(s)
    – length of s
 * changeq(lq, rq)
    – change quote characters (default ` and ')
"""

import sys
import re
from collections import deque


# Symbol-table types
DEFTYPE  = 'define'
MACTYPE  = 'macro'
IFTYPE   = 'if'
SUBTYPE  = 'substr'
EXPRTYPE = 'expr'
LENTYPE  = 'len'
CHQTYPE  = 'changeq'


# Symbol table: name -> (defn: str, kind)
symTable = {}

# Macro call stack (each frame holds its own argument state)
frames = []
''' frame = {
      'name':token, 'kind':kind,
      'args':[], 'plev':0, 'current':''  }

  Effectively the parser state for one active macro call.

  A single frame variable would only work if macro calls could never nest.

  The Pascal version used three separate arrays
  (callstk, typestk, and plev) and a shared argstk. To find
  the state of the current macro, you had to use a common
  index cp across all these arrays.

  frames['current'] collects characters as they are read
  from the input. It serves the same role as the Pascal
  evalStk
'''

# Quote characters
lquote = '`'
rquote = "'"


def macro(src):
  global frames
  frames.clear()

  inQ = deque(src)
  outChs = []

  # Initialize built-ins
  for name, kind in [('define', DEFTYPE),
          ('expr', EXPRTYPE), ('substr', SUBTYPE),
          ('ifelse', IFTYPE), ('len', LENTYPE),
          ('changeq', CHQTYPE)]:
    symTable[name] = ('', kind)

  while True:
    token = getToken(inQ)
    if token == '':
      break

    first = token[0]
    if first.isalpha() or first == '_':
      handleIdent(token, inQ, outChs)
    elif first == lquote:
      handleQuotes(inQ, outChs)
    elif not frames:
      outChs.append(token)
    else:
      handlePunct(token, inQ, outChs)

  if frames:
    raise RuntimeError('macro: unexpected end of input')

  return ''.join(outChs)



def handleIdent(token, inQ, outChs):
  """ Process potential macro names and
      handle argument collection. """
  global frames

  # If already collecting args, treat token as text
  if frames and frames[-1]['plev'] > 0:
    frames[-1]['current'] += token
    return

  entry = symTable.get(token)
  if entry is None:
    outChs.append(token)
    return

  if isTracing:
    print(f"TRACE: Potential macro '{token}'")

  # Peek at next character for macro call '('
  skipped = []
  nextch = getChar(inQ)

  while nextch in (' ', '\t', '\n'):
    skipped.append(nextch)
    nextch = getChar(inQ)

  if nextch != '(':
    defn, kind = entry
    if kind == MACTYPE and '$' not in defn:
      # Zero-arg user macro expansion
      if nextch:
        inQ.appendleft(nextch)
      inQ.extendleft(reversed(skipped))
      putback(inQ, defn)
    else:
      # Not a macro call — restore consumed chars
      outChs.append(token)
      for ch in skipped:
        outChs.append(ch)
      if nextch:
        inQ.appendleft(nextch)
  else:
    # Start new macro frame
    inQ.appendleft('(')

    _, kind = entry
    frames.append({
      'name': token,
      'kind': kind,
      'args': [],
      'plev': 0,
      'current': ''
    })



def handleQuotes(inQ, outChs):
  # strips quote characters and nesting
  quoted = []
  nlpar = 1

  while nlpar > 0:
    t = getToken(inQ)
    if t == '':
      raise RuntimeError('macro: missing right quote')
    if t == rquote:
      nlpar -= 1
      if nlpar > 0:
        quoted.append(t)
    elif t == lquote:
      nlpar += 1
      quoted.append(t)
    else:
      quoted.append(t)

  text = ''.join(quoted)
  if frames and frames[-1]['plev'] > 0:
    frames[-1]['current'] += text
  else:
    # output literal text
    outChs.append(text)



def handlePunct(token, inQ, outChs):
  # Handle logic for parens and commas within macro calls
  first = token[0]
  frame = frames[-1]

  if first == '(':
    frame['plev'] += 1
    if frame['plev'] > 1:
      frame['current'] += token
  elif first == ')':
    frame['plev'] -= 1
    if frame['plev'] == 0:
      frame['args'].append(frame['current'].lstrip())
      frames.pop()
      doEval(frame, inQ, outChs)
    else:
      frame['current'] += token
  elif first == ',' and frame['plev'] == 1:
    frame['args'].append(frame['current'].lstrip())
    frame['current'] = ''
  else:
    frame['current'] += token



def doEval(frame, inQ, outChs):
  name = frame['name']
  kind = frame['kind']
  args = frame['args']

  if isTracing:
    depth = len(frames)
    print(f"TRACE: {'  '*depth}Eval '{name}' args: {args}")

  # Built-ins only need the user arguments
  if kind == DEFTYPE:
    dodef(args)
  elif kind == EXPRTYPE:
    doexpr(args, inQ)
  elif kind == SUBTYPE:
    dosub(args, inQ)
  elif kind == IFTYPE:
    doif(args, inQ)
  elif kind == LENTYPE:
    dolen(args, inQ)
  elif kind == CHQTYPE:
    dochq(args)
  elif kind == MACTYPE:
    defn, _ = symTable[name]
    doUserMacro(defn, name, args, inQ)



# Built-in macro handlers ----------------------

def dodef(args):
  # define(name, body) – add a macro
  if len(args) >= 1:
    name = args[0].strip()
    # Rejoin if user used commas
    body = ",".join(args[1:])
    symTable[name] = (body, MACTYPE)


def doexpr(args, inQ):
  # evaluate arithmetic expression;
  # expand first so macros inside work
  if len(args) > 0:
    try:
      putback(inQ,
          str(evalExpr(expandStr(args[0]))))
    except Exception:
      # Default or error state
      putback(inQ, "0")


def expandStr(s):
  # Run string s through the macro expander
  return macro(s)


def dosub(args, inQ):
  # substr(s, from, n) – push substring of s.
  # using 0-based indexing for the target string
  if len(args) > 0:
    s = expandStr(args[0])
    frm = evalExpr(expandStr(args[1])) \
                      if len(args) > 1 else 0
    nc = evalExpr(expandStr(args[2])) \
                      if len(args) > 2 else len(s)
    putback(inQ, s[frm: frm + nc])


def doif(args, inQ):
  # ifelse(v1, v2, then, else)
  if len(args) >= 3:
    # expand v1 and v2 before comparing
    v1 = expandStr(args[0].strip())
    v2 = expandStr(args[1].strip())
    if v1 == v2:
      putback(inQ, args[2])
    elif len(args) > 3:
      putback(inQ, args[3])


def dolen(args, inQ):
  if len(args) > 0:
    putback(inQ,
       str(len(expandStr(args[0]))))
  else:
    putback(inQ, '0')


def dochq(args):
  # changeq(lq, rq) – change quote characters
  global lquote, rquote

  if len(args) >= 2:
    # User called changeq(l, r)
    lquote = args[0].strip() if args[0].strip() else '`'
    rquote = args[1].strip() if args[1].strip() else "'"
  elif len(args) == 1:
    # User called changeq(lr)
    s = args[0].strip()
    if len(s) == 0:
      lquote, rquote = '`', "'"
    elif len(s) == 1:
      lquote, rquote = s[0], s[0]
    else:
      lquote, rquote = s[0], s[1]
  else:
    # reset to defaults
    lquote, rquote = '`', "'"


def doUserMacro(defn, name, args, inQ):
  """  Expand a user-defined macro.
    $1..$9 in the def are replaced by the corresponding args.
    $0 is replaced by the macro name.
  """

  def replaceArg(match):
    argIdx = int(match.group(1))
    if argIdx == 0:
      return name
    return args[argIdx-1] \
              if 0 < argIdx <= len(args) else ""

  expanded = re.sub(r'\$(\d)', replaceArg, defn)
  putback(inQ, expanded)



# ----------- expression evaluator ---------------
''' expr := term { (+|-) term }
    term := factor { (*|/|%) factor }
    factor := number
           |  ( expr )
'''

def evalExpr(s):
  return parseExpr(s, 0)[0]


def parseExpr(s, i):
  # expr := term { (+|-) term }
  v, i = parseTerm(s, i)
  while True:
    i = skipWhites(s, i)
    if i < len(s) and s[i] == '+':
      t, i = parseTerm(s, i + 1)
      v += t
    elif i < len(s) and s[i] == '-':
      t, i = parseTerm(s, i + 1)
      v -= t
    else:
      break

  return v, i


def parseTerm(s, i):
  # term := factor { (*|/|%) factor }
  v, i = parseFactor(s, i)
  while True:
    i = skipWhites(s, i)
    if i < len(s) and s[i] == '*':
      t, i = parseFactor(s, i + 1)
      v *= t
    elif i < len(s) and s[i] == '/':
      t, i = parseFactor(s, i + 1)
      v //= t
    elif i < len(s) and s[i] == '%':
      t, i = parseFactor(s, i + 1)
      v %= t
    else:
      break

  return v, i


def parseFactor(s, i):
  # factor := number  |  ( expr )
  i = skipWhites(s, i)
  if i < len(s) and s[i] == '(':
    v, i = parseExpr(s, i + 1)
    i = skipWhites(s, i)
    if i >= len(s) or s[i] != ')':
      raise ValueError('macro: missing ) in expr')
    return v, i + 1

  sign = 1
  if i < len(s) and s[i] == '-':
    sign = -1
    i += 1

  start = i
  while i < len(s) and s[i].isdigit():
    i += 1
  if start == i:
    raise ValueError('macro: expected number in expr')

  return sign * int(s[start:i]), i


def skipWhites(s, i):
  # skip white spaces
  while i < len(s) and s[i] in ' \t\n':
    i += 1
  return i


# Queue / I/O helpers ------------------

def getToken(inQ):
  # Read next token made up of alphanumerics
  c = getChar(inQ)
  if c == '':
    return ''

  if c.isalnum() or c == '_':
    tokenChars = [c]
    while True:
      nc = getChar(inQ)
      if nc == '':
        break
      if nc.isalnum() or nc == '_':
        tokenChars.append(nc)
      else:
        inQ.appendleft(nc)
        break
    return ''.join(tokenChars)

  return c


def getChar(inQ):
  return inQ.popleft() if inQ else ''


def putback(inQ, s):
  # push (expanded) string onto input queue
  if isTracing and s:
    print(f"TRACE: Pushing back: [{s}]")

  inQ.extendleft(reversed(s))



# main() ----------

isTracing = False

if len(sys.argv) < 2:
  print('Usage: python3 macro3.py [-trace] <file.m3>')
  sys.exit(1)

# Check for trace flag
args = sys.argv[1:]

if '-trace' in args:
  isTracing = True
  args.remove('-trace')

if not args or not args[0].endswith('.m3'):
  print(f'Error: input file must have a .m3 extension')
  sys.exit(1)

infile = args[0]
try:
  with open(infile, 'r') as fh:
    src = fh.read()
except FileNotFoundError:
  print(f'Error: cannot open {infile!r}')
  sys.exit(1)

expandedSrc = macro(src)
print("\n--------- Expansion ---------\n")
print(expandedSrc)
print("-----------------------------\n")

outfile = infile[:-3] + '.txt'
with open(outfile, 'w') as fh:
  fh.write(expandedSrc)
print(f'Expanded output written to {outfile!r}')