
# complexUtils.py
# Andrew Davison, ad@coe.psu.ac.th, Nov. 2025

import math, cmath
import colorsys
import matplotlib.pyplot as plt


STEPS = 150


def castComplexToFloats(xs):
  return [abs(x) if isinstance(x, complex) else float(x) 
                 for x in xs ]



def getRoots(z, n):
  '''
  Compute all nth roots using the formula:
  z^(1/n) * exp(2 \pi ik/n) for k = 0, 1, 2, ..., n-1
  '''
  # Handle zero specially
  if z.real == 0 and z.imag == 0:
    return [complex(0, 0)]
  
  # Get principal root
  rRoot = abs(z)**(1/n)
  theta = cmath.phase(z)
  
  # Calculate all n roots
  roots = []
  for k in range(n):
    ang = (theta + 2*math.pi * k)/n
    roots.append( rRoot*cmath.exp(1j*ang))
  return roots


def getRealRoot(z, n):
  roots = getRoots(z, n)
  for root in roots:
    if isReal(root):
      return root.real if isinstance(root, complex) else float(root)
  printf("No real root found")
  return math.nan
    

def isReal(r):
  if isinstance(r, complex):
    return abs(r.imag) < 1e-10  # imag == 0
  else:
    return True


# --------------- cmath function plot info -------

def loadCMFuncs():
  # Reads cmFuncs.txt and returns a dictionary
  data = {}
  try:
    with open("cmFuncs.txt", "r") as file:
      lines = file.readlines()
  except FileNotFoundError:
    print("Error: cmFuncs.txt not found.")
    return data

  for l in lines:
    line = l.strip()
    if not line or line[0] == '#':
      continue

    parts = line.split()
    if len(parts) != 5:
      print(f"Warning: Skipping malformed line: {line.strip()}")
      continue
  
    name = parts[0].lower()
    try:
      floats = [float(x) for x in parts[1:]]
    except ValueError:
      print(f"Warning: Skipping invalid numeric line: {line.strip()}")
      continue
   
    if floats[0] < floats[1] and floats[2] < floats[3]:
      data[name] = floats
    else:
      print(f"Warning: Skipping line with invalid order: {line.strip()}")

  return data



def linspace(start, stop, num):
  # Create evenly spaced numbers over a specified interval
  if num == 1:
    return [start]
  step = (stop - start) / (num - 1)
  return [start + step * i for i in range(num)]



def evalFunc(func, realVs, imagVs):
  # Compute complex function values
  zs = []
  for im in imagVs:
    row = []
    for re in realVs:
      z = complex(re, im)
      try:
        res = func(z)
        row.append(res)
      except (ValueError, ZeroDivisionError):
        row.append(complex(float('nan'), float('nan')))
    zs.append(row)
  return zs


# ------------ phase coloring ------------


def showPhase(func, extent, title):
  realMin, realMax, imagMin, imagMax = extent
  reals = linspace(realMin, realMax, STEPS)
  imags = linspace(imagMin, imagMax, STEPS)
  colors, mags, phases = getPhaseInfo(func, reals, imags)
  plotPhase(reals, imags, 
                   extent, colors, mags, phases, title)



def getPhaseInfo(func, realVs, imagVs):
  # Returns colors, magnitudes, and phases
  colors = []; mags = []; phases = []
  for im in imagVs:
    rowColors = []
    rowMag = []; rowPhase = []
    for re in realVs:
      z = complex(re, im)
      try:
        w = func(z)
      except (ValueError, ZeroDivisionError):
        w = complex(float('nan'), float('nan'))

      mag = abs(w)
      phase = cmath.phase(w)
      # color = complexToHlsColor(phase, mag)
      # color = complexToHsbColor(phase, mag)
      color = complexToHsbColorLog2(phase, mag)
  
      rowColors.append(color)
      rowMag.append(mag)
      rowPhase.append(phase)

    colors.append(rowColors)
    mags.append(rowMag)
    phases.append(rowPhase)

  return colors, mags, phases


# ----------------------------------------
# Convert phase and magnitude to RGB 

def complexToHlsColor(phase, mag):
  hue = (phase + cmath.pi) / (2 * cmath.pi)   
    # map [-pi,pi] -> [0,1]
  lightness = 1 / (1 + mag**0.3)
  saturation = 1.0
  r, g, b = colorsys.hls_to_rgb(hue, lightness, saturation)
  return (r, g, b)


def complexToHsbColor(phase, mag):
  hue = (phase + cmath.pi) / (2 * cmath.pi)   
     # map [-pi,pi] -> [0,1]
  saturation = 1.0
  brightness = 1 / (1 + mag**0.3)
  r, g, b = colorsys.hsv_to_rgb(hue, saturation, brightness)
  return (r, g, b)


def complexToHsbColorLog1(phase, mag):
  # Hue encodes phase, shifted so 0 is red, 
  # then CCW through color wheel
  hue = (phase / (2 * cmath.pi) + 0.5) % 1.0

  # Use log-scale brightness with periodic bands
  # Smaller k gives more dense bands; adjust to taste
  if mag == 0:
    brightness = 1.0
  else:
    brightness = 0.5 + 0.5*(math.sin(3*math.log10(mag)) + 1)/2

  # Keep saturation constant
  saturation = 1.0
  r, g, b = colorsys.hsv_to_rgb(hue, saturation, brightness)
  return (r, g, b)



def complexToHsbColorLog2(phase, mag):
  # Hue: phase of f(z), arg=0 -> red, increasing CCW
  hue = (phase / (2 * math.pi)) % 1.0

  # Value/Brightness: logarithmic, 
  # repeat every magnitude decade
  if mag == 0:
    value = 1.0
  else:
    value = math.log10(mag + 1) % 1.0

  # Increase overall brightness and 
  # apply gamma for visual pop
  value = 0.6 + 0.4 * (value**0.6)  # brighten base + gamma

  # Slight desaturation for large mags 
  # (helps branch cuts stand out)
  saturation = 1.0 - 0.3 * (value < 0.1)  # small tweak

  r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)
  return (r, g, b)


# -----------------------------------------
'''
A domain colored phase diagram with contour lines 
for phase and modulus
  - phase lines are black, modulus lines are white

Positive numbers are colored red; negative numbers are colored in cyan. Zeros and poles occur at the points where all the colors meet. 

A zero is distinguished by the change from red, yellow, green, to blue occuring in a counter-clockwise order around the point.

A pole has the colors change in a clockwise order.

'''

def plotPhase(reals, imags, 
           extent, colors, mags, phases, title):
  plt.figure(figsize=(10, 7))
  plt.imshow(colors, origin='lower', extent=extent,
               interpolation='bilinear', aspect='equal' )
  
  # labeled |f(z)| contours (white)
  magLevels = [0.2, 0.5, 1, 2, 3, 4, 5, 10, 15]
  contours = plt.contour(reals, imags, mags, 
                  levels=magLevels, 
                  colors='white', linewidths=0.7)
  plt.clabel(contours, inline=True, 
                  fontsize=8, fmt='%.1f')
  
  # arg(f(z)) contours (black dashed)
  phaseLevels = [-math.pi + k * math.pi/8 
                           for k in range(17)]  
                 # 16 lines from -pi to pi
  phases = plt.contour(reals, imags, phases, 
       levels=phaseLevels, 
       colors='black', linestyles='dashed', linewidths=0.5)
  plt.clabel(phases, inline=True, fontsize=8, fmt=degreeFmt)
  
  plt.title(f"Phase Diagram: " + title)
  plt.xlabel("Re(z)")
  plt.ylabel("Im(z)")
  plt.show()  


def degreeFmt(angle):
  deg = math.degrees(angle)
  return f"{deg:.0f}"


def cubic(z, a, b, c, d):
  return a*z**3 + b*z**2 + c*z + d
