
# Mat.py
# Andrew Davison, June 2025, ad@coe.psu.ac.th

''' Simple Matrix class. 

  The functions include:
    * row(), col(): extract specific row or column as lists
    * submatrix(): submatrix consisting of specified rows and columns
    * maxElement(), minElement()
    * +, -, *, scalar*, ==
    * transpose()
    * isSquare()
    * inverse() using Gauss-Jordan elimination
    * determinant(self) using recursive expansion by minors
    * isSingular()

  Static methods:
    * identity()
    * fill(), fillSq(): fill a matrix with a number
    * read(), readSq()

The current determinant method has a time complexity of O(n!). 
This is inefficient for matrices larger than 4x4 or 5x5. 
The code includes a commented-out version of determinant using 
Gaussian elimination, which is O(n^3).
'''

import math

EPS = 1e-9


class Mat:
  def __init__(self, data):
    if not self.isValid(data):
      raise ValueError("Badly formed matrix: empty, contains an empty row, or rows are inconsistent in length")
    self.data = [row[:] for row in data]
    self.nRows = len(data)
    self.nCols = len(data[0])
    self.fmt = "{:> .2f}"


  def isValid(self, data):
    # is data empty or not a list?
    if not data or not isinstance(data, list):
      return False

    if len(data) == 0:      # at least one row?
      return False

    if len(data[0]) == 0:   # first row is non-empty
      return False

    # Ensure that every row is not empty and has the same 
    # length as the first row.
    rowLen = len(data[0])
    for row in data:
      if not row or len(row) != rowLen:
        return False
    return True


  def setFormat(self, n):
    if n >=0:
      self.fmt = "{:> ." + str(n) + "f}"


  def __str__(self):
    s = ""
    for row in self.data:
      fRow = [self.fmt.format(val) for val in row]
      s += "\t".join(fRow) + "\n"
    return s

  def __getitem__(self, index):
    ''' slicing/indexing. Allows mat[i][j] instead 
        of mat.data[i][j]  '''
    return self.data[index]


  def order(self):
    # Returns the dimensions 
    return (self.nRows, self.nCols)


  def copy(self):
    return Mat([row[:] for row in self.data])

  def toList(self):
    return [row[:] for row in self.data]


  # Extract specific row or column as lists
  def row(self, i):
    return self.data[i][:]

  def col(self, j):
    return [self.data[i][j] for i in range(self.nRows)]


  def submatrix(self, rowIndices, colIndices):
    # Return a submatrix consisting of specified rows and columns
    if not all(0 <= r < self.nRows for r in rowIndices):
      raise IndexError("Row index out of range.")
    if not all(0 <= c < self.nCols for c in colIndices):
      raise IndexError("Column index out of range.")

    subData = [[self.data[r][c] for c in colIndices] 
                                     for r in rowIndices]
    return Mat(subData)


  def submatrixExcept(self, rowIndicesX, colIndicesX):
    # Return a submatrix without the specified rows and columns
    if not all(0 <= r < self.nRows for r in rowIndicesX):
      raise IndexError("Row index out of range.")
    if not all(0 <= c < self.nCols for c in colIndicesX):
      raise IndexError("Column index out of range.")

    rs = list(range(self.nRows))
    rowIndices = [i for i in rs if i not in rowIndicesX]
    cs = list(range(self.nCols))
    colIndices = [i for i in cs if i not in colIndicesX]
    return self.submatrix(rowIndices, colIndices)


  # Max/min across matrix
  def maxElement(self):
    return max(max(row) for row in self.data)

  def minElement(self):
    return min(min(row) for row in self.data)


  def add(self, other):
    if self.nRows != other.nRows or \
       self.nCols != other.nCols:
      raise ValueError("Wrong dimensions to add.")
    result = [[self.data[i][j] + other.data[i][j] 
                 for j in range(self.nCols)]
                  for i in range(self.nRows) ]
    return Mat(result)


  def subtract(self, other):
    if self.nRows != other.nRows or \
       self.nCols != other.nCols:
      raise ValueError("Wrong dimensions to subtract.")
    result = [ [self.data[i][j] - other.data[i][j] 
                  for j in range(self.nCols)]
                   for i in range(self.nRows) ]
    return Mat(result)


  def scalarMultiply(self, scalar):
    result = [ [scalar * self.data[i][j] 
                  for j in range(self.nCols)]
                    for i in range(self.nRows) ]
    return Mat(result)


  def multiply(self, other):
    if self.nCols != other.nRows:
        raise ValueError("Wrong dimensions for multiplication.")

    # Initialize the result matrix with zeros
    result = [[0 for _ in range(other.nCols)] 
                         for _ in range(self.nRows)]

    # Explicit loops for matrix multiplication
    for i in range(self.nRows):
      for j in range(other.nCols):
        sumVal = 0
        for k in range(self.nCols): # or other.nRows
          sumVal += self.data[i][k] * other.data[k][j]
        result[i][j] = sumVal
    return Mat(result)

  '''
  def multiply(self, other):
    if self.nCols != other.nRows:
      raise ValueError("Wrong dimensions for multiplication.")
    result = [ [ sum(self.data[i][k] * other.data[k][j] 
                     for k in range(self.nCols))
                       for j in range(other.nCols)
           ] for i in range(self.nRows) ]
    return Mat(result)
  '''

  def __add__(self, other):
    return self.add(other)

  def __sub__(self, other):
    return self.subtract(other)

  def __mul__(self, other):
    if isinstance(other, (int, float)):
      return self.scalarMultiply(other)
    elif isinstance(other, Mat):
      return self.multiply(other)
    else:
      raise ValueError("Only scalar or Mat supported")

  def __rmul__(self, other):
    if isinstance(other, (int, float)):
      return self.scalarMultiply(other)
    elif isinstance(other, Mat):
      return other.multiply(self)
    return NotImplemented  # Important for operator chaining


  # overload @ to mean matrix multiplication
  def __matmul__(self, other):
    if isinstance(other, Mat):
      return self.multiply(other)
    return NotImplemented

  def __rmatmul__(self, other):
    if isinstance(other, Mat):
      return other.multiply(self)
    return NotImplemented


  def __eq__(self, other):
    if not isinstance(other, Mat):
      return False
    if self.order() != other.order():
      return False
    for i in range(self.nRows):
      for j in range(self.nCols):
        if abs(self.data[i][j] - other.data[i][j]) > EPS:
          return False
    return True

  def __neg__(self):
    return self.scalarMultiply(-1)


  def transpose(self):
    result = [ [self.data[j][i] for j in range(self.nRows)]
                                  for i in range(self.nCols) ]
    return Mat(result)


  def isSquare(self):
    return self.nRows == self.nCols


  def inverse(self):
    ''' using Gauss-Jordan elimination with robust pivoting:
       https://en.wikipedia.org/wiki/Gaussian_elimination
    '''
    if not self.isSquare():
      raise ValueError("Only square matrices can be inverted")

    n = self.nRows
    # Create augmented matrix with self.data on the left 
    # and the identity matrix on the right.
    aug = [self.data[i][:] + [1 if i == j else 0 for j in range(n)] 
                                              for i in range(n)]
    
    for i in range(n):
      # Robust pivoting: choose row with largest absolute value in column i.
      pivotRowIdx = i
      maxVal = abs(aug[i][i])
      for r in range(i + 1, n):
        if abs(aug[r][i]) > maxVal:
          maxVal = abs(aug[r][i])
          pivotRowIdx = r

      if maxVal < EPS:
        raise ValueError("Matrix is singular and cannot be inverted (pivot too small).")

      if pivotRowIdx != i:
        aug[i], aug[pivotRowIdx] = aug[pivotRowIdx], aug[i]  # Swap rows
      
      pivot = aug[i][i]
      # Normalize the pivot row.
      for j in range(2 * n):
        aug[i][j] /= pivot
      
      # Eliminate the current column in all the other rows.
      for r in range(n):
        if r != i:
          factor = aug[r][i]
          for j in range(2 * n):
            aug[r][j] -= factor * aug[i][j]
    
    # Extract the right half as the inverse matrix.
    invData = [row[n:] for row in aug]
    return Mat(invData)
  

  '''
  def determinant(self):
    # Determinant using Gaussian elimination (O(n^3)) with robust pivoting.
    if not self.isSquare():
      raise ValueError("Only square matrices have a determinant")
    
    n = self.nRows
    # Work on a copy of the matrix data.
    mat = [row[:] for row in self.data]
    det = 1.0
    
    for i in range(n):
      # Robust pivoting: choose the row with the largest absolute value in column i.
      pivotRowIdx = i
      maxVal = abs(aug[i][i])
      for r in range(i + 1, n):
        if abs(aug[r][i]) > maxVal:
          maxVal = abs(aug[r][i])
          pivotRowIdx = r

      if maxVal < EPS:
        return 0.0  # The matrix is singular.
      
      if pivotRowIdx != i:
        # Swap rows and adjust the sign of the determinant.
        mat[i], mat[pivotRowIdx] = mat[pivotRowIdx], mat[i]
        det *= -1
      
      pivot = mat[i][i]
      det *= pivot
      
      # Eliminate rows below the pivot row.
      for r in range(i + 1, n):
        factor = mat[r][i] / pivot
        for j in range(i, n):
          mat[r][j] -= factor * mat[i][j]
    
    return det
  '''


  def determinant(self):
    # using recursive expansion by minors
    if not self.isSquare():
      raise ValueError("Only square matrices have a determinant")

    n = self.nRows
    # base cases
    if n == 1:
      return self.data[0][0]
    if n == 2:
      return self.data[0][0] * self.data[1][1] - \
             self.data[0][1] * self.data[1][0]

    k = 0   # delete kth column
    det = 0
    for i in range(n):
      minor = self.submatrixExcept([i], [k])
      det += self.data[k][i] * (-1)**(i+k) * minor.determinant()
    return det


  def isSingular(self):
    # is singular (non-invertible)
    if not self.isSquare():
      raise ValueError("Only square matrices can be tested for singularity")
    return abs(self.determinant()) < EPS


  # ---------------  static methods  --------

  @staticmethod
  def identity(n):
    return Mat([[1 if i == j else 0 for j in range(n)] 
                                        for i in range(n)])

  @staticmethod
  def fill(rows, cols, val):
    return Mat([[val for _ in range(cols)] for _ in range(rows)])

  @staticmethod
  def fillSq(n, val):
    return Mat.fill(n, n, val)


  @staticmethod
  def read(fnm):
    print("Reading:", fnm)
    m = []
    try:
      with open(fnm, 'r') as f:
        for line in f:
          ln = line.strip()
          if ln:  # Skip blank lines
            try:
              row = [float(v) for v in ln.split()]
              m.append(row)
            except ValueError:
              raise ValueError(f"Non-numeric value found in line: {line}")
    except FileNotFoundError:
      raise FileNotFoundError(f"File '{fnm}' not found.")
    if not m:
      raise ValueError("File is empty or contains no valid data.")
    return Mat(m)
  

  @staticmethod
  def readSq(fnm):
    print("Reading square:", fnm)
    try:
      with open(fnm, 'r') as f:
        lines = f.readlines()
    except FileNotFoundError:
      raise FileNotFoundError(f"File '{fnm}' not found.")

    fvals = []
    for line in lines:
      parts = line.strip().split()
      for part in parts:
        try:
          fvals.append(float(part))
        except ValueError:
          raise ValueError(f"Invalid float value found: '{part}'")

    nVals = len(fvals)
    size = int(nVals ** 0.5)
    if size * size != nVals:
      raise ValueError("The floats do not form a square matrix.")

    m = [fvals[i * size:(i + 1) * size] for i in range(size)]
    return Mat(m)



# ------------- Test rig -------------

if __name__ == "__main__":
  print("Matrix A:")
  a = Mat([[4, 7], [2, 6]])
  print(a)

  print("A*5:")
  print(a*5)
  print("5*A:")
  print(5*a)

  print(f"Shape of A: {a.order()}\n")

  print("Transpose of A:")
  print(a.transpose())

  print("Determinant of A:")
  det = a.determinant()
  print(f"{det:.2f}\n") # Expected: 4*6 - 7*2 = 10.00
  
  print("Inverse of A:")
  aInv = a.inverse()
  print(aInv)

  print("A * A_inv (Should be Identity):")
  identity = a.multiply(aInv)
  print(identity)


  Acopy = a.copy()
  print(f"A == A? {a == a}")
  print(f"A == Acopy (same data)? {a == Acopy}")
  Acopy[0][1] = 5
  print("Changed Acopy:")
  print(Acopy)
  print(f"A == Acopy? {a == Acopy}\n")
  print(f"A != Acopy? {a != Acopy}\n")


  print("Static identity:")
  print( Mat.identity(5))

  try:
    bad = Mat([[1,2],[3,4,5]])
    print("Error: Bad matrix init did not raise error.")
  except ValueError as e:
    print(f"Error: {e}")
  print()

  print("Matrix B:")
  b = Mat([[1, 0], [0, 1]])
  print(b)

  print("A + B:")
  print(a.add(b))
  print(a + b)

  print("A - B:")
  print(a.subtract(b))
  print(a - b)

  print("A * B:")
  print(a.multiply(b))
  print(a * b)
  print(a @ b)

  c = Mat.readSq('m1.txt')
  print("Matrix C:")
  print(c)

  print(f"Row 0 of C: {c.row(0)}")
  print(f"Column 1 of C: {c.col(1)}")
  print(f"Max element of C: {c.maxElement()}")
  print(f"Min element of C: {c.minElement()}\n")

  dd = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] # Singular, Det = 0
  d = Mat(dd)
  print(f"Matrix D (Singular):\n{d}")
  det_d = d.determinant()
  print(f"Determinant of D: {det_d:.2f} (Expected: 0.00)\n")


  print(f"Is A singular? {a.isSingular()} (Expected: False)")
  print(f"Is C singular? {c.isSingular()} (Expected: False)")
  print(f"Is D singular? {d.isSingular()} (Expected: True)\n")

  try:
    print("Try to invert D:")
    d_inv = d.inverse()
    print(f"Inverse of D (should not compute):\n{d_inv}")
  except ValueError as err:
    print(f"Error: {err}")
  print()

  # Test Non-Square Matrix Multiplication
  e23 = Mat([[1, 2, 3], [4, 5, 6]]) # 2x3
  f32 = Mat([[7, 8], [9, 1], [2, 3]]) # 3x2
  print(f"Matrix E (2x3):\n{e23}")
  print(f"Matrix F (3x2):\n{f32}")
  ef = e23 * f32 # Result should be 2x2
  print(f"E * F (2x2):\n{ef}")

  try:
    print("Try to add E and F:")
    ef = e23 + f32
    print(f"E+F (should not compute):\n{ef}")
  except ValueError as err:
    print(f"Error: {err}")
  print()


  # Test fill, fillSq
  fm = Mat.fill(2, 3, 7.5)
  print(f"Mat.fill(2, 3, 7.5):\n{fm}")
  fsq = Mat.fillSq(3, -1)
  print(f"Mat.fillSq(3, -1):\n{fsq}")

  g = Mat([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]])
  print("Original Matrix G:")
  print(g)
  
  sub = g.submatrix([0, 2], [1, 2])
  print("Submatrix with rows [0, 2] and cols [1, 2]:")
  print(sub)

  sub = g.submatrixExcept([0], [2])
  print("Submatrix without row [0] and col [2]:")
  print(sub)
