
# exprCompiler.py
# Andrew Davison, ad@coe.psu.ac.th, April 2026

'''
High-level to Low-level Expression Compiler
Working, but not a complete language.

High-level BNF
===============
program   ::= { statement }
statement ::=  var '=' expr

expr    ::= logOR

logOr   ::= logAnd
            | logOr '||' logAnd

logAnd  ::= equal
            | logAnd '&&' equal

equal   ::= rel
            | equal '==' rel
            | equal '!=' rel

rel     ::= add
            | rel '<'  add
            | rel '<=' add
            | rel '>'  add
            | rel '>=' add

add     ::= mult
            | add '+' mult
            | add '-' mult

mult    ::= power
            | mult '*'  power
            | mult '/'  power
            | mult '%'  power

power   ::= unary
           | unary '^' power    # right-associative

unary   ::= pri
            | '-' unary
            | '!' unary          # logical NOT

pri     ::= var | int | string
            | '(' expr ')'

var       ::= a-z
int       ::= (0-9)+
string    ::= '"' [^"]* '"' 

--------------------------------------

Low-level BNF
==============
oprogram   ::= { ostatement }
ostatement ::=  var '=' oexpr

oexpr      ::= atom [ op atom ]     
atom      ::= var | int
op        ::=  ... # binary operators only


Usage:
  python exprCompiler.py test1.txt out.txt

'''

import sys


# globals
usedVars = set()
varsExhausted = False
zCounter = 0
tempVars = set()


# -------------------------------------------
# tokenizer

def tokenize(text):
  toks = []
  i = 0
  line = 1

  while i < len(text):
    c = text[i]

    if c == '\n':
      line += 1
      i += 1
      continue

    if c.isspace():
      i += 1
      continue

    if isDigit(c):
      start = i
      while i < len(text) and isDigit(text[i]):
        i += 1
      toks.append(('INT', text[start:i], line))
      continue

    if isLower(c):
      toks.append(('ID', c, line))
      i += 1
      continue

    if i + 1 < len(text):
      two = text[i:i+2]
      if two in ('||','&&','==','!=','<=','>='):
        toks.append(('OP', two, line))
        i += 2
        continue

    if c in '+-*/%<>!^':
      toks.append(('OP', c, line))
      i += 1
      continue

    if c == '(':
      toks.append(('LPAREN', c, line))
      i += 1
      continue

    if c == ')':
      toks.append(('RPAREN', c, line))
      i += 1
      continue

    if c == '=':
      toks.append(('EQ', c, line))
      i += 1
      continue

    raise SyntaxError(f"line {line}: unexpected character '{c}'")

  return toks


def isDigit(c): return '0' <= c <= '9'
def isLower(c): return 'a' <= c <= 'z'


def expect(pos, toks, kind, val=None):
  pos, tok = advance(pos, toks)
  if (not tok) or (tok[0] != kind) or ((val is not None) and (tok[1] != val)):
    raise SyntaxError("parse error")
  return pos, tok


def advance(pos, toks):
  tok = peek(pos, toks)
  if tok:
    pos += 1
  return pos, tok


def peek(pos, toks):
  return toks[pos] if pos < len(toks) else None


# ------------------------------------------
# parser

def parseProgram(pos, toks):
  stmts = []
  while peek(pos, toks):
    pos, stmt = parseStatement(pos, toks)
    stmts.append(stmt)
  return pos, stmts


def parseStatement(pos, toks):
  global usedVars

  pos, tok = expect(pos, toks, 'ID')
  var = tok[1]
  usedVars.add(var)

  pos, _ = expect(pos, toks, 'EQ')
  pos, expr = parseExpr(pos, toks)
  return pos, ('assign', var, expr)


def parseExpr(pos, toks): return parseLogOr(pos, toks)


def parseLogOr(pos, toks):
  pos, node = parseLogAnd(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] == '||':
      pos, _ = advance(pos, toks)
      pos, rhs = parseLogAnd(pos, toks)
      node = ('||', node, rhs)
    else:
      break
  return pos, node


def parseLogAnd(pos, toks):
  pos, node = parseEqual(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] == '&&':
      pos, _ = advance(pos, toks)
      pos, rhs = parseEqual(pos, toks)
      node = ('&&', node, rhs)
    else:
      break
  return pos, node


def parseEqual(pos, toks):
  pos, node = parseRel(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] in ('==','!='):
      op = tok[1]
      pos, _ = advance(pos, toks)
      pos, rhs = parseRel(pos, toks)
      node = (op, node, rhs)
    else:
      break
  return pos, node


def parseRel(pos, toks):
  pos, node = parseAdd(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] in ('<','<=','>','>='):
      op = tok[1]
      pos, _ = advance(pos, toks)
      pos, rhs = parseAdd(pos, toks)
      node = (op, node, rhs)
    else:
      break
  return pos, node


def parseAdd(pos, toks):
  pos, node = parseMult(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] in ('+','-'):
      op = tok[1]
      pos, _ = advance(pos, toks)
      pos, rhs = parseMult(pos, toks)
      node = (op, node, rhs)
    else:
      break
  return pos, node


def parseMult(pos, toks):
  pos, node = parsePower(pos, toks)
  while True:
    tok = peek(pos, toks)
    if tok and tok[1] in ('*','/','%'):
      op = tok[1]
      pos, _ = advance(pos, toks)
      pos, rhs = parsePower(pos, toks)
      node = (op, node, rhs)
    else:
      break
  return pos, node


def parsePower(pos, toks):
  pos, node = parseUnary(pos, toks)
  tok = peek(pos, toks)
  if tok and tok[1] == '^':
    pos, _ = advance(pos, toks)
    pos, rhs = parsePower(pos, toks)
    node = ('^', node, rhs)
  return pos, node


def parseUnary(pos, toks):
  tok = peek(pos, toks)
  if tok and tok[1] == '-':
    pos, _ = advance(pos, toks)
    pos, operand = parseUnary(pos, toks)
    return pos, ('-', ('int', '0'), operand)
  if tok and tok[1] == '!':
    pos, _ = advance(pos, toks)
    pos, operand = parseUnary(pos, toks)
    return pos, ('==', operand, ('int', '0'))
  return parsePrimary(pos, toks)


def parsePrimary(pos, toks):
  global usedVars

  tok = peek(pos, toks)
  if tok and tok[0] == 'INT':
    return pos+1, ('int', tok[1])
  if tok and tok[0] == 'ID':
    usedVars.add(tok[1])
    return pos+1, ('var', tok[1])
  if tok and tok[0] == 'LPAREN':
    pos, _ = expect(pos, toks, 'LPAREN')
    pos, node = parseExpr(pos, toks)
    pos, _ = expect(pos, toks, 'RPAREN')
    return pos, node
  raise SyntaxError("invalid primary")


# --------------------------------------------
# LOWERING: <= and >= expanded to < / > and ==

def lowerProgram(ast):
  return [('assign', var, lowerExpr(expr)) for _, var, expr in ast]

def lowerExpr(node):
  if node[0] in ('int', 'var'):
    return node
  op, left, right = node
  left  = lowerExpr(left)
  right = lowerExpr(right)
  # a <= b  =>  (a < b) || (a == b)
  if op == '<=':
    return ('||', ('<', left, right), ('==', left, right))
  # a >= b  =>  (a > b) || (a == b)
  if op == '>=':
    return ('||', ('>', left, right), ('==', left, right))
  return (op, left, right)


# ------------------------------------------
# COMPILER

# map from internal operator names to output symbols
OP_SYMBOLS = {
  '==': '=',
  '!=': '!',
  '&&': '&',
  '||': '|',
}

def compileProgram(ast):
  ast = lowerProgram(ast)
  varsAvail = [c for c in 'abcdefghijklmnopqrstuvwxyz' 
           if c not in usedVars]

  lines = []
  for stmt in ast:
    compileStatement(stmt, varsAvail, lines)
  return lines


def compileStatement(stmt, varsAvail, lines):
  _, var, expr = stmt
  res = compileExpr(expr, varsAvail, lines)

  # on-the-fly optimization
  if lines:
    prev = lines[-1]
    lhs, rhs = splitAssign(prev)
    if lhs in tempVars and lhs == res:
      lines[-1] = f"{var} = {rhs}"
      varsAvail.append(lhs)
      return

  lines.append(f"{var} = {res}")


def compileExpr(node, varsAvail, lines):
  if node[0] == 'int':
    return node[1]
  if node[0] == 'var':
    return node[1]

  op, left, right = node

  valL = compileExpr(left, varsAvail, lines)
  valR = compileExpr(right, varsAvail, lines)

  t = newTemp(varsAvail)
  sym = OP_SYMBOLS.get(op, op)
  lines.append(f"{t} = {valL} {sym} {valR}")
  return t


def newTemp(varsAvail):
  global varsExhausted, zCounter, tempVars

  if varsAvail:
    name = varsAvail.pop(0)
    tempVars.add(name)
    return name

  if not varsExhausted:
    print("Warning: switching to z0, z1, ...")
    varsExhausted = True

  name = f"z{zCounter}"
  zCounter += 1
  tempVars.add(name)
  return name


def splitAssign(line):
  parts = line.split('=', 1)
  if len(parts) != 2:
    return None, None
  return parts[0].strip(), parts[1].strip()


# =============================================

if len(sys.argv) != 3:
  print("Usage: python compiler.py input.txt output.txt")
  sys.exit(1)

text = open(sys.argv[1]).read()
toks = tokenize(text)
_, ast = parseProgram(0, toks)
code = compileProgram(ast)

with open(sys.argv[2], 'w') as f:
  for line in code:
    f.write(line + '\n')
