
# bernsteins.py
# Andrew Davison, Sept. 2025, ad@coe.psu.ac.th
# input: 4 or 5

import math
import matplotlib.pyplot as plt


def bernsteinBasis(n):
  """Return a list of functions which can 
     calculate Bernstein B_i^n(t) for i=0..n """
  combs = [math.comb(n, i) for i in range(n+1)]
  def makeBern(i):
    def bern(t):
      return combs[i] * (t**i) * ((1-t)**(n-i))
    return bern
  return [makeBern(i) for i in range(n+1)]


def sampleFns(funcs, nSamps):
  """Sample functions between [0,1] at nSamps 
     evenly spaced points."""
  ts = [i/nSamps for i in range(nSamps+1)]
  ys = []
  for f in funcs:
    ys.append([f(t) for t in ts])
  return ts, ys


n = int(input("Degree of the Bernstein basis? "))
basis = bernsteinBasis(n)
ts, ysList = sampleFns(basis, 400)

plt.figure()

# plot all the Bernstein functions B_i,n(t)
for i, ys in enumerate(ysList):
  label = f'$B_{{{i},{n}}}(t)$'
  plt.plot(ts, ys, label=label)

# plot their sum to show unit total
sumYs = [sum(vs) for vs in zip(*ysList)]
plt.plot(ts, sumYs, linestyle="--", lw=2, label="sum")

plt.title(f"Bernstein Basis Polynomials (degree n={n})")
plt.xlabel("t")
plt.ylabel('$B_{i,n}(t)$')
plt.legend(loc="best")
plt.grid(True)
plt.show()
