
# viewCMath.py
# Andrew Davison, ad@coe.psu.ac.th, Nov. 2025

'''
Visualizes cmath complex functions by plotting magnitude, real part, imaginary part, and phase over a 2D domain in the complex plane as 3D surfaces. Shows all four plots simultaneously in a 2x2 grid.

The user supplies the function name, and the axes ranges are read from a funcs.txt file. 
'''

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import cmath

import complexUtils
import numpy as np   # used for meshgrid() and array()


def extractCoords(plotType, zs):
  # Extract the component to plot
  vs = []
  label = '?'
  for i in range(len(zs)):
    row = []
    for j in range(len(zs[i])):
      if plotType == 'm':   # magnitude
        val = abs(zs[i][j])
        label = '|f(z)|'
      elif plotType == 'r':  # real
        val = zs[i][j].real
        label = 'Re(f(z))'
      elif plotType == 'i':  # imag
        val = zs[i][j].imag
        label = 'Im(f(z))'
      elif plotType == 'p':  # phase
        val = cmath.phase(zs[i][j])
        label = 'arg(f(z))'
      else:
        raise ValueError("plotType must be 'm', 'r', 'i', or 'p'")
      row.append(val)
    vs.append(row)
  return label, vs


# --------------------------------
steps = 150

data = complexUtils.loadCMFuncs()
if not data:
  exit()

funcName = input("Enter function name: ").strip().lower()
if not funcName in data:
  print(f"No entry found for '{funcName}'.")
  exit()

if not hasattr(cmath, funcName):
  print(f"No cmath function named '{name}' found.")
  exit()

func = getattr(cmath, funcName)

print(f"Plotting {funcName}(z)")
realMin, realMax, imagMin, imagMax = data[funcName]
print(f"Real range: [{realMin}, {realMax}]")
print(f"Imaginary range: [{imagMin}, {imagMax}]")

reals = complexUtils.linspace(realMin, realMax, steps)
imags = complexUtils.linspace(imagMin, imagMax, steps)
zs = complexUtils.evalFunc(func, reals, imags)

# Create meshgrid once for all plots
realGrid, imagGrid = np.meshgrid(reals, imags)

# Create figure with 2x2 subplots
fig = plt.figure(figsize=(10, 8))

# Define plot types and their positions
plotTypes = [('m', 1), ('r', 2), ('i', 3), ('p', 4)]
plotNames = {'m': 'Magnitude', 'r': 'Real Part', 
              'i': 'Imaginary Part', 'p': 'Phase'}

for plotType, pos in plotTypes:
  ax = fig.add_subplot(2, 2, pos, projection='3d')
  
  zlabel, vs = extractCoords(plotType, zs)
  Z = np.array(vs)
  
  # Plot surface
  surf = ax.plot_surface(realGrid, imagGrid, Z, 
            cmap='viridis', alpha=0.9, 
            edgecolor='none', antialiased=True)
  
  ax.set_xlabel('Re')
  ax.set_ylabel('Im')
  ax.set_zlabel(zlabel)
  ax.set_title(f'{funcName}(z): {plotNames[plotType]}')
  
  # Set viewing angle
  ax.view_init(elev=30, azim=225)   # 45
  
  # Add border around the subplot
  for spine in ax.spines.values():
    spine.set_edgecolor('lightgray')
    spine.set_linewidth(2)
  ax.patch.set_edgecolor('lightgray')
  ax.patch.set_linewidth(2)

plt.suptitle(f'Complex Function Visualization: {funcName}(z)', y=0.98)
plt.tight_layout()
plt.show()
