
# tbc.py
# Andrew Davison, ad@coe.psu.ac.th, April 2026

'''
Compiler from a simple BASIC-like language (.tb) to a low-level
jump-based output format (.bas).

INPUT LANGUAGE
==============
program     ::= statement*
statement   ::= assignment
              | print | input
              | while | for | if
              | BREAK | CONTINUE

assignment  ::= var '=' expr

print       ::= PRINT print_item (',' print_item)*
print_item  ::= atom | string
input       ::= INPUT [string] var

while       ::= WHILE expr statement* WEND
for         ::= FOR var '=' atom TO atom [STEP atom] 
                  statement* NEXT var
if          ::= IF expr THEN statement* [ELSE statement*] ENDIF

expr        ::= atom [ op atom ]     
op          ::= '+' | '-' | '*' | '/' | '&' | '|'
                | '==' | '!=' | '<' | '<=' | '>' | '>=' 

OUTPUT FORMAT
=============
Minimal BASIC Interpreter

  oprogram   ::= { [label] ostatement }
  ostatement ::= 'P' (var | string | '!' | '-')   # newline, ""
              | 'I' var
              | 'J' atom label    # jump to label if atom != 0

  oexpr      ::= atom [ oop atom ]     
  atom      ::= var | int
  oop        ::= '+' | '-' | '*' | '/' | '&' | '|'
                  | '=' | '!'  | '<' | '>' 
    # '/' is floor division: -7 / 2 = -4
    # '=' is == and '!' is !=
    # '&' is logical and '|' is logical or
    # All logical operations return 1 (true) or 0 (false)

  var       ::= a-z
  label     ::= A-Z (excluding P, J, I)
  int       ::= (0-9)+
  string    ::= '"' [^"]* '"'    (! at end means \n)

'''

import string, sys, os

# -------------------------
# Tokens

KEYWORD = 'KEYWORD'
VAR     = 'VAR'
INTEGER = 'INTEGER'
STRING  = 'STRING'
OP      = 'OP'
ASSIGN  = 'ASSIGN'
COMMA   = 'COMMA'

KEYWORDS = {'INPUT', 'PRINT','WHILE','WEND','FOR','TO',
             'STEP','NEXT','IF','THEN','ELSE','ENDIF',
             'BREAK','CONTINUE'}

# single-char ops (multi-char handled separately in tokenizer)
SIMPLE_OPS = ('+', '-', '*', '/', '&', '|')


# -------------------------
# Tokenizer

def lexer(src):
  toks = []
  userVars = set()
  i = 0
  lineNum = 1
  while i < len(src):
    ch = src[i]

    if ch in (' ', '\t', '\r'):
      i += 1
      continue

    if ch == '\n':
      lineNum += 1
      i += 1
      continue

    if ch == '#':
      while i < len(src) and src[i] != '\n':
        i += 1
      continue

    elif ch.islower():
      toks.append((VAR, ch, lineNum))
      userVars.add(ch)
      i += 1

    elif ch.isdigit():
      i, tok = lexInteger(src, i, lineNum)
      toks.append(tok)

    elif ch == '"':
      i, tok = lexString(src, i, lineNum)
      toks.append(tok)

    elif ch in SIMPLE_OPS or ch in '<>!=/':
      i, tok = lexOp(src, i, lineNum)
      toks.append(tok)

    elif ch == ',':
      toks.append((COMMA, ',', lineNum))
      i += 1

    elif ch.isupper():
      i, tok = lexKeyword(src, i, lineNum)
      toks.append(tok)

    else:
      lexError(lineNum, f"Unknown character '{ch}'")
  return toks, userVars


def lexInteger(src, i, lineNum):
  num = []
  while i < len(src) and src[i].isdigit():
    num.append(src[i])
    i += 1
  return i, (INTEGER, int(''.join(num)), lineNum)


def lexString(src, i, lineNum):
  i += 1  # Skip opening quote
  s = []
  while i < len(src) and src[i] != '"':
    s.append(src[i])
    i += 1
  if i >= len(src):
    lexError(lineNum, "Unterminated string literal")
  return i+1, (STRING, ''.join(s), lineNum)


def lexOp(src, i, lineNum):
  ch = src[i]
  # Handle multi-character operators
  if ch == '<':
    if i+1 < len(src) and src[i+1] == '=':
      return i+2, (OP, '<=', lineNum)
    return i+1, (OP, '<', lineNum)
  elif ch == '>':
    if i+1 < len(src) and src[i+1] == '=':
      return i+2, (OP, '>=', lineNum)
    return i+1, (OP, '>', lineNum)
  elif ch == '=':
    if i+1 < len(src) and src[i+1] == '=':
      return i+2, (OP, '==', lineNum)
    return i+1, (ASSIGN, '=', lineNum)
  elif ch == '!':
    if i+1 < len(src) and src[i+1] == '=':
      return i+2, (OP, '!=', lineNum)
    lexError(lineNum, "Expected != but got lone !")
  elif ch in SIMPLE_OPS:
    return i+1, (OP, ch, lineNum)

  lexError(lineNum, f"Unknown operator character '{ch}'")


def lexKeyword(src, i, lineNum):
  j = i
  word = []
  while j < len(src) and src[j].isupper():
    word.append(src[j])
    j += 1
  w = ''.join(word)

  if w in KEYWORDS:
    return j, (KEYWORD, w, lineNum)
  else:
    lexError(lineNum, f"Unknown keyword '{w}'")


def lexError(lnNum, msg):
  raise SyntaxError(f'lex error on line {lnNum}: {msg}')

# -------------------------
# Parser

def lineOf(tok):
  # Tokens are (type, value, lineNum)
  return tok[2] if tok and len(tok) > 2 else '?'


def expect(pos, toks, typ, val=None):
  pos, tok = advance(pos, toks)
  if (not tok) or (tok[0] != typ) or \
     ((val is not None) and (tok[1] != val)):
    ln = lineOf(tok) if tok else lineOf(peek(pos-1, toks))
    exMsg = str(typ)
    if val is not None:
        exMsg += f" '{val}'"
    parseError(ln, f"Expected {exMsg}, got {tok[0] if tok else 'EOF'}")
  return pos, tok


def advance(pos, toks):
  tok = peek(pos, toks)
  if tok: 
    pos += 1
  return pos, tok


def peek(pos, toks):
  return toks[pos] if pos < len(toks) else None



def parseProgram(pos, toks):
  # program ::= statement*
  stmts = []
  while peek(pos, toks):
    pos, stmt = parseStmt(pos, toks)
    stmts.append(stmt)
  return pos, stmts


def parseStmt(pos, toks):
  ''' statement ::= assignment
              | print | input
              | while | for | if
              | BREAK | CONTINUE
  '''
  tok = peek(pos, toks)
  if tok[0] == KEYWORD:
    if tok[1] == 'PRINT':
      return parsePrint(pos, toks)
    elif tok[1] == 'INPUT':
      return parseInput(pos, toks)
    elif tok[1] == 'WHILE':
      return parseWhile(pos, toks)
    elif tok[1] == 'FOR':
      return parseFor(pos, toks)
    elif tok[1] == 'IF':
      return parseIf(pos, toks)
    elif tok[1] == 'BREAK':
      pos, _ = advance(pos, toks)
      return pos, ('BREAK',)
    elif tok[1] == 'CONTINUE':
      pos, _ = advance(pos, toks)
      return pos, ('CONTINUE',)
    else:
      parseError(lineOf(tok), f"Unknown keyword '{tok[1]}'")
  elif tok[0] == VAR:
    return parseAssign(pos, toks)
  else:
    parseError(lineOf(tok), f"Unexpected token {tok}")


def parseAssign(pos, toks):
  # assignment  ::= var '=' expr
  pos, var_tok = expect(pos, toks, VAR)
  var = var_tok[1]
  pos, _ = expect(pos, toks, ASSIGN)
  pos, expr = parseExpr(pos, toks)
  return pos, ('LET', var, expr)



def parsePrint(pos, toks):
  # print ::= PRINT print_item (',' print_item)*
  pos, _ = expect(pos, toks, KEYWORD, 'PRINT')
  tok = peek(pos, toks)
  if not tok or tok[0] not in (VAR, INTEGER, STRING):
    ln = lineOf(tok) if tok else '?'
    parseError(ln, "PRINT requires at least one item")
  
  pos, item = parsePrintItem(pos, toks)
  items = [item]
  while peek(pos, toks) and peek(pos, toks)[0] == COMMA:
    pos, _ = advance(pos, toks)
    pos, item = parsePrintItem(pos, toks)
    items.append(item)
  return pos, ('PRINT', items)


def parsePrintItem(pos, toks):
  tok = peek(pos, toks)
  if tok and tok[0] == STRING:
    pos, _ = advance(pos, toks)
    return pos, ('STR', tok[1])
  return parseAtom(pos, toks)


def parseInput(pos, toks):
  # input ::= INPUT [string] var
  pos, _ = expect(pos, toks, KEYWORD, 'INPUT')
  prompt = None
  tok = peek(pos, toks)
  if tok and tok[0] == STRING:
    pos, str_tok = advance(pos, toks)
    prompt = str_tok[1]
  pos, var_tok = expect(pos, toks, VAR)
  var = var_tok[1]
  return pos, ('INPUT', prompt, var)


def parseWhile(pos, toks):
  # while ::= WHILE expr statement* WEND
  pos, _ = expect(pos, toks, KEYWORD, 'WHILE')
  pos, expr = parseExpr(pos, toks)
  body = []
  while peek(pos, toks) and \
        not (peek(pos, toks)[0] == KEYWORD and \
             peek(pos, toks)[1] == 'WEND'):
    pos, stmt = parseStmt(pos, toks)
    body.append(stmt)
  pos, _ = expect(pos, toks, KEYWORD, 'WEND')
  return pos, ('WHILE', expr, body)


def parseFor(pos, toks):
  ''' for ::= FOR var '=' atom TO atom [STEP atom] 
                 statement* NEXT var
  '''
  pos, _ = expect(pos, toks, KEYWORD, 'FOR')
  pos, var_tok = expect(pos, toks, VAR)
  var = var_tok[1]
  pos, _ = expect(pos, toks, ASSIGN)
  pos, start = parseAtom(pos, toks)
  pos, _ = expect(pos, toks, KEYWORD, 'TO')
  pos, end = parseAtom(pos, toks)
  
  step = ('INT', 1)
  tok = peek(pos, toks)
  if tok and tok[0] == KEYWORD and tok[1] == 'STEP':
    pos, _ = advance(pos, toks)
    pos, step = parseAtom(pos, toks)
    
  body = []
  while peek(pos, toks) and \
            not (peek(pos, toks)[0] == KEYWORD and \
                 peek(pos, toks)[1] == 'NEXT'):
    pos, stmt = parseStmt(pos, toks)
    body.append(stmt)
    
  pos, _ = expect(pos, toks, KEYWORD, 'NEXT')
  pos, nextVar_tok = expect(pos, toks, VAR)
  nextVar = nextVar_tok[1]
  if nextVar != var:
    parseError(lineOf(nextVar_tok), f"NEXT {nextVar} does not match FOR {var}")
  return pos, ('FOR', var, start, end, step, body)


def parseIf(pos, toks):
  # if ::= IF expr THEN statement* [ELSE statement*] ENDIF
  pos, _ = expect(pos, toks, KEYWORD, 'IF')
  pos, expr = parseExpr(pos, toks)
  pos, _ = expect(pos, toks, KEYWORD, 'THEN')
  
  thenBody = []
  while peek(pos, toks) and \
              not (peek(pos, toks)[0] == KEYWORD and \
                   peek(pos, toks)[1] in ('ELSE', 'ENDIF')):
    pos, stmt = parseStmt(pos, toks)
    thenBody.append(stmt)
    
  elseBody = []
  tok = peek(pos, toks)
  if tok and tok[0] == KEYWORD and tok[1] == 'ELSE':
    pos, _ = advance(pos, toks)
    while peek(pos, toks) and \
          not (peek(pos, toks)[0] == KEYWORD and \
               peek(pos, toks)[1] == 'ENDIF'):
      pos, stmt = parseStmt(pos, toks)
      elseBody.append(stmt)
      
  pos, _ = expect(pos, toks, KEYWORD, 'ENDIF')
  return pos, ('IF', expr, thenBody, elseBody)


def parseExpr(pos, toks):
  # expr ::= atom [op atom]
  pos, left = parseAtom(pos, toks)
  if peek(pos, toks) and peek(pos, toks)[0] == OP:
    pos, op_tok = advance(pos, toks)
    op = op_tok[1]
    pos, right = parseAtom(pos, toks)
    left = ('BINOP', op, left, right)
  return pos, left


def parseAtom(pos, toks):
  # atom ::= var | int
  tok = peek(pos, toks)
  if not tok:
    parseError("?", "Unexpected end of input, expected atom")
  if tok[0] == VAR:
    pos, _ = advance(pos, toks)
    return pos, ('VAR', tok[1])
  if tok[0] == INTEGER:
    pos, _ = advance(pos, toks)
    return pos, ('INT', tok[1])
  parseError(lineOf(tok), f"Expected atom (VAR or INT), got {tok}")


def parseError(lnNum, msg):
  raise SyntaxError(f'parse error on line {lnNum}: {msg}')


# -------------------------
# AST Printer


def printAst(node, indent=0):
  pad = '  ' * indent
  if isinstance(node, list):
    for item in node:
      printAst(item, indent)

  elif isinstance(node, tuple):
    tag, *rest = node
    if tag == 'LET':
      var, expr = rest
      print(f"{pad}(LET {var} = {fmtExpr(expr)})")

    elif tag == 'PRINT':
      print(f"{pad}(PRINT {', '.join(fmtExpr(a) for a in rest[0])})")

    elif tag == 'INPUT':
      prompt, var = rest
      if prompt is not None:
        print(f"{pad}(INPUT \"{prompt}\" {var})")
      else:
        print(f"{pad}(INPUT {var})")

    elif tag in ('BREAK', 'CONTINUE'):
      print(f"{pad}({tag})")

    elif tag == 'IF':
      expr, thenBody, elseBody = rest
      print(f"{pad}(IF {fmtExpr(expr)}")
      print(f"{pad}  THEN")
      printAst(thenBody, indent+2)
      if elseBody:
        print(f"{pad}  ELSE")
        printAst(elseBody, indent+2)
      print(f"{pad})")

    elif tag == 'WHILE':
      expr, body = rest
      print(f"{pad}(WHILE {fmtExpr(expr)}")
      printAst(body, indent+1)
      print(f"{pad})")

    elif tag == 'FOR':
      var, start, end, step, body = rest
      print(f"{pad}(FOR {var} = {fmtExpr(start)} TO {fmtExpr(end)} STEP {fmtExpr(step)}")
      printAst(body, indent+1)
      print(f"{pad})")

    else:
      print(f"{pad}({tag})")
  else:
    print(f"{pad}{node!r}")


def fmtExpr(node):
  if node[0] == 'VAR':   
    return node[1]
  if node[0] == 'INT':   
    return str(node[1])
  if node[0] == 'STR':   
    return f'"{node[1]}"'
  if node[0] == 'BINOP': 
    return f'({fmtExpr(node[2])} {node[1]} {fmtExpr(node[3])})'
  return repr(node)


# -------------------------
# Generator
'''
  oop ::= '+' | '-' | '*' | '/' | '&' | '|'
        | '=' | '!'  | '<' | '>' 
'''

COMPOUND_OPS = {'<=': ('<', '='),
                '>=': ('>', '=')}

REMAP_OPS    = {'==': '=', '!=': '!'}

# Mapping for direct negation used in Jumps
NEGATE_OP = {'<': '>=', '>': '<=', '<=': '>',
             '>=': '<', '==': '!=', '!=': '==',
             '&': None, '|': None}



def generate(stmts, userVars):
  _labels = (ch for ch in string.ascii_uppercase
                 if ch not in 'PSJI')
  _temps  = (ch for ch in reversed(string.ascii_lowercase)
                 if ch not in userVars)
  labels = lambda: getNext(_labels, 'labels')
  temps  = lambda: getNext(_temps,  'temporaries')
  code = []
  loopStack = []
  genStmts(stmts, code, loopStack, labels, temps)
  return code


def getNext(gen, kind):
  try:
    return next(gen)
  except StopIteration:
    raise RuntimeError(f"Ran out of {kind}: program is too complex ")



def genStmts(stmts, code, loopStack, labels, temps):
  for s in stmts:
    genStmt(s, code, loopStack, labels, temps)


def genStmt(stmt, code, loopStack, labels, temps):
  tag, *rest = stmt
  
  if tag == 'PRINT':
    n = len(rest[0])
    for i, atom in enumerate(rest[0]):
      if atom[0] == 'VAR':
        emit(code, f"P {atom[1]}")
      else:  
        s = (str(atom[1]) 
          if atom[0] == 'INT' else atom[1])
        emit(code, f"P \"{s}\"")
    emit(code, f"P !")

  elif tag == 'LET':
    var, expr = rest
    emitExpr(var, expr, code, temps)

  elif tag == 'INPUT':
    prompt, var = rest
    if prompt is not None:
      emit(code, f"P \"{prompt}\"")
    emit(code, f"I {var}")

  elif tag == 'BREAK':
    if not loopStack:
      raise ValueError("BREAK outside loop")
    _, lblEnd = loopStack[-1]
    emit(code, f"J 1 {lblEnd}")

  elif tag == 'CONTINUE':
    if not loopStack:
      raise ValueError("CONTINUE outside loop")
    lblStart, _ = loopStack[-1]
    emit(code, f"J 1 {lblStart}")

  elif tag == 'IF':
    expr, thenBody, elseBody = rest
    lblElse = labels()
    lblEnd  = labels()
    tmp = temps()
    emitExpr(tmp, negateExpr(expr), code, temps)
    emit(code, f"J {tmp} {lblElse}")
    genStmts(thenBody, code, loopStack, labels, temps)
    emit(code, f"J 1 {lblEnd}")
    emit(code, f"{lblElse}:")
    genStmts(elseBody, code, loopStack, labels, temps)
    emit(code, f"{lblEnd}:")

  elif tag == 'WHILE':
    expr, body = rest
    lblStart = labels()
    lblEnd   = labels()
    loopStack.append((lblStart, lblEnd))
    emit(code, f"{lblStart}:")
    tmp = temps()
    emitExpr(tmp, negateExpr(expr), code, temps)
    emit(code, f"J {tmp} {lblEnd}")
    genStmts(body, code, loopStack, labels, temps)
    emit(code, f"J 1 {lblStart}")
    emit(code, f"{lblEnd}:")
    loopStack.pop()

  elif tag == 'FOR':
    var, start, end, step, body = rest
    lblStart = labels()
    lblEnd   = labels()
    loopStack.append((lblStart, lblEnd))
    emit(code, f"{var} {atomStr(start)}")
    emit(code, f"{lblStart}:")
    tmp = temps()
    emit(code, f"{tmp} {var} > {atomStr(end)}")
    emit(code, f"J {tmp} {lblEnd}")
    genStmts(body, code, loopStack, labels, temps)
    emit(code, f"{var} {var}+{atomStr(step)}")
    emit(code, f"J 1 {lblStart}")
    emit(code, f"{lblEnd}:")
    loopStack.pop()

  else:
    raise ValueError("Unknown stmt type "+tag)


def emit(code, line):
  if line.endswith(':'):
    code.append(line[:-1]+'  P "-"')  # print nothing (a NOP)
  else:
    code.append('   '+line)



def emitExpr(dest, node, code, temps):
  """ Decomposes complex ops (<=, >=) into 
      simpler ops (=, !, <, >) to match BNF.
  """
  if node[0] in ('VAR', 'INT'):
    emit(code, f"{dest} {atomStr(node)}")
    return dest

  if node[0] != 'BINOP':
    raise ValueError(f"Unexpected expr node {node[0]}")

  _, op, left, right = node
  lStr = atomStr(left)
  rStr = atomStr(right)

  if op in COMPOUND_OPS:
    op1, op2 = COMPOUND_OPS[op]
    t1 = temps()
    t2 = temps()
    emit(code, f"{t1} {lStr} {op1} {rStr}")
    emit(code, f"{t2} {lStr} {op2} {rStr}")
    emit(code, f"{dest} {t1} | {t2}")
  elif op in REMAP_OPS:
    emit(code, f"{dest} {lStr} {REMAP_OPS[op]} {rStr}")
  else:
    emit(code, f"{dest} {lStr} {op} {rStr}")

  return dest


def atomStr(node):
  if node[0] == 'VAR': 
    return node[1]
  if node[0] == 'INT': 
    return str(node[1])
  raise ValueError(f"Expected VAR or INT atom, got {node[0]}")


def negateExpr(node):
  if node[0] != 'BINOP':
    raise ValueError(f"Cannot negate non-binop node: {node[0]}")

  _, op, left, right = node
  if (op not in NEGATE_OP) or (NEGATE_OP[op] is None):
    raise ValueError(f"Cannot directly negate operator '{op}'")
  return ('BINOP', NEGATE_OP[op], left, right)


# --------------------------------------------

if len(sys.argv) != 2:
  print("Usage: python tbc.py <file.tb>")
  sys.exit(1)

fnm = sys.argv[1]
if not fnm.endswith('.tb'):
  print("Error: input file must have a .tb extension")
  sys.exit(1)

if not os.path.isfile(fnm):
  print(f"File '{fnm}' not found.")
  sys.exit(1)

with open(fnm) as f:
  src = f.read()

try:
  toks, userVars = lexer(src)
  pos, ast = parseProgram(0, toks)
  # printAst(ast)
  # print("---------------------")
  output = generate(ast, userVars)
except (SyntaxError, RuntimeError, ValueError) as e:
  print(f"Error: {e}")
  sys.exit(1)

for line in output:
  print(line)
print("---------------------")

base, ext = os.path.splitext(fnm)
outFnm = f"{base}.bas"
with open(outFnm, 'w') as f:
  f.write('\n'.join(output)+'\n')
print(f"Written to {outFnm}")