
# missionaries.py
# Andrew Davison, Dec 2023, ad@coe.psu.ac.th

# https://en.wikipedia.org/wiki/Missionaries_and_cannibals_problem
'''
On the West bank of a river are three missionaries and three cannibals. 

There is one boat available that can hold up to two people and that they would like to use to cross the river. 

If the cannibals ever outnumber the missionaries on either of the river's banks, the missionaries will get eaten.

How can the boat be used to safely carry all the missionaries and cannibals across the river?

  state = (mWest, cWest, mEast, cEast, boatPos)
    where 
     mWest = west bank missionaries
     cWest = west bank cannibals
     mEast = east bank missionaries
     cEast = east bank cannibals
     boatPos = "west" or "east"
'''

import sys
import searchUtils
import matplotlib.pyplot as plt

MAX_NUM = 3

# problem-specific ---------------

def isGoal(state):
  # are all the missionaries and cannibals on the east bank?
  mWest, cWest, mEast, cEast, boatPos = state
  return (mEast == MAX_NUM) and (cEast == MAX_NUM)


def nextStates(state):
  mWest, cWest, mEast, cEast, boatPos = state
  sts = []
  if boatPos == "west":  # boat on west bank
    nboat = "east"
    if mWest > 1:
      sts.append((mWest-2, cWest, mEast+2, cEast, nboat))
    if mWest > 0:
      sts.append((mWest-1, cWest, mEast+1, cEast, nboat))
    if cWest > 1:
      sts.append((mWest, cWest-2, mEast, cEast+2, nboat))
    if cWest > 0:
      sts.append((mWest, cWest-1, mEast, cEast+1, nboat))
    if (cWest > 0) and (mWest > 0):
      sts.append((mWest-1, cWest-1, mEast+1, cEast+1, nboat))
  else: # boat on east bank
    nboat = "west"
    if mEast > 1:
      sts.append((mWest+2, cWest, mEast-2, cEast, nboat))
    if mEast > 0:
      sts.append((mWest+1, cWest, mEast-1, cEast, nboat))
    if cEast > 1:
      sts.append((mWest, cWest+2, mEast, cEast-2, nboat))
    if cEast > 0:
      sts.append((mWest, cWest+1, mEast, cEast-1, nboat))
    if (cEast > 0) and (mEast > 0):
      sts.append((mWest+1, cWest+1, mEast-1, cEast-1, nboat))

  return [s for s in sts if isValid(s)]


def isValid(state):
  '''
    If the cannibals ever outnumber the missionaries on 
    either of the river's banks, then the missionaries 
    get eaten, which is considered invalid. 
  '''
  mWest, cWest, mEast, cEast, boatPos = state
  if (mWest < cWest) and (mWest > 0):
    return False
  if (mEast < cEast) and (mEast > 0):
    return False
  return True


def printPath(path):
  print("Path length:", len(path))
  prev = path[0]
  printState(prev)

  for state in path[1:]:
    mWest, cWest, mEast, cEast, boatPos = state
    if boatPos == "west":
      print(f"  East -> West: {prev[2]-mEast}m {prev[3]-cEast}c")
    else:
      print(f"  West -> East: {prev[0]-mWest}m {prev[1]-cWest}c")
    printState(state)
    prev = state


def printState(state):
  mWest, cWest, mEast, cEast, boatPos = state
  print(f"West: {mWest}m {cWest}c; East: {mEast}m {cEast}c | Boat: {boatPos}")


# override the imported dummy functions
searchUtils.isGoal = isGoal
searchUtils.nextStates = nextStates


# -------- main -------------


path, numVisited = searchUtils.bfs((MAX_NUM, MAX_NUM, 0, 0, "west"))

print("No. of states visited:", numVisited)
if path == None:
  print("No path found");
  sys.exit()

printPath(path)


# ---------- plot state space  ---------------

fig = plt.figure()

ax = plt.axes(projection ='3d')

# define 3 axes
ms = []
cs = []
bps = []
for state in path:
  mWest, cWest, mEast, cEast, boatPos = state
  if boatPos == "west":
    ms.append(mWest)
    cs.append(cWest)
    bps.append(0)
  else:
    ms.append(mEast)
    cs.append(cEast)
    bps.append(1)

ax.plot3D(ms, cs, bps, 'green')
ax.set_xticks([0,1,2,3])
ax.set_yticks([0,1,2,3])
ax.set_zticks([0,1], ["West", "East"])

ax.set_title('Missionaries and Cannibals')
ax.set_xlabel('missionaries')
ax.set_ylabel('cannibals')
ax.set_zlabel('boat')
plt.show()
