
# sweep.py

import math 
import matplotlib.pyplot as plt

import heapq
from treeset import TreeSet


EPS = 1e-8

# ---------------------------------------------------

class EvPoint():
  '''
    An endpoint of a segment or an intersection point for two segments.
  '''

  def __init__(self, id=None, x=None, y=None, 
                          ptType=None, segment=None):
    self.id = id
    self.x = x
    self.y = y
    self.ptType = ptType # left, right, or intersect
    self.segment = segment


  def update(self, other):
    self.id = other.id
    self.x = other.x
    self.y = other.y
    self.ptType = other.ptType
    self.segment = other.segment


  def get(self):
    return (self.x, self.y)


  def __str__(self):
    return f"(id:{self.id}, ({self.x:.1f},{self.y:.1f}), type:{self.ptType})"

  def __hash__(self):
    return hash((self.x, self.y))


  def __eq__(self, other):
    if other == None:
      return False
    return self.id == other.id

  def __lt__(self, other):
    if self.x < other.x:
      return True
    elif self.x > other.x:
      return False
    elif self.y < self.y:
      return True
    elif self.y > other.y:
      return False

    if self.ptType == "intersect":
      return True
    elif self.ptType == "left":
      return other.ptType == "left"
    else:
      return other.ptType != "intersect"


  def __gt__(self, other):
    return not self.__lt__(other)


  def draw(self, color="b", label=None):
    plt.scatter(self.x, self.y, color=color)
    if label != None:
      ax = plt.gca()
      ax.annotate(" "+label, (self.x, self.y), fontsize=12)




# ---------------------------------------------------

class Segment():
  '''
    A line between two points p and q.
    p is the left-most point
  '''

  def __init__(self, p, q, sweep=None):
    if p < q:
      self.p = p
      self.q = q
    else:
      self.p = q
      self.q = p
    self.p.ptType = "left"
    self.q.ptType = "right"

    self.sweep = sweep
    self.p.segment = self.q.segment = self
    self.slope = self.calcSlope()


  def setP(self, pt):
    self.p = pt
    self.slope = self.calcSlope()


  def setQ(self, pt):
    self.q = pt
    self.slope = self.calcSlope()


  def calcSlope(self):
    if self.p.x != self.q.x:
      return (self.p.y - self.q.y) / (self.p.x - self.q.x)
    else:
      return None


  def getY(self):
    if self.sweep.currEvPt == self.p:
      return self.p.y
    elif self.sweep.currEvPt == self.q:
      return self.q.y
    else:
      return self.p.y


  def get(self):
    return [ self.p.get(), self.q.get() ]

  def xs(self):
    return [self.p.x, self.q.x]

  def ys(self):
    return [self.p.y, self.q.y]


  def __str__(self):
    return f"[{self.p.id}, {self.q.id}]"


  def __eq__(self, other):
    if other == None:
      return False
    return self.p == other.p and self.q == other.q

  def __lt__(self, other):
    if self.getY() == other.getY():
      return self.slope < other.slope
    else:
      return self.getY() < other.getY()


  def __gt__(self, other):
    return not self.__lt__(other) and not self == other



  def __adj__(self, other):
    return self.p == other.p or self.p == other.q or \
           self.q == other.p or self.q == other.q


  def length(self):
    return math.sqrt((self.p.x - self.q.x)**2 + 
                     (self.p.y - self.q.y)**2)


  def ccw(self, p, q, r):
    '''
     https://www.geeksforgeeks.org/orientation-3-ordered-points/amp/
     returns 0, 1, or -1
      0 if p-->r and p-->q are collinear
      1 if p--> r is left (counter-clockwise) of p-->q
     -1 if p--> r is right (clockwise) of p-->q
    '''
    cross = (q.y - p.y) * (r.x - p.x) - \
          (q.x - p.x) * (r.y - p.y)
    if abs(cross) < EPS:
      return 0
    elif cross > 0:
      return 1
    else:
      return -1


  # check if point p lies on line segment
  def onSegment(self, p): 
    return ( (p.x <= max(self.p1.x, self.p2.x)) and 
             (p.x >= min(self.p1.x, self.p2.x)) and
             (p.y <= max(self.p1.y, self.p2.y)) and 
             (p.y >= min(self.p1.y, self.p2.y)))

  def intersects(self, other):
    if self.__adj__(other):
      return False

    o1 = self.ccw(self.p, self.q, other.p)
    o2 = self.ccw(self.p, self.q, other.q)
    o3 = self.ccw(other.p, other.q, self.p)
    o4 = self.ccw(other.p, other.q, self.q)

    if (o1 != o2) and (o3 != o4):
      return True
    return (o1 == 0) and self.onSegment(self.p, self.q, other.p) \
      or (o2 == 0) and self.onSegment(self.p, self.q, other.q) \
      or (o3 == 0) and self.onSegment(other.p, other.q, self.p) \
      or (o4 == 0) and self.onSegment(other.p, other.q, self.q)


  def intersectPt(self, other):
    # Segment AB represented as a1x + b1y = c1
    a1 = self.q.y - self.p.y
    b1 = self.p.x - self.q.x
    c1 = self.p.y * b1 + self.p.x * a1

    # Segment CD represented as a2x + b2y = c2
    a2 = other.q.y - other.p.y
    b2 = other.p.x - other.q.x
    c2 = other.p.y * b2 + other.p.x * a2

    det = a1 * b2 - a2 * b1
    if det != 0:
      x = (b2 * c1 - b1 * c2) / det
      y = (a1 * c2 - a2 * c1) / det
      pt = self.sweep.getPoint(x,y)
      return EvPoint(pt.id, pt.x, pt.y, "intersect", (self, other))

    return None

  def draw(self, label=None):
    plt.plot(self.xs(), self.ys(), "g")
    self.p.draw(label=label)
    self.q.draw()



# ---------------------------------------------------

class Sweeper(object):

  def __init__(self, nodes, edges):
    self.pts = [ EvPoint(i, nodes[i][0], nodes[i][1]) for i in range(len(nodes))]
    self.numOrigNodes = len(self.pts)
    self.currEvPt = EvPoint()
    self.segments = [Segment(self.pts[e[0]], self.pts[e[1]], self) \
                         for e in edges]


  def getPoint(self, x, y):
    target = EvPoint(-1, x, y)
    for nd in self.pts:
      if hash(nd) == hash(target):
        return nd   # found existing point

    # store a new point for (x,y)
    target.id = len(self.pts)
    self.pts.append(target)
    return target


  def sweepLine(self):
    self.currEvPt = EvPoint()
    tree = TreeSet()
    ptsPQ = []
    pushAll(ptsPQ, [seg.p for seg in self.segments])
    pushAll(ptsPQ, [seg.q for seg in self.segments])
    icoords = []

    while ptsPQ != []:
      ipt = None
      self.currEvPt.update(heapq.heappop(ptsPQ))
      currSeg = self.currEvPt.segment

      if self.currEvPt.ptType == 'left':
        low, high = tree.addHighLow(currSeg)
        # print( "Low, Actual, High:", low, currSeg, high)

        if low != None:
          if currSeg.intersects(low):
            ipt = currSeg.intersectPt(low)
            heapq.heappush(ptsPQ, ipt)

        if high != None:
          if currSeg.intersects(high):
            ipt = currSeg.intersectPt(high)
            heapq.heappush(ptsPQ, ipt)

      elif self.currEvPt.ptType == "right":
        low = tree.lower(currSeg)
        high = tree.higher(currSeg)
        if (low != None) and (high != None):
          if low.intersects(high):
            ipt = low.intersectPt(high)
            heapq.heappush(ptsPQ, ipt)
        tree.remove(currSeg)

      elif self.currEvPt.ptType == "intersect":
        # exchange pos of the two intersecting segments
        s1, s2 = currSeg
        print("Intersecting segments:", s1, s2)
        tree.swap(s1, s2)
        old_s1 = s1.p
        old_s2 = s2.p
        s1.setP(self.currEvPt)
        s2.setP(self.currEvPt)

        if s1 is tree.lower(s2):  # ... s1, s2, ...
          low = tree.lower(s1)
          if low is not None:
            if s1.intersects(low):
              ipt = s1.intersectPt(low)
              heapq.heappush(ptsPQ, ipt)

          high = tree.higher(s2)
          if high is not None:
            if s2.intersects(high):
              ipt = s2.intersectPt(high)
              heapq.heappush(ptsPQ, ipt)

        elif s2 is tree.lower(s1):  # ... s2, s1, ...
          high = tree.higher(s1)
          if high != None:
            if s1.intersects(high):
              ipt = s1.intersectPt(high)
              heapq.heappush(ptsPQ, ipt)

          low = tree.lower(s2)
          if low != None:
            if s2.intersects(low):
              ipt = s2.intersectPt(low)
              heapq.heappush(ptsPQ, ipt)

        else:
          print("Intersection point error")
        s1.setP(old_s1)
        s2.setP(old_s2)
        
      else:
        print("Node without ptType")

      if ipt != None:
        icoords.append(ipt.get())

    self.pts = self.pts[:self.numOrigNodes]
                       # discard intersection pts
    return icoords



  def plot(self):
    fig, ax = plt.subplots()

    for seg in self.segments:
      px, py = seg.p.get()
      pid = seg.p.id
      qx, qy = seg.q.get()
      qid = seg.q.id
      seg.draw()
      ax.annotate(str(pid), (px, py), fontsize=12)
      ax.annotate(str(qid), (qx, qy), fontsize=12)

    plt.show()


# ------------------------------------------

def pushAll(pq, elems):
  for elem in elems:
    heapq.heappush(pq, elem)


# --------------- main ---------------------------------

nodes = [(10,10), (40,30), (20,40), (60,10), 
         (30,50), (60,30), (50,10), (70,50)]
edges = [ (0,1), (2,3), (4,5), (6,7)]
sw = Sweeper(nodes, edges)

icoords = sw.sweepLine()
print("No. of Intersections:", len(icoords))
for coord in icoords:
  print(f"  ({coord[0]:.2f}, {coord[1]:.2f})")


sw.plot() 


