
# MarkovLib.py
# Andrew Davison, June 2025, ad@coe.psu.ac.th

'''
  Collection of functions for manipulating Markov chains.
  Uses Mat.py and the graphviz module.
'''

import graphviz
from Mat import *

TEMP_FNM = 'temp_graph'

EPS = 1e-8


def readMarkov(fnm):
  print('Reading:', fnm+".txt")
  with open(fnm+".txt", 'r') as file:
    # Remove blank lines and comments
    lines = [line.strip() for line in file 
        if line.strip() and not line.strip().startswith('#')]
  if not lines:
    raise ValueError("File contains no usable data.")

  labels = []
  matrix = []
  if lines[0].lower().startswith('labels:'):
    labels = lines[0][len('labels:'):].strip().split()
    size = len(labels)
    dataLines = lines[1:size+1]
    for line in dataLines:
      toks = line.split()
      matrix.append([float(token) for token in toks])
    # print(labels, matrix)
  else:
    # Matrix without labels; use 1, 2, ....
    size = len(lines[0].split())
    labels = [str(i + 1) for i in range(size)]
    for line in lines[:size]:
      toks = line.split()
      if len(toks) != size:
        raise ValueError("Matrix must be square.")
      matrix.append([float(token) for token in toks])

  return labels, Mat(matrix)


def powers(labels, mat, n=100):
  # Raise mat to powers up to n
  print()
  q = mat.copy()
  for power in range(1, n+1):
    print(f"mat^{power}:")
    matLabels(q, labels,fmt="{:.4f}")
    q = mat*q



def absorb(Q, R):
  '''
    A transition matrix P =
       Q  R
       0  I
    Q = prob of transitioning from some transient state to another
    R = prob of transitioning from some transient state to some absorbing state. 
  '''
  # https://en.wikipedia.org/wiki/Absorbing_Markov_chain
  n = Q.nRows
  N = (Mat.identity(n)- Q).inverse()  # (I-Q)^{-1}
  t = N*Mat.fill(n, 1, 1)  # column vector of 1's
  B = N*R
  return N, B, t


def steady(P):
  Z, pi = getZPi(P)
  n = P.nRows
  M = Mat.fillSq(n, 0)
  r = [0]*n
  # M = mean first passage times
  for i in range(n):
    for j in range(n):
      if pi[j] < EPS:   # ignore transient states
        M[i][j] = 0
      else:
        M[i][j] = (Z[j][j] - Z[i][j]) / pi[j]
    if pi[i] < EPS:  # ignore transient states
      r[i] = 0
    else:
      r[i] = 1.0 / pi[i]   # r = mean recurrence times
  return Z, pi, r, M
 

def getZPi(P):
  # Computes fundamental matrix Z and steady state vector pi
  n = P.nRows
  e = Mat.fill(n, 1, 1)  # column vector, all 1's
  beta = Mat.fill(1, n, 0)  # row vector
  beta[0][0] = 1  # so that beta * e == 1

  X = Mat.identity(n) - P + (e*beta)
  Z = X.inverse()
  pi = beta*Z 
  return Z, pi.row(0)  # return pi as a list


def matLabels(matrix, rowLabels, 
                      colLabels=None, fmt="{:> .2f}"):
  if colLabels == None:
    colLabels = rowLabels
  print("\t" + "\t".join(colLabels))
  for i, row in enumerate(matrix.data):
    formattedRow = [fmt.format(value) for value in row]
    print(f"{rowLabels[i]}\t" + "\t".join(formattedRow))
  print()



def vecLabels(vector, labels, fmt="{:> .2f}"):
  # Labels and prints a vector
  matLabels(Mat([vector]), [""], labels, fmt)



def graph(title, labels, matrix, isLinear=True):
  mat = matrix.data
  g = graphviz.Digraph(format='png')
  g.attr(rankdir='LR') 
  g.attr(label=title)
  
  # force edges to be in a line
  if isLinear:
    for i in range(len(mat)-1):
      g.edge(labels[i], labels[i+1], style='invis')
  makeEdges(g, labels, mat)
  g.render(filename=TEMP_FNM, view=True, cleanup=True)


def makeEdges(g, labels, mat):
  size = len(mat)
  for i in range(size):
    for j in range(size):
      val = mat[i][j]
      if val != 0:
        g.edge(labels[i], labels[j], label=f'{val:.2f}')


#  --------------------------------------------
if __name__ == "__main__":
  labels = ['0', '1']
  Q = Mat([[0.5, 0.3], [0.2, 0.6]])
  R = Mat([[0.2, 0.0], [0.1, 0.1]])
  N, B, t = absorb(Q, R)
  print("Fundamental Matrix N:\n", N)
  print("Matrix B:\n", B)
  print("Expected Times to Absorption:\n", t)

  print()
  P = Mat([[0.5, 0.5], [0.4, 0.6]])
  Z, pi, r, M = steady(P)
  print("Fundamental Matrix Z:\n", Z)
  print("Stationary Distribution:", pi)
  print("Mean Recurrence Times:", r)
  print("Mean First Passage Times:\n", M)
