
# searchUtils.py
# Andrew Davison, Dec 2023, ad@coe.psu.ac.th

'''
 algorithms: 
    * BFS
    * AStar
    * DFS using a stack
    * DFS using recursion with a depth limit

  All of the algorithms store node objects, and make use of
  'dummy' functions that should be redefined in the code for
  the problem:
    isGoal(s)
    nextStates(s)
    goalDist(s)  # only used by AStar
'''

import heapq


class node:
  def __init__(self, state, parent, cost=0):
    self.state = state
    self.parent = parent  # parent Node or None
    self.cost = cost

  def setCost(self):
    self.cost = len(self.getPath())-1 + goalDist(self.state)
    # no. of moves made so far + distance to goal

  def __lt__(self, nd):
    # only used by aStar()
    # priority queue of nodes is ordered based on cost
    return self.cost < nd.cost

  def getPath(self):
    path = [self.state]
    par = self.parent
    while par != None:
      path.insert(0, par.state)  # generate in reverse order
      par = par.parent
    return path


# ---------------------------------


def bfs(startState):
  queue = [ node(startState, None)]   # no parent
  visited = []

  while queue != []:
    currNode = queue.pop(0)   # remove from start
    visited.append(currNode.state)

    if isGoal(currNode.state):
      return currNode.getPath(), len(visited)

    for s in nextStates(currNode.state):
      if not (s in queue) and not (s in visited):
        queue.append( node(s, currNode)) 

  return None, len(visited)


# ------------------------------------


def aStar(startState):
  priQueue = []
  startNode = node(startState, None)
  startNode.setCost()
  heapq.heappush(priQueue, startNode)
  visited = []

  while priQueue != []:
    minNode = heapq.heappop(priQueue)
    visited.append(minNode.state)

    if isGoal(minNode.state):
      return minNode.getPath(), len(visited)

    for s in nextStates(minNode.state):
      if not (s in visited):
        # allow duplicate states to be inserted
        # since they are sorted by cost
        nd = node(s, minNode)
        nd.setCost()
        heapq.heappush(priQueue, nd)

  return None, len(visited)


# -------------------------------------------


def dfsIter(state):
  stack = [node(state, None)]    # no parent
  visited = []

  while stack != []:
    currNode = stack.pop()  # remove from end
    visited.append(currNode.state)

    if isGoal(currNode.state):
      return currNode.getPath(), len(visited)

    for s in reversed(nextStates(currNode.state)):  # reversed for stack order
      if not (s in stack) and not (s in visited):
        stack.append( node(s, currNode)) 

  return None, len(visited)


# ------------------------------------

MAX = 20
visited = []

def dfs(state, max=MAX):
    path = dfsDepth( node(state, None), 1, max)
    numVisited = len(visited)
    return path, numVisited


def dfsDepth(currNode, depth, max):
  if depth > max:  # constrain search depth
    return None

  visited.append(currNode.state)
  if isGoal(currNode.state):
    return currNode.getPath()

  for s in nextStates(currNode.state):
    if not (s in visited):   # no stack to examine
      path = dfsDepth( node(s, currNode), depth+1, max)
      if path != None:
        return path
  return None
    
# ------------------------------------


def isGoal(s):
  print("Redefine isGoal()")
  return False

def nextStates(s):
  print("Redefine nextStates()")
  return []

def goalDist(s):
  # only used by aStar()
  print("Redefine goalDist()")
  return 0
