# Quat.py
# Andrew Davison, ad@coe.psu.ac.th, August 2025

'''
Quaternions extend complex numbers. 
Instead of just one value whose square is -1 
we have three -- i, j and k such that:
  i^2 = j^2 = k^2 = -1  
These are related as follows:  
  1. i . j = k  
  2. j . k = i  
  3. k . i = j

A quaternion has the form: 
  q = s+ai+bj+ck with s,a,b, in R

Often quaternions are separated into their scalar 
and vector terms:  
  q = s+v or q = [s,v] where s in R and v = ai+bj+ck 

For v, I am using my Python Vec class.

The cross and dot products on v simplify quaternion 
multiplication quite a bit (see __mul__()).
'''


import math
from Vec import Vec

class Quat:
  def __init__(self, w=1.0, v=None):
    if not isinstance(w, (int, float)):
      raise TypeError("w must be a number")
    if v is not None and not isinstance(v, Vec):
      raise TypeError("v must be a Vec instance or None")
    
    self.w = float(w)
    self.v = v if v is not None else Vec(0.0, 0.0, 0.0)

  def __repr__(self):
    return f"Quat({self.w:.3f}, {self.v})"

  def __eq__(self, other):
    if not isinstance(other, Quat):
      return NotImplemented
    return math.isclose(self.w, other.w) and self.v.isclose(other.v)

  def __neg__(self):
    return Quat(-self.w, -self.v)

  def copy(self):
    return Quat(self.w, self.v.copy())


  def __add__(self, other):
    return Quat(self.w + other.w, self.v + other.v)

  def __iadd__(self, other):
    # In-place addition: self += other
    self.w += other.w
    self.v += other.v
    return self

  def __sub__(self, other):
    return Quat(self.w - other.w, self.v - other.v)

  def __isub__(self, other):
    # In-place subtraction: self -= other
    self.w -= other.w
    self.v -= other.v
    return self

  def __mul__(self, other):
    if isinstance(other, Quat):
      w1, v1 = self.w, self.v
      w2, v2 = other.w, other.v
      # Quaternion multiplication formula: 
      # (w1,v1)(w2,v2) = (w1*w2 - v1.v2, w1*v2 + w2*v1 + v1 x v2); not commutative
      return Quat( w1*w2 - v1.dot(v2),
                   w1*v2 + w2*v1 + v1.cross(v2) )
    elif isinstance(other, (int, float)): # Scalar mult
      return Quat(self.w*other, self.v*other)
    else:
      raise TypeError("Unsupported operand for multiplication")

  def __rmul__(self, other):
    return self.__mul__(other)

  def __imul__(self, other):
    # In-place multiplication: self *= other
    if isinstance(other, Quat): 
      w1, v1 = self.w, self.v 
      w2, v2 = other.w, other.v 
      # Quaternion multiplication formula: 
      # (w1,v1)(w2,v2) = (w1*w2 - v1.v2, w1*v2 + w2*v1 + v1 x v2)
      self.w = w1*w2 - v1.dot(v2) 
      self.v = w1*v2 + w2*v1 + v1.cross(v2) 
    elif isinstance(other, (int, float)): 
      self.w *= other
      self.v *= other
    else:
      raise TypeError("Unsupported operand for in-place multiplication")
    return self

  def __truediv__(self, scalar):
    return Quat(self.w/scalar, self.v/scalar)

  def __itruediv__(self, scalar):
    # In-place division by a scalar only
    if scalar == 0:
      raise ZeroDivisionError("Cannot divide by zero") #
    self.w /= scalar
    self.v /= scalar
    return self


  def conjugate(self):
    return Quat(self.w, -self.v)

  def normSquared(self):
    return self.w**2 + self.v.normSquared()

  def norm(self):  # also called magnitude
    return math.sqrt(self.normSquared())

  def normalize(self):
    n2 = self.normSquared()
    if n2 == 0:
      raise ZeroDivisionError("Cannot normalize a zero quaternion")
    if math.isclose(n2, 1.0, rel_tol=1e-12):  # Already normalized
      return self.copy()
    return self / math.sqrt(n2)

  def inverse(self):
    n2 = self.normSquared()
    if n2 == 0:
      raise ZeroDivisionError("Cannot invert a zero quaternion")
    return self.conjugate()/n2


  def isUnit(self, tolerance=1e-6):
    # is quaternion a unit length?
    return math.isclose(self.norm(), 1.0, rel_tol=tolerance)

  def dot(self, other):
    return self.w * other.w + self.v.dot(other.v)

  def distance(self, other):
    # Angular distance between two unit quaternions
    return math.acos(min(1.0, abs(self.normalize().dot(other.normalize()))))

  def distanceDegrees(self, other):
    # Angular distance in degrees
    return math.degrees(self.distance(other))

  def power(self, t):
    # Raise to power t (useful for scaling rotations).
    if self.v.isZero():
      return Quat(math.pow(self.w, t), Vec(0, 0, 0))
    
    log_q = self.log()
    return (log_q * t).exp()


  def rotateVec(self, vec):
    # Rotate Vec using this quaternion.
    if not isinstance(vec, Vec):
      raise TypeError("Input must be a Vec instance")
    
    # Warn but don't fail for non-unit quaternions
    if not self.isUnit(1e-6):
      print("Warning: Quaternion is not unit length. Consider normalizing first.")

    v = Quat(0, vec)
    # Implements p*v*p_conjugate, return vector part
    return (self * v * self.conjugate()).v


  def log(self):
    # Logarithm of a unit quaternion
    # exp(u theta) = cos theta + u sin theta
    q = self.normalize()
    vMag = q.v.magnitude()
    if vMag == 0:
      return Quat(0, Vec(0, 0, 0))
    theta = math.acos(q.w)
    coeff = theta / math.sin(theta) \
               if math.sin(theta) != 0 else 0
    return Quat(0, q.v * coeff)


  def exp(self):
    # Exponential of a pure quaternion 
    # (assumes self.w = 0)
    vMag = self.v.magnitude()
    w = math.cos(vMag)
    if vMag == 0:
      return Quat(w, Vec(0, 0, 0))
    scale = math.sin(vMag) / vMag
    return Quat(w, self.v * scale)



  @staticmethod
  def lerps(q0, q1, nSteps):
    return Quat.interps(Quat.lerp, q0, q1, nSteps)

  @staticmethod
  def slerps(q0, q1, nSteps):
    return Quat.interps(Quat.slerp, q0, q1, nSteps)

  @staticmethod
  def slerpsNInv(q0, q1, nSteps):
    return Quat.interps(Quat.slerpNInv, q0, q1, nSteps)


  @staticmethod
  def interps(fn, q0, q1, nSteps):
    step =1/nSteps
    t = 0
    quats = []
    while t < 1.0:
      quats.append( fn(q0, q1, t))
      t += step
    return quats


  @staticmethod
  def lerp(q0, q1, t):
    """ Linear interpolation between q0 and q1.
    The result is normalized to ensure it's a 
    valid rotation quaternion. """
    t = max(0, min(1, t))
    return ((q0 * (1 - t)) + (q1 * t)).normalize()


  @staticmethod
  def slerp(q0, q1, t):
    """ Spherical linear interpolation between q0 and q1 (Shoemake 85)
    Improved numerical stability version.
    """
    t = max(0, min(1, t))
    # Normalize input quaternions to be safe
    q0 = q0.normalize()
    q1 = q1.normalize()
  
    # get cosine of angle between quats
    dotCos = q0.w * q1.w + q0.v.dot(q1.v)
    
    if dotCos < 0:
      # negate q1 to take shorter path
      q1 = -q1
      dotCos = -dotCos
    
    # Clamp dotCos to avoid numerical issues with acos
    dotCos = max(-1.0, min(1.0, dotCos))
    
    # If quaternions are very close, use linear interpolation to avoid div by zero
    if dotCos > 0.9995:
      return Quat.lerp(q0, q1, t)
    
    # Use more numerically stable formulation
    theta = math.acos(dotCos)
    sinTheta = math.sin(theta)
    t_theta = theta * t
    s0 = math.cos(t_theta) - (math.sin(t_theta) * dotCos)/sinTheta
    s1 = math.sin(t_theta) / sinTheta
    return (s0*q0 + s1*q1).normalize()


  @staticmethod
  def spline(qPrev, q, qNext):
    """Compute control point for quaternion spline interpolation  (Eberly equ 31)."""
    qInv = q.inverse()
    delta = ((qInv * qPrev).log() + 
             (qInv * qNext).log()) * -0.25
    return q * delta.exp()


  @staticmethod
  def slerpNInv(q0, q1, t):
    """Slerp without checking for long path (no inversion).
    Improved numerical stability version."""
    t = max(0, min(1, t))
    q0 = q0.normalize()
    q1 = q1.normalize()
    dotCos = q0.w * q1.w + q0.v.dot(q1.v)

    # Clamp dotCos to avoid numerical issues with acos
    dotCos = max(-1.0, min(1.0, dotCos))

    if abs(dotCos) < 0.9995:
      theta = math.acos(abs(dotCos))  
          # Use abs(dotCos) since we don't invert
      sinTheta = math.sin(theta)
      t_theta = theta * t
      s0 = math.cos(t_theta) - \
           (math.sin(t_theta) * abs(dotCos))/sinTheta
      s1 = math.sin(t_theta) / sinTheta
      return (s0*q0 + s1*q1).normalize()
    else:
      return Quat.lerp(q0, q1, t)


  @staticmethod
  def squad(q0, q1, a, b, t):
    """Spherical cubic interpolation (Shoemake 87)
       a and b form a quadrangle with q0 and q1 that will result in smooth interpolations over a list of rotation keyframes.
    """
    c = Quat.slerpNInv(q0, q1, t)
    d = Quat.slerpNInv(a, b, t)
    return Quat.slerpNInv(c, d, 2*t*(1 - t))



  # ----------- conversion utils -----------


  @staticmethod
  def fromLatLon(latDeg, lonDeg):
    """Convert latitude and longitude (in degrees) to a unit quaternion.
    Assumes the point lies on a unit sphere centered at origin.
    Returns a quaternion q = [0, x, y, z]
    """
    latRad = math.radians(latDeg)
    lonRad = math.radians(lonDeg)
  
    x = math.cos(latRad) * math.cos(lonRad)
    y = math.cos(latRad) * math.sin(lonRad)
    z = math.sin(latRad)
    return Quat(0, Vec(x, y, z))
  

  def toLatLon(self):
    """Convert this quaternion q = [0, x, y, z] on the 
       unit sphere back to (lat, lon) in degrees"""
    axis = self.v
    r = axis.magnitude()
    if r == 0:
      raise ValueError("Quaternion has zero vector part")
  
    latDeg = math.degrees( math.asin(axis.z/r))
    lonDeg = math.degrees( math.atan2(axis.y, axis.x))
    return latDeg, lonDeg



  @staticmethod
  def fromAxisAngle(axis, angleRad):
    half = angleRad/2
    # Quaternion from axis-angle formula: 
    #    q = [cos(angle/2), axis*sin(angle/2)]
    return Quat(math.cos(half), 
                axis.normalize()*math.sin(half))


  def toAxisAngle(self):
    angle = 2*math.acos(self.w)
    sinHalf = math.sqrt(1 - self.w**2)
    # Avoid division by zero or very small numbers
    if sinHalf < 1e-8:
      return Vec(1, 0, 0), 0.0 # Return a default axis and zero angle for small rotations
    axis = self.v/sinHalf
    return axis, angle


  @staticmethod
  def fromEuler(roll, pitch, yaw):
    """Convert Euler angles (in radians) to a quaternion.
    Rotation order is roll (x), pitch (y), yaw (z) """
    cy = math.cos(yaw/2)
    sy = math.sin(yaw/2)
    cp = math.cos(pitch/2)
    sp = math.sin(pitch/2)
    cr = math.cos(roll/2)
    sr = math.sin(roll/2)

    w = cr*cp*cy + sr*sp*sy
    x = sr*cp*cy - cr*sp*sy
    y = cr*sp*cy + sr*cp*sy
    z = cr*cp*sy - sr*sp*cy
    return Quat(w, Vec(x, y, z))


  def toEuler(self):
    """Convert a quaternion to Euler angles (roll, pitch, yaw) in radians.
    Assumes XYZ rotation order."""
    x, y, z = self.v.x, self.v.y, self.v.z
    w = self.w
  
    # Roll (x-axis rotation)
    sinrCosp = 2 * (w * x + y * z)
    cosrCosp = 1 - 2 * (x * x + y * y)
    roll = math.atan2(sinrCosp, cosrCosp)
  
    # Pitch (y-axis rotation)
    sinp = 2 * (w * y - z * x)
    if abs(sinp) >= 1:
      pitch = math.pi / 2 * math.sign(sinp)  
      # use 90 degs if out of range
    else:
      pitch = math.asin(sinp)
  
    # Yaw (z-axis rotation)
    sinyCosp = 2 * (w * z + x * y)
    cosyCosp = 1 - 2 * (y * y + z * z)
    yaw = math.atan2(sinyCosp, cosyCosp)
  
    return roll, pitch, yaw



  def toMatrix3(self):
    """ Return 3x3 rotation matrix as list of 3 row lists.
        Assumes unit quaternion.
    """
    x, y, z = self.v.x, self.v.y, self.v.z
    w = self.w

    # Elements of the rotation matrix derived from a quaternion
    return [
      [1 - 2*(y*y + z*z),     2*(x*y - z*w),     2*(x*z + y*w)],
      [    2*(x*y + z*w), 1 - 2*(x*x + z*z),     2*(y*z - x*w)],
      [    2*(x*z - y*w),     2*(y*z + x*w), 1 - 2*(x*x + y*y)] ]


  def toMatrix4(self):
    """ Return 4x4 homogeneous rotation matrix.
        Adds a translation component of (0,0,0) and 
        a perspective component for 3D graphics.
    """
    m = self.toMatrix3()
    return [ m[0] + [0],
             m[1] + [0],
             m[2] + [0],
             [0, 0, 0, 1]  ]


# -------------------


def quatCurve(quats, segSteps):
  """
  Interpolates a list of quaternions to form a smooth 
  curve using squad interpolation.
  
  This function calculates intermediate control points 
  using Quat.spline() and interpolates between segments
  using Quat.squad(). 

  Boundary conditions for the first and last segments 
  are handled by duplicating the start/end quaternions 
  for the spline calculation.
  """

  if not quats:
      return []
  if len(quats) == 1:
      return [quats[0]]
  
  curve = []
  numSegs = len(quats) - 1
  for i in range(numSegs):
    q1 = quats[i]
    q2 = quats[i+1]
  
    # previous quaternion
    qPrev = quats[max(0, i-1)]  # 0 is min index
  
    # next quaternion
    qNext = quats[min(numSegs, i+2)] # numSegs is max index
  
    # control points 'a' and 'b'
    a = Quat.spline(qPrev, q1, q2)
    b = Quat.spline(q1, q2, qNext)
  
    # squad interpolation
    for step in range(segSteps):
      t = step / segSteps
      iquat = Quat.squad(q1, q2, a, b, t)
      curve.append(iquat)
  
  # Add last control point to ensure the curve ends 
  # at the final quaternion.
  if numSegs > 0:
    curve.append(quats[-1])
  
  return curve



# --- test rig -------------

if __name__ == "__main__":
  axis = Vec(0, 0, 1)  # z-axis
  print("Create quat from Axis, Angle:\n", axis, ",", 90)
  p = Quat.fromAxisAngle(axis, math.pi/2)
  print("Result:", p)
  v1 = Vec(1, 0, 0) # x-axis
  v2 = p.rotateVec(v1)
  print(f"Quat rotates {v1} --> {v2}")

  print("\nRoll (x), pitch (y), yaw (z): 30, 0, 0")
  q = Quat.fromEuler(math.radians(30), 
         math.radians(0), math.radians(0))
  print("Quat version:", q)
  roll, pitch, yaw = q.toEuler()
  print("Recovered roll/pitch/yaw:",
        round(math.degrees(roll), 1),
        round(math.degrees(pitch), 1),
        round(math.degrees(yaw), 1))
  print("As a matrix:")
  m3 = q.toMatrix3()
  for row in m3:
    print(" ", row)

