
# bezUtils.py
# Andrew Davison, Sept. 2025, ad@coe.psu.ac.th

'''
  Generate a Bezier curve using:
    * Bernstein polynomials
    * Bernstein polynomials using 3D pts
    * rational: Bernstein + weights
    * De Casteljau LERP 
    * using Matplotlib path
   and 
     drawControlPts() for drawing the control points, 
     polygon lines, and optional weights
'''

import math
import matplotlib.pyplot as plt
from matplotlib.path import Path



def bezierCurve(ctrlPts, nSamps=100):
  """
  Compute Bézier curve points for any number of control points using Bernstein polynomials.
  controlPoints: list of (x, y) tuples
  nSamps: number of points along the curve
  """
  n = len(ctrlPts) - 1
  xs = []
  ys = []
  if nSamps < 2:
    nSamps = 2
  for step in range(nSamps + 1):
    u = step / nSamps
    qx, qy = 0, 0
    for i in range(n+1):
      b = bernstein(n, i, u)
      qx += b * ctrlPts[i][0]
      qy += b * ctrlPts[i][1]
    xs.append(qx)
    ys.append(qy)
  return xs, ys


def bernstein(n, i, u):
  # Bernstein basis polynomial B_{i,n}(u)
  return math.comb(n, i) * (u**i) * ((1 - u)**(n - i))


def bezier3D(ctrlPts, nSamps=100):
  """
  Compute Bézier curve points for any number of control points using Bernstein polynomials.
  controlPoints: list of (x, y, z) tuples
  nSamps: number of points along the curve
  """
  n = len(ctrlPts) - 1
  xs = []
  ys = []
  zs = []
  if nSamps < 2:
    nSamps = 2
  for step in range(nSamps + 1):
    u = step / nSamps
    qx, qy, qz = 0, 0, 0
    for i in range(n+1):
      b = bernstein(n, i, u)
      qx += b * ctrlPts[i][0]
      qy += b * ctrlPts[i][1]
      qz += b * ctrlPts[i][2]
    xs.append(qx)
    ys.append(qy)
    zs.append(qz)
  return xs, ys, zs



# -------------------------------------------

def ratBezier(ctrlPts, weights, nSamps=100):
  # Return points for a rational Bézier curve
  if len(ctrlPts) != len(weights):
    raise ValueError("control points and weights must have the same length")
  n = len(ctrlPts) - 1
  xs = []
  ys = []
  if nSamps < 2:
    nSamps = 2
  for s in range(nSamps+1):
    u = s / nSamps
    denom = 0
    qx, qy = 0, 0
    for i in range(n+1):
      b = bernstein(n, i, u) * weights[i]
      denom += b
      qx += b * ctrlPts[i][0]
      qy += b * ctrlPts[i][1]
    xt = qx/denom
    yt = qy/denom
    xs.append(xt)
    ys.append(yt)
  return xs, ys


# -------------------------------------------

def bezierCurveDC(ctrlPts, nSamps):
  # De Casteljau
  xs = []   # points making up the curve
  ys = []
  if nSamps < 2:
    nSamps = 2
  for s in range(nSamps+1):
    u = s / nSamps
    qx, qy = bezLerp(ctrlPts, u)
    xs.append(qx)  # one point for each sample
    ys.append(qy)
  return xs, ys


def bezLerp(ctrlPts, u):
  # Iterative De Casteljau for arbitrary-degree Bézier
  pts = list(ctrlPts)  # make a working copy
  n = len(pts)
  for r in range(1, n):   # r = interpolation depth
    for i in range(n - r):
      x = lerp(pts[i][0], pts[i+1][0], u)
      y = lerp(pts[i][1], pts[i+1][1], u)
      pts[i] = (x, y)
  return pts[0]


def lerp(a, b, u):
  # Linear interpolation
  return a + (b-a)*u


# -------------------------------------------


def bezPath(ctrlPts):
  n = len(ctrlPts)
  codes = []
  if n == 3:   # CURVE3 is for a quadratic Bezier
    codes = [Path.MOVETO, Path.CURVE3, Path.CURVE3]
  elif n == 4: # CURVE4 is for a cubic Bezier
    codes = [Path.MOVETO, Path.CURVE4, 
             Path.CURVE4, Path.CURVE4]
  else:
    raise ValueError("control points lengths for paths must == 3 or 4")
  return Path(ctrlPts, codes)



# -------------------------------------------


def drawControlPts(plt, ctrlPts, weights=None):
  xs, ys = zip(*ctrlPts)
  plt.plot(xs, ys, 'ro--', markersize=5, lw=1,
                     label="Control polygon")

  if weights != None:
    for (x, y), w in zip(ctrlPts, weights):
      plt.annotate(f"w={w:.3f}", xy=(x, y), xytext=(4, -6), 
            textcoords="offset points", fontsize=9)

