
# subleq.py
# Andrew Davison, ad@coe.psu.ac.th, April 2026
'''
Based on subleq.py from 
  https://esolangs.org/wiki/Subleq/subleq.py

Changes:
 * removed class structuring
 * input file can have comments on code lines
 * added extra output ops (b == -1, -2, -3) to print ASCII
 * added input op (b == -4)
 * made the C field of the (A B C) triplet optional;
   its absense means ?+1, so goto the next triplet
 * save XXX.asm to XXX.tri
 * assemble flag shortened to asm
 * added listTri.py
 * sx.bat combines asm and run calls to subleq.py


Triplet meaning (see run()): 
SUBLEQ(pc) =
  A = mem[pc]; B = mem[pc+1]; C =mem[pc+2]
  mem[B] = mem[B] - mem[A]
  if mem[B] <= 0:
    pc = C    # goto C
  else:
    pc += 3   # next triplet

----
Usage: 

> python subleq.py asm hi.asm 
> python subleq.py run hi.tri
Hi

> python subleq.py trace hi.tri

> python subleq.py asm hello.asm
> python subleq.py run hello.tri
Hello, World!

> sx hi.asm
'''

import sys, os

OUTPUT_INT = -1    # print var value as an integer
OUTPUT_VASC = -2   # print var value as ASCII char
OUTPUT_ASC = -3    # print address as ASCII char

INPUT_PORT  = -4


def assemble(src):
  # twp-pass compiler to deal with label resolution
  lbls = collectLabels(src)
  mem = []
  writeContents(src, lbls, mem)
  return mem


def collectLabels(src):
  # scan to collect label -> address mappings
  lbls = {}
  addr = 0
  for line in src.split("\n"):
    line = line.strip()
    if not line: 
      continue
    argsList = line.split()
    count = len(argsList)
    if count == 3:
      a, b, c = argsList
      if ':' in a:
        label, _ = a.split(':')
        lbls[label] = addr
      if ':' in b:
        label, _ = b.split(':')
        lbls[label] = addr + 1
      if ':' in c:
        label, _ = c.split(':')
        lbls[label] = addr + 2
      addr += 3
    elif count == 2:   # no ?+1 argument
      a, b = argsList
      if ':' in a:
        label, _ = a.split(':')
        lbls[label] = addr
      if ':' in b:
        label, _ = b.split(':')
        lbls[label] = addr + 1
      addr += 3
    else:
      print("Missing values on:", line)
      sys.exit(1)

  return lbls


def writeContents(src, lbls, mem):
  # resolve label symbols and write triples into mem
  addr = 0
  for line in src.split("\n"):
    line = line.strip()
    if not line or line.startswith('#'):
      continue
    argsList = line.split()
    count = len(argsList)
    if count == 3:
      a, b, c = argsList
      resolve(a, addr, lbls, mem)
      resolve(b, addr, lbls, mem)
      resolve(c, addr, lbls, mem)
      addr += 3
    elif count == 2:  # no ?+1 argument
      a, b = argsList
      resolve(a, addr, lbls, mem)
      resolve(b, addr, lbls, mem)
      resolve('?+1', addr, lbls, mem)
      addr += 3
    else:
      print("Missing values on:", line)
      sys.exit(1)


def resolve(v, addr, lbls, mem):
  # resolve a label and append it to mem
  if ':' in v:
    _, contents = v.split(':')
    v = contents
  if v in lbls:
    # replace label v by its mem[] address
    v = lbls[v]
  if v == '?+1':
    v = addr + 3
  mem.append(int(v))


def saveTriplets(mem, f):
  # write the mem triplets to a file
  addr = 0
  while addr < len(mem):
    a = mem[addr]
    b = mem[addr+1]
    c = mem[addr+2]
    f.write(f"{a} {b} {c}\n")
    addr += 3


def loadTriples(f):
  # load the mem triplets from a file
  mem = []
  for line in f:
    a, b, c = line.split()
    mem.append(int(a))
    mem.append(int(b))
    mem.append(int(c))
  return mem



def run(mem, isTrace=False):
  if isTrace:
    print("-------- Execution ----------")
  pc = 0
  if isTrace:
    print(f"\t pc: (  a,   b,   c)   mem[a]  mem[b]\n")
  while pc >= 0:
    a = mem[pc]
    b = mem[pc+1]
    c = mem[pc+2]
    if isTrace:
      memA = memStr(a)
      memB = memStr(b)
      print(f"\t{pc:3d}: ({a:3d}, {b:3d}, {c:3d})   {memA:>4}    {memB:>4}")
    result = 0
    if b >= 0:
      result = mem[b] - mem[a]
      mem[b] = result
    elif b == OUTPUT_INT:  # -1
      # print mem[a] value as an integer
      print( printInt(mem[a]), end='')

    elif b == OUTPUT_VASC:  # -2
      # print mem[a] value as ASCII char
      print( printAscii(mem[a]), end='')

    elif b == OUTPUT_ASC:  # -3
      # print address a as ASCII char
      print( printAscii(a), end='')

    elif b == INPUT_PORT:   # -4
      mem[a] = int(input(">> "))

    pc = c if result <= 0 else pc + 3     
                          # next triple

def memStr(addr):
  # return value as a string and 
  # deal with invalid address ranges
  if (addr < 0) or (addr >= len(mem)):
    return "--"
  else:
    return str(mem[addr])


# Use -10 to indicate a newline in the print funcs

def printInt(n):
  if n == -10:
    return '\n'
  else:
    return f"{n} "  # added a space

def printAscii(n):
  if n == -10:
    return '\n'
  elif (32 <= n <= 126):  # is printable
    return chr(n)
  else:  
    return f"{n}"

# ----------------------------------------------

if len(sys.argv) != 3:
  print("Usage: python subleq.py asm in.asm")
  print("   or: python subleq.py run in.tri")
  print("   or: python subleq.py trace in.tri")
  sys.exit(1)

args = sys.argv[1:]

fnm = args[1]
if not os.path.isfile(fnm):
  print(f"File '{fnm}' not found.")
  sys.exit(1)

if args[0] == 'asm':
  if not fnm.endswith('.asm'):
    print("Error: assembler must have a .asm extension")
    sys.exit(1)

  # remove all text from # to EOL on each line
  with open(fnm, 'r') as f:
    lines = []
    for line in f:
      line = line.split('#', 1)[0]
      lines.append(line.rstrip())
    src = "\n".join(lines)
  mem = assemble(src)

  # derive output filename: remove extension, add .tri
  base, _ = os.path.splitext(fnm)
  outFile = base + ".tri"
  print("Saving subleq instruction to", outFile)
  with open(outFile, 'w') as f:
    saveTriplets(mem, f)

elif args[0] == 'run' or args[0] == 'trace':
  if not fnm.endswith('.tri'):
    print("Error: run/trace must have a .tri extension")
    sys.exit(1)
  with open(fnm, 'r') as f:
    mem = loadTriples(f)
  isTrace = (args[0] == 'trace')
  run(mem, isTrace)

else:
  print("Usage must be 'asm', 'run', or 'trace'")
