
# Mat.py
# Andrew Davison, June 2025, ad@coe.psu.ac.th

''' Simple Matrix class. 

  The functions include:
    * row(), col(): extract specific row or column as lists
    * submatrix(): submatrix consisting of specified rows and columns
    * maxElement(), minElement()
    * +, -, *, scalar*, ==
    * transpose()
    * isSquare()
    * inverse() using Gauss-Jordan elimination
    * determinant(self) using recursive expansion by minors
    * isSingular()

  Static methods:
    * identity()
    * fill(), fillSq(): fill a matrix with a number
    * read(), readSq()

The current determinant method has a time complexity of O(n!). 
This is inefficient for matrices larger than 4x4 or 5x5. 
The code includes a commented-out version of determinant using 
Gaussian elimination, which is O(n^3).

The current pivoting strategy in inverse (and implicitly in rank) 
swaps with the first available non-zero pivot. For better numerical 
stability, especially with matrices that are ill-conditioned or 
have entries of vastly different magnitudes. it would be better 
to choose the biggest pivot value.
'''

import math

EPS = 1e-9

class Mat:
  def __init__(self, data):
    if not data or \
       not all(len(row) == len(data[0]) for row in data):
      raise ValueError("Badly formed matrix")
    self.data = [row[:] for row in data]
    self.nRows = len(data)
    self.nCols = len(data[0])
    self.fmt = "{:> .2f}"

  def setFormat(self, n):
    if n >=0:
      self.fmt = "{:> ." + str(n) + "f}"


  def __str__(self):
    s = ""
    for row in self.data:
      fRow = [self.fmt.format(val) for val in row]
      s += "\t".join(fRow) + "\n"
    return s

  def __getitem__(self, index):
    ''' slicing/indexing. Allows mat[i][j] instead 
        of mat.data[i][j]  '''
    return self.data[index]


  def shape(self):
    # Returns the dimensions 
    return (self.nRows, self.nCols)


  def copy(self):
    return Mat([row[:] for row in self.data])

  def toList(self):
    return [row[:] for row in self.data]


  # Extract specific row or column as lists
  def row(self, i):
    return self.data[i][:]

  def col(self, j):
    return [self.data[i][j] for i in range(self.nRows)]


  def submatrix(self, rowIndices, colIndices):
    # Return a submatrix consisting of specified rows and columns
    if not all(0 <= r < self.nRows for r in rowIndices):
      raise IndexError("Row index out of range.")
    if not all(0 <= c < self.nCols for c in colIndices):
      raise IndexError("Column index out of range.")

    subData = [[self.data[r][c] for c in colIndices] 
                                     for r in rowIndices]
    return Mat(subData)


  # Max/min across matrix
  def maxElement(self):
    return max(max(row) for row in self.data)

  def minElement(self):
    return min(min(row) for row in self.data)


  def add(self, other):
    if self.nRows != other.nRows or \
       self.nCols != other.nCols:
      raise ValueError("Wrong dimensions to add.")
    result = [[self.data[i][j] + other.data[i][j] 
                 for j in range(self.nCols)]
                  for i in range(self.nRows) ]
    return Mat(result)


  def subtract(self, other):
    if self.nRows != other.nRows or \
       self.nCols != other.nCols:
      raise ValueError("Wrong dimensions to subtract.")
    result = [ [self.data[i][j] - other.data[i][j] 
                  for j in range(self.nCols)]
                   for i in range(self.nRows) ]
    return Mat(result)


  def scalarMultiply(self, scalar):
    result = [ [scalar * self.data[i][j] 
                  for j in range(self.nCols)]
                    for i in range(self.nRows) ]
    return Mat(result)


  def multiply(self, other):
    if self.nCols != other.nRows:
      raise ValueError("Wrong dimensions for multiplication.")
    result = [ [ sum(self.data[i][k] * other.data[k][j] 
                     for k in range(self.nCols))
                       for j in range(other.nCols)
           ] for i in range(self.nRows) ]
    return Mat(result)


  def __add__(self, other):
    return self.add(other)

  def __sub__(self, other):
    return self.subtract(other)

  def __mul__(self, other):
    if isinstance(other, (int, float)):
      return self.scalarMultiply(other)
    elif isinstance(other, Mat):
      return self.multiply(other)
    else:
      return ValueError("Only scalar or Mat supported")

  def __rmul__(self, other):
    if isinstance(other, (int, float)):
      return self.scalarMultiply(other)
    return NotImplemented  # Important for operator chaining

  def __eq__(self, other):
    if not isinstance(other, Mat):
      return False
    if self.shape() != other.shape():
      return False
    for i in range(self.nRows):
      for j in range(self.nCols):
        if abs(self.data[i][j] - other.data[i][j]) > EPS:
          return False
    return True


  def transpose(self):
    result = [ [self.data[j][i] for j in range(self.nRows)]
                                  for i in range(self.nCols) ]
    return Mat(result)


  def isSquare(self):
    return self.nRows == self.nCols


  def inverse(self):
    ''' using Gauss-Jordan elimination:
       https://en.wikipedia.org/wiki/Gaussian_elimination
    '''
    if not self.isSquare():
      raise ValueError("Only square matrices can be inverted")
  
    n = self.nRows  
    # Create augmented mat by appending identity to original
    # Make a deep copy of self.data for augmentation to avoid modifying original matrix
    dataCopy = [row[:] for row in self.data]
    identData = [[1 if i == j else 0 for j in range(n)] 
                                         for i in range(n)]
    aug = [dataCopy[i] + identData[i] for i in range(n)]
  
    for i in range(n):
      # Pivoting: Find a non-zero pivot
      pivotRowIdx = i
      if abs(aug[pivotRowIdx][i]) < EPS:
        hasPivot = False
        for j in range(i + 1, n):
          if abs(aug[j][i]) > EPS:
            pivotRowIdx = j
            hasPivot = True
            break
        if not hasPivot:
          raise ValueError("Matrix is singular and cannot be inverted (pivot not found).")
      
      if pivotRowIdx != i:
          aug[i], aug[pivotRowIdx] = aug[pivotRowIdx], aug[i] # Swap rows

      # Normalize pivot row
      pivot = aug[i][i]
      for j in range(2 * n):
        aug[i][j] /= pivot
  
      # Eliminate other entries in current column
      for k in range(n):
        if k != i:
          factor = aug[k][i]
          for j in range(2 * n): 
            aug[k][j] -= factor * aug[i][j]
  
    # Extract right half of the augmented mat, which is the inverse
    inv_data = [row[n:] for row in aug]
    return Mat(inv_data)


  '''
  def determinant(self):
    """ The determinant is the product of the
      pivots found during the elimination process, 
      adjusted for row swaps.
    """
    if not self.isSquare():
      raise ValueError("Only square matrices have a determinant")

    n = self.nRows
    mat = [row[:] for row in self.data]  # Work on a copy

    det = 1.0
    for i in range(n):
      # 1. Pivoting logic
      pivotRow = i
      if abs(mat[i][i]) < EPS:
        for j in range(i + 1, n):
          if abs(mat[j][i]) > EPS:
            pivotRow = j
            break
        else:  # the matrix is singular
          return 0.0

      if pivotRow != i:
        mat[i], mat[pivotRow] = mat[pivotRow], mat[i]
        det *= -1  # change the sign of the determinant
      
      # 2. The pivot is the diagonal element
      pivot = mat[i][i]
      
      # 3. Accumulate pivot's value in determinant
      det *= pivot

      # 4. Perform elimination to find next correct pivot
      for j in range(i + 1, n):
        factor = mat[j][i] / pivot
        for k in range(i, n):
          mat[j][k] -= factor * mat[i][k]
              
    return det
  '''


  def determinant(self):
    """using recursive expansion by minors."""
    if not self.isSquare():
      raise ValueError("Only square matrices have a determinant")

    n = self.nRows
    if n == 1:
      return self.data[0][0]
    if n == 2:
      return self.data[0][0] * self.data[1][1] - \
             self.data[0][1] * self.data[1][0]

    det = 0
    for col in range(n):
      sign = (-1) ** col
      subData = self._getSubmatrix(0, col)
      subMat = Mat(subData)
      det += sign * self.data[0][col] * subMat.determinant()
    return det


  def _getSubmatrix(self, excludeRow, excludeCol):
    return [
      [self.data[i][j] 
         for j in range(self.nCols) if j != excludeCol]
           for i in range(self.nRows) if i != excludeRow]


  def isSingular(self):
    """is singular (non-invertible)."""
    if not self.isSquare():
      raise ValueError("Only square matrices can be tested for singularity")
    return abs(self.determinant()) < EPS


  '''
  def rank(self):
    """Computes the rank of the matrix using row-reduction
      echelon form (RREF)."""
    mat = [row[:] for row in self.data]  # deep copy
    nRows, nCols = self.nRows, self.nCols
    rank = 0
    row = 0

    for col in range(nCols):
      if row >= nRows:
        break

      # Find non-zero pivot in column
      pivotRow = None
      for r in range(row, nRows):
        if abs(mat[r][col]) > EPS:
          pivotRow = r
          break

      if pivotRow is None:
        continue

      # Swap to move pivot row to current row
      if pivotRow != row:
        mat[row], mat[pivotRow] = mat[pivotRow], mat[row]

      # Normalize pivot row
      pivotVal = mat[row][col]
      mat[row] = [val / pivotVal for val in mat[row]]

      # Eliminate below and above
      for r in range(nRows):
        if r != row and abs(mat[r][col]) > EPS:
          factor = mat[r][col]
          mat[r] = [mat[r][i] - factor * mat[row][i] 
                                  for i in range(nCols)]

      row += 1
      rank += 1
    return rank


  def isOrthogonal(self):
    if not self.isSquare():
      return False
    trans = self.transpose()
    prod = self.multiply(trans)
    identity = Mat.identity(self.nRows)
    return prod == identity


  def trace(self):
    """the trace of a square matrix."""
    if not self.isSquare():
        raise ValueError("Trace is only defined for square matrices.")
    return sum(self.data[i][i] for i in range(self.nRows))
  '''



  # ---------------  static methods  --------

  @staticmethod
  def identity(n):
    return Mat([[1 if i == j else 0 for j in range(n)] 
                                        for i in range(n)])

  @staticmethod
  def fill(rows, cols, val):
    return Mat([[val for _ in range(cols)] for _ in range(rows)])

  @staticmethod
  def fillSq(n, val):
    return Mat.fill(n, n, val)


  @staticmethod
  def read(fnm):
    print("Reading:", fnm)
    m = []
    try:
      with open(fnm, 'r') as f:
        for line in f:
          ln = line.strip()
          if ln:  # Skip blank lines
            try:
              row = [float(v) for v in ln.split()]
              m.append(row)
            except ValueError:
              raise ValueError(f"Non-numeric value found in line: {line}")
    except FileNotFoundError:
      raise FileNotFoundError(f"File '{fnm}' not found.")
    if not m:
      raise ValueError("File is empty or contains no valid data.")
    return Mat(m)
  

  @staticmethod
  def readSq(fnm):
    print("Reading square:", fnm)
    try:
      with open(fnm, 'r') as f:
        lines = f.readlines()
    except FileNotFoundError:
      raise FileNotFoundError(f"File '{fnm}' not found.")

    fvals = []
    for line in lines:
      parts = line.strip().split()
      for part in parts:
        try:
          fvals.append(float(part))
        except ValueError:
          raise ValueError(f"Invalid float value found: '{part}'")

    nVals = len(fvals)
    size = int(nVals ** 0.5)
    if size * size != nVals:
      raise ValueError("The floats do not form a square matrix.")

    m = [fvals[i * size:(i + 1) * size] for i in range(size)]
    return Mat(m)



# ------------- Test rig -------------

if __name__ == "__main__":
  print("Matrix A:")
  a = Mat([[4, 7], [2, 6]])
  print(a)

  print("A*5:")
  print(a*5)
  print("5*A:")
  print(5*a)

  print(f"Shape of A: {a.shape()}\n")

  print("Transpose of A:")
  print(a.transpose())

  print("Determinant of A:")
  det = a.determinant()
  print(f"{det:.2f}\n") # Expected: 4*6 - 7*2 = 10.00
  
  print("Inverse of A:")
  aInv = a.inverse()
  print(aInv)

  print("A * A_inv (Should be Identity):")
  identity = a.multiply(aInv)
  print(identity)


  Acopy = a.copy()
  print(f"A == A? {a == a}")
  print(f"A == Acopy (same data)? {a == Acopy}")
  Acopy[0][1] = 5
  print("Changed Acopy:")
  print(Acopy)
  print(f"A == Acopy? {a == Acopy}\n")


  print("Static identity:")
  print( Mat.identity(5))

  try:
    bad = Mat([[1,2],[3,4,5]])
    print("Error: Bad matrix init did not raise error.")
  except ValueError as e:
    print(f"Error: {e}")
  print()

  print("Matrix B:")
  b = Mat([[1, 0], [0, 1]])
  print(b)

  print("A + B:")
  print(a.add(b))
  print(a + b)

  print("A - B:")
  print(a.subtract(b))
  print(a - b)

  print("A * B:")
  print(a.multiply(b))
  print(a * b)

  c = Mat.readSq('m1.txt')
  print("Matrix C:")
  print(c)

  print(f"Row 0 of C: {c.row(0)}")
  print(f"Column 1 of C: {c.col(1)}")
  print(f"Max element of C: {c.maxElement()}")
  print(f"Min element of C: {c.minElement()}\n")

  dd = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] # Singular, Det = 0
  d = Mat(dd)
  print(f"Matrix D (Singular):\n{d}")
  det_d = d.determinant()
  print(f"Determinant of D: {det_d:.2f} (Expected: 0.00)\n")


  print(f"Is A singular? {a.isSingular()} (Expected: False)")
  print(f"Is C singular? {c.isSingular()} (Expected: False)")
  print(f"Is D singular? {d.isSingular()} (Expected: True)\n")

  try:
    print("Try to invert D:")
    d_inv = d.inverse()
    print(f"Inverse of D (should not compute):\n{d_inv}")
  except ValueError as err:
    print(f"Error: {err}")
  print()

  # Test Non-Square Matrix Multiplication
  e23 = Mat([[1, 2, 3], [4, 5, 6]]) # 2x3
  f32 = Mat([[7, 8], [9, 1], [2, 3]]) # 3x2
  print(f"Matrix E (2x3):\n{e23}")
  print(f"Matrix F (3x2):\n{f32}")
  ef = e23 * f32 # Result should be 2x2
  print(f"E * F (2x2):\n{ef}")

  try:
    print("Try to add E and F:")
    ef = e23 + f32
    print(f"E+F (should not compute):\n{ef}")
  except ValueError as err:
    print(f"Error: {err}")
  print()


  # Test fill, fillSq
  fm = Mat.fill(2, 3, 7.5)
  print(f"Mat.fill(2, 3, 7.5):\n{fm}")
  fsq = Mat.fillSq(3, -1)
  print(f"Mat.fillSq(3, -1):\n{fsq}")

  g = Mat([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]])
  print("Original Matrix G:")
  print(g)
  
  sub = g.submatrix([0, 2], [1, 2])
  print("Submatrix with rows [0, 2] and cols [1, 2]:")
  print(sub)

