
# cubicRoots.py
# Andrew Davison, ad@coe.psu.ac.th, Nov. 2025

'''
Solves equations of the form: ax^3 + bx^2 + cx + d = 0
https://en.wikipedia.org/wiki/Cubic_equation

Combines Cardano and Viete methods.

'''


import math, cmath
import cubicUtils, complexUtils


def findCubicRoots(coefs):
  a, b, c, d = coefs
  print("Equation:", cubicUtils.toString(a, b, c, d))
  disc, roots = cvRoots(a, b, c, d)
  reportDiscriminant(disc)
  printRoots(roots)
  if not checkRoots(roots, a, b, c, d):
    print("**ERROR**")




def cvRoots(a, b, c, d):
  # use Cardano's or Viete's method
  if a == 0:
    raise ValueError("Coef 'a' must not be zero")

  # Convert to depressed cubic: x = t - b/(3a)
  p = (3*a*c - b*b)/(3*a*a)
  q = -(2*b**3 - 9*a*b*c + 27*a*a*d)/(27*a**3)
  # print(f"p = {p}, q = {q}")
  disc = (q/2)**2 + (p/3)**3
  shift = -b/(3*a)
  roots = []

  if abs(disc) < 1e-12:  # disc == 0
    # repeated real roots 
    if abs(p) < 1e-12 and abs(q) < 1e-12:  
      # Triple root; p == q == 0
      x = 0
      roots = [x+shift, x+shift, x+shift]
    else:
      # simple real, and a double root
      t = cubeRootReal(q/2)
      x1 = 2*t
      x2 = -t
      roots = [x1+shift, x2+shift, x2+shift]

  elif disc > 0:
    # Cardano's formula    
    # One real root and two complex conjugates
    print("Cardano")
    sqrtD = math.sqrt(disc)
    # t = (q/2 + sqrtD)**(1/3)
    t = complexUtils.getRealRoot((q/2 + sqrtD), 3)
    u = (-q/2 + sqrtD)**(1/3)
    # print(f"t = {t}, u = {u}, (t*u) = {t*u}")
    
    # Ensure t*u = -p/3 relationship;
    # avoids using the wrong cube root
    if t != 0 and isinstance(u, complex):
      u = -p/(3*t)
      # print(f"Recalculated u: {u},  (t*u) = {t*u}")

    if a < 0:
      u = -u   # so x1, x2, x3 are additions
    x1 = t - u
    # compute the other two roots using the
    # principal cube root of unity   w = e^(2*pi*i/3)
    w = cmath.exp(2*math.pi*1j/3)
    # print("w:", w)
    x2 = w*t - w*w*u
    x3 = w*w*t - w*u
    roots = [complex(x1.real+shift, 0),
             x2+shift, x3+shift]

  else:  # disc < 0
    # Vieta's trigonometric method
    # Three distinct real roots
    print("Vieta")
    arg = q/(2*math.sqrt(-(p/3)**3)) 
    theta = math.acos(arg)/3
    m = 2*math.sqrt(-p/3)
    roots = [
      m*math.cos(theta) + shift,
      m*math.cos(theta + 2*math.pi/3) + shift,
      m*math.cos(theta - 2*math.pi/3) + shift ]

  return disc, roots



def cubeRootReal(x):
  # Real cube root preserving sign for real x.
  if x == 0:
    return 0
  return math.copysign(abs(x)**(1/3), x)


def reportDiscriminant(disc):
  print(f"Discriminant: {disc:.4f}")
  if abs(disc) < 1e-8:   # = 0
    print("All roots are real, and at least two are equal")
  elif disc > 0:
    print("One real root and two complex conjugate roots")
  else: # < 0
    print("All three roots are real and distinct")



def printRoots(roots):
  if roots == []:
    print("No roots found")
  else:
    print("Roots:")
    for root in roots:
      if isinstance(root, complex):
        if abs(root.imag) < 1e-10:  # imag == 0
          print(f"  {root.real: .4f}")
        else:
          print(f"  {root.real: .4f}{root.imag:+.4f}j")
      else:
        print(f"  {root: .4f}")



def checkRoots(roots, a, b, c, d):
  z1, z2, z3 = roots
  br  = -(z1 + z2 + z3)
  cr = z1*z2 + z1*z3 + z2*z3
  dr = -z1 * z2 * z3
  isClose =  cmath.isclose(br, b/a, abs_tol=1e-10) and \
             cmath.isclose(cr, c/a, abs_tol=1e-10) and \
             cmath.isclose(dr, d/a, abs_tol=1e-10)
  if not isClose:
    print(br, cr, dr)
  return isClose


def getReals(xs):
  return [x.real if isinstance(x, complex) else float(x) 
                 for x in xs ]


# ------------------------------
if __name__ == "__main__":
  tests = [
    (1, 0, -3, 1),   # Three distinct real roots (Vieta branch)
    (1, -6, 11, -6), # Three distinct real roots (1,2,3)
    (1, -3, 3, -1),  # Triple root at 1 -> (x-1)^3
    (1, 3, 3, 1),    # Triple root at -1 -> (x+1)^3
    (1, 0, -3, 2),   # One double + one single real root
    (1, 2, 3, 4),    # One real + two complex
    (1, 0, 1, 0),    # One real + two complex
    (1, 0, 2, 2),    # One real + two complex
    (1, 0, 1, 1),    # One real + two complex
    (1, 1, 1, 1),    # One real + two complex
    (2, -4, 6, 8),   # One real + two complex (scaled)
    (-1, 2, 3, 4),   # One real + two complex (negative a)
    (-2, 2, 3, 4),   # One real + two complex (-ve, scaled a)
    (1, 0, 6, -20),  # One real + two complex (section examp)
    (1, 0, -15, -4), # Three distinct real roots (section examp)
    (1, 0, -3, 1),   # Three distinct real roots (section examp)
  ]

  for coef in tests:
    findCubicRoots(coef)
    print()

