# mbxS.py
# Andrew Davison, ad@coe.psu.ac.th, April 2026

'''
Minimal BASIC Interpreter (with 'S' (screen) parsing)

  program   ::= { [label] statement }
  statement ::= 'P' (var | string | '!' | '-')   # newline, ""
              | 'I' var
              | var expr
              | 'J' atom label   # jump to label if atom != 0
              | 'S' int '-' int [ '-' int ]      # NEW
              | 'S' '-'

  expr      ::= atom [ op atom ]     
  atom      ::= var | int
  op        ::= '+' | '-' | '*' | '/' | '&' | '|'
                | '=' | '!'  | '<' | '>' 
    # '/' is floor division: -7 / 2 = -4
    # '=' is == and '!' is !=
    # '&' is logical and '|' is logical or
    # All logical operations return 1 (true) or 0 (false)

  var       ::= a-z
  label     ::= A-Z (excluding P, J, I, S)
  int       ::= (0-9)+
  string    ::= '"' [^"]* '"'  (! at end means \n)

Usage:
  python mbxS.py fact.bas

  python mbxS.py "In a1 i1 Abi>n JbE aa*i ii+1 J1A EPa"
  python mbxS.py "Ina1i1Abi>nJbEaa*iii+1J1AEPa"

  python mbxS.py "Ia;Ib;ca+b;Pc"      #;'s are optional
  python mbxS.py "Ia Ib ca+b Pc" 
  python mbxS.py "IaIbca+bPc"
  python mbxS.py "IaIbca+bPcS40-40Ia"
'''

import sys, os
from Screen import Screen

# Tokens
INT    = 'INT'
VAR    = 'VAR'
LABEL  = 'LABEL'
OP     = 'OP'
STRING = 'STRING'
EOF    = 'EOF'


# ------------------------------
# Tokenizer

OPS = ('+', '-', '*', '/', '=', '!','<', '>', '&', '|', '.')

def lexer(src):
  toks = []
  i = 0
  lineNum = 1

  while i < len(src):
    ch = src[i]
    if ch in (' ', '\t', '\r'):
      i += 1
      continue

    if ch == '\n':
      lineNum += 1
      i += 1
      continue

    if ch == '#':
      while i < len(src) and src[i] != '\n':
        i += 1
      continue

    if ch == '"':
      i, token = lexString(src, i, lineNum)
      toks.append(token)
      continue

    if ch.isdigit():
      i, token = lexInteger(src, i, lineNum)
      toks.append(token)
      continue

    if ch.isalpha():
      i, token = lexAlpha(src, i, lineNum)
      toks.append(token)
      continue

    if ch in OPS:
      toks.append((OP, ch, lineNum))
      i += 1
      continue

    lexError(lineNum, f'unexpected character {ch}')

  toks.append((EOF, None, lineNum))
  # print(toks)
  return toks


def lexString(src, i, lineNum):
  # string ::= '"' [^"]* '"'
  start = i
  i += 1
  while i < len(src) and src[i] != '"':
    i += 1
  if i >= len(src):
    lexError(lineNum, 'unterminated string literal')
  val = src[start+1:i]
  return i+1, (STRING, val, lineNum)


def lexInteger(src, i, lineNum):
  # int ::= (0-9)+
  start = i
  while i < len(src) and src[i].isdigit():
    i += 1
  return i, (INT, int(src[start:i]), lineNum)


def lexAlpha(src, i, lineNum):
  ''' var    ::= a-z
      label  ::= A-Z (excluding P, J, I, S)
  '''
  ch = src[i]
  if ch.isupper():
    if ch in ('P', 'J', 'I', 'S'):
      return i+1, (ch, ch, lineNum)
    return i+1, (LABEL, ch, lineNum)
  else:
    return i+1, (VAR, ch, lineNum)


def lexError(lnNum, msg):
  raise SyntaxError(f'lex error on line {lnNum}: {msg}')


# ------------------------------
# Parser

def lineOf(tok):
  return tok[2] if tok and len(tok) > 2 else '?'


def expect(pos, toks, kind, val=None):
  pos, tok = advance(pos, toks)
  if (not tok) or (tok[0] != kind) or \
    ((val is not None) and (tok[1] != val)):
    ln = lineOf(tok) if tok else lineOf(peek(pos-1, toks))
    exMsg = str(kind)
    if val is not None:
        exMsg += f" '{val}'"
    raise SyntaxError(f"line {ln}: expected {exMsg}, got {tok[0] if tok else 'EOF'}")
  return pos, tok


def advance(pos, toks):
  tok = peek(pos, toks)
  if tok:
    pos += 1
  return pos, tok


def peek(pos, toks):
  return toks[pos] if pos < len(toks) else None


def parseProgram(pos, toks):
  # program ::= { [label] statement }
  labelIdx = {}
  lines = []

  while peek(pos, toks) and peek(pos, toks)[0] != EOF:
    label = None
    tok = peek(pos, toks)
    lnNum = lineOf(tok)

    if tok[0] == LABEL:
      pos, lblTok = advance(pos, toks)
      label = lblTok[1]
      if label in labelIdx:
        raise SyntaxError(f'line {lnNum}: duplicate label "{label}"')
      labelIdx[label] = len(lines)

    pos, stmt = parseStmt(pos, toks)
    # print(pos, stmt)
    lines.append((label, stmt, lnNum))

  return pos, lines, labelIdx


def parseStmt(pos, toks):
  '''
  statement ::= 'P' (var | string | '!') 
              | 'I' var | var expr
              | 'J' atom label
              | 'S' int '-' int [ '-' int ] | 'S' '-'
  '''
  tok = peek(pos, toks)
  if not tok:
    raise SyntaxError("Unexpected end of input")

  if tok[0] == 'P':
    pos, _ = advance(pos, toks)
    next_tok = peek(pos, toks)
    
    if next_tok and next_tok[0] == VAR:
      pos, v_tok = expect(pos, toks, VAR)
      return pos, ('print', 'var', v_tok[1])
    elif next_tok and next_tok[0] == STRING:
      pos, s_tok = expect(pos, toks, STRING)
      return pos, ('print', 'string', s_tok[1])
    elif next_tok and ((next_tok[1] == '!') or (next_tok[1] == '-')):
      pos, _ = expect(pos, toks, OP)
      return pos, ('print', 'string', next_tok[1])
    else:
      raise SyntaxError(f"line {lineOf(tok)}: expected variable or string after P")

  if tok[0] == 'I':
    pos, _ = advance(pos, toks)
    pos, v_tok = expect(pos, toks, VAR)
    return pos, ('input', v_tok[1])

  if tok[0] == VAR:
    pos, v_tok = expect(pos, toks, VAR)
    pos, expr = parseExpr(pos, toks)
    return pos, ('assign', v_tok[1], expr)

  if tok[0] == 'J':
    pos, _ = advance(pos, toks)
    pos, cond_node = parseAtom(pos, toks)
    pos, lblTok = expect(pos, toks, LABEL)
    return pos, ('ifGoto', cond_node, lblTok[1])

  if tok[0] == 'S':
    return parseScreen(pos, toks)

  raise SyntaxError(f"line {lineOf(tok)}: expected statement, got {tok[0]} ({tok[1]!r})")


def parseExpr(pos, toks):
  pos, left = parseAtom(pos, toks)
  tok = peek(pos, toks)
  
  if tok and tok[0] == OP:
    pos, op_tok = advance(pos, toks)
    pos, right = parseAtom(pos, toks)
    left = ('binop', op_tok[1], left, right)
    
  return pos, left



def parseScreen(pos, toks):
  '''
    'S' int '-' int [ '-' int ] | 'S' '-'
  '''
  tok = peek(pos, toks)
  lnNum = lineOf(tok)
  pos, _ = advance(pos, toks)   # consume 'S'

  next_tok = peek(pos, toks)
  if not next_tok:
    raise SyntaxError(f"line {lnNum}: incomplete S statement")

  # S '-'
  if next_tok[0] == OP and next_tok[1] == '-':
    pos, op_tok = advance(pos, toks)
    return pos, ('screen', op_tok[1])

  # S int '-' int [ '-' int ]
  if next_tok[0] == INT:
    pos, i1 = expect(pos, toks, INT)
    pos, slash1 = expect(pos, toks, OP, '-')
    pos, i2 = expect(pos, toks, INT)
    next_tok = peek(pos, toks)
    if next_tok and next_tok[0] == OP and next_tok[1] == '-':
      pos, _ = advance(pos, toks)   # consume second '-'
      pos, i3 = expect(pos, toks, INT)
      return pos, ('screen', 'draw', i1[1], i2[1], i3[1])
    else:
      return pos, ('screen', 'draw', i1[1], i2[1], 1)  # black

  raise SyntaxError(f"line {lnNum}: invalid S statement")


def parseAtom(pos, toks):
  tok = peek(pos, toks)
  if not tok:
    raise SyntaxError("Unexpected end of input")

  if tok[0] == VAR:
    pos, v_tok = advance(pos, toks)
    return pos, ('var', v_tok[1])

  if tok[0] == INT:
    pos, i_tok = advance(pos, toks)
    return pos, ('int', i_tok[1])

  raise SyntaxError(f"line {lineOf(tok)}: expected variable or integer")


# ------------------------------
# Interpreter

def run(lines, labelIdx):
  env = {}
  pc  = 0

  while pc < len(lines):
    _, stmt, lnNum = lines[pc]
    tag, *rest = stmt

    if tag == 'assign':
      v, expr = rest
      env[v] = evalExpr(expr, env, lnNum)
      pc += 1

    elif tag == 'print':
      ptype, pval = rest
      if ptype == 'var':
        print(getVar(env, pval, lnNum), end=" ")
      elif ptype == 'string':
        if pval.endswith('!'):
          print(pval[:-1])
        elif pval == "" or pval == "-":
          print("", end="")
        else:
          print(pval, end=" ")
      pc += 1

    elif tag == 'input':
      v = rest[0]
      try:
        val = int(input("?? "))
      except ValueError:
        runError(lnNum, 'invalid integer input')
      env[v] = val
      pc += 1

    elif tag == 'ifGoto':
      cond_node, lbl = rest
      if evalExpr(cond_node, env, lnNum) != 0:
        if lbl not in labelIdx:
          runError(lnNum, f'unknown {lbl}')
        pc = labelIdx[lbl]
      else:
        pc += 1

    elif tag == 'screen':
      evalScreen(rest, lnNum)
      pc += 1

    else:
      runError(lnNum, f'unknown statement {tag}')


def evalScreen(args, lnNum):
  if not args:
    runError(lnNum, 'invalid screen arguments')

  if args[0] == '-':
    print("[SCREEN CLEAR]")
    grid.clear()
    return

  if args[0] == 'draw':
    _, x, y, z = args
    # print(f"[SCREEN DRAW] x={x}, y={y}, val={z}")
    grid.draw(x,y,z)
    return

  runError(lnNum, f'unknown screen operation {args[0]!r}')


def evalExpr(node, env, lnNum):
  tag, *rest = node

  if tag == 'int':
    return rest[0]

  if tag == 'var':
    return getVar(env, rest[0], lnNum)

  if tag == 'binop':
    op, left, right = rest
    l = evalExpr(left, env, lnNum)
    r = evalExpr(right, env, lnNum)
    return applyOp(op, l, r, lnNum)

  runError(lnNum, f'internal error: unknown node {tag}')


def getVar(env, v, lnNum):
  if v not in env:
    runError(lnNum, f'undefined {v}')
  return env[v]


OP_LAMS = {
  '+': lambda a, b: a + b,
  '-': lambda a, b: a - b,
  '*': lambda a, b: a * b,
  '/': lambda a, b: a // b,

  '=': lambda a, b: 1 if a == b else 0,
  '!': lambda a, b: 1 if a != b else 0,
  '<': lambda a, b: 1 if a < b else 0,
  '>': lambda a, b: 1 if a > b else 0,

  '&': lambda a, b: 1 if a != 0 and b != 0 else 0,
  '|': lambda a, b: 1 if a != 0 or  b != 0 else 0,
}


def applyOp(op, a, b, lnNum):
  if op == '/' and b == 0:
    runError(lnNum, 'division by zero')
  if op not in OP_LAMS:
    runError(lnNum, f'unknown operator {op!r} applied to {a}, {b}')
  return OP_LAMS[op](a, b)


def runError(lnNum, msg):
  raise RuntimeError(f'run error on line {lnNum}: {msg}')


# ------------------------------

src = []
if len(sys.argv) != 2:
  print('usage: python mbxS.py <file.bas>|"<string>"')
  sys.exit(1)

if os.path.isfile(sys.argv[1]):
  fnm = sys.argv[1]
  if not fnm.endswith('.bas'):
    print("Error: input file must have a .bas extension")
    sys.exit(1)
  with open(fnm) as f:
    src = f.read()
else:
   src = sys.argv[1].replace(";", " ")
   
try:
  toks = lexer(src)
  pos, lines, labelIdx = parseProgram(0, toks)

  grid = Screen(cellSize=10, gridCells=80)
  run(lines, labelIdx)
except (SyntaxError, RuntimeError, ValueError) as e:
  print(f"Error: {e}")