from math import sqrt

import io
from pyqtgraph import PlotWidget, plot
from pyqtgraph.Qt import QtGui, QtCore, QtWidgets
import numpy as np
import pyqtgraph as pg
import pyqtgraph.exporters
from PlotServer import PlotServer
import datetime 


def infinity():
    return float('inf')


def pointToPointDistanceSq(p1, p2) -> float:
    return ((p2.x() - p1.x()) * (p2.x() - p1.x()) +
            (p2.y() - p1.y()) * (p2.y() - p1.y()))


def pointToPointDistance(p1, p2) -> float:
    return sqrt(pointToPointDistanceSq(p1, p2))


class Crosshair:
    def __init__(self, curveItem, plotItem):
        self.curveItem = curveItem
        self.plotItem = plotItem
        self.lines = [pg.InfiniteLine(angle=90, movable=False),
                      pg.InfiniteLine(angle=0, movable=False)]
        for line in self.lines:
            plotItem.addItem(line, ignoreBounds=True)

        self.curvePoint = None
        self.textItem = None

        self.setVisible(False)

    def setVisible(self, visible: bool):
        for line in self.lines:
            line.setVisible(visible)

        if self._createCurvePoint():
            self.curvePoint.setVisible(visible)
            self.textItem.setVisible(visible)

    @property
    def isVisible(self):
        return self.lines[0].isVisible()

    def setPosition(self, position):
        self.lines[0].setPos(position.x())
        self.lines[1].setPos(position.y())

    def setColor(self, pen):
        for line in self.lines:
            line.setPen(pen)

    def updateText(self, position: float, value):
        if not self._createCurvePoint():
            return

        # This is a hack since CurvePoint.event handles 'index' as float
        # but CurvePoint.setIndex will cast the value to an int => snap
        # to data point (which probably isn't desired behavior).
        self.curvePoint.setProperty('index', position)
        xUnit = self.plotItem.getAxis('bottom').labelUnits
        yUnit = self.plotItem.getAxis('left').labelUnits
        text = '<span style="font-size: 14pt;">x = %.6f%s,<br>y = %.6f%s</span>'
        self.textItem.setHtml(text % (value.x(),
                              ' (' + xUnit + ')' if xUnit != '' else '',
                              value.y(),
                              ' (' + yUnit + ')' if yUnit != '' else ''))

    def _createCurvePoint(self):
        if self.curvePoint:
            return True

        # It's not possible to create pg.CurvePoint with and empty set.
        (x, _) = self.curveItem.getData()
        
        # Need to handle first step when data arrays are empty
        # Then item.getData() returns (None, None)
        if not isinstance(x, np.ndarray):
            x = np.array([])
        
        if len(x) == 0:
            return False

        self.curvePoint = pg.CurvePoint(self.curveItem, rotate=False)
        self.plotItem.addItem(self.curvePoint)

        self.textItem = pg.TextItem(text='',
                                    anchor=(0, 1),
                                    border=self.curveItem.opts['pen'],
                                    fill=(255, 255, 255, 64))
        self.textItem.setParentItem(self.curvePoint)

        self.curvePoint.setVisible(self.isVisible)
        self.textItem.setVisible(self.isVisible)

        return True


class Curve:
    def __init__(self, item, plotItem, filename):

        self.item = item
        self.crosshair = Crosshair(self.item, plotItem)
        if filename is not None:
            self.filename = filename
            self.file = io.open(self.filename, mode='w',  buffering=1)
        else:
            self.file = None

    def setData(self, x, y):
        self.item.setData(x, y)

    class ProjectionResult:
        numPoints = 0
        prevIndex = -1
        nextIndex = -1
        prevPoint = (0, 0)
        nextPoint = (0, 0)
        pointOnCurve = None
        dx = 0.0
        dy = 0.0
        curve = None

    def project(self, point, snapToControlPoint: bool = False):
        (x, y) = self.item.getData()
        result = self.ProjectionResult()
        result.curve = self
        if x is None:
            x = np.array([])
            
        result.numPoints = len(x)

        if result.numPoints == 0:
            return None

        result.prevIndex = Curve.findIndex(x, y, point, snapToControlPoint)
        if snapToControlPoint:
            result.nextIndex = result.prevIndex
            result.prevPoint = (x[result.prevIndex], y[result.prevIndex])
            result.nextPoint = result.prevPoint
            result.pointOnCurve = QtCore.QPointF(*result.prevPoint)
            return result

        if result.prevIndex > 0 and point.x() < x[result.prevIndex]:
            result.prevIndex -= 1
        result.nextIndex = result.prevIndex + 1

        result.prevPoint = (x[result.prevIndex], y[result.prevIndex])
        if result.nextIndex >= result.numPoints:
            result.pointOnCurve = QtCore.QPointF(*result.prevPoint)
            return result

        result.nextPoint = (x[result.nextIndex], y[result.nextIndex])

        result.dx = result.nextPoint[0] - result.prevPoint[0]
        result.dy = result.nextPoint[1] - result.prevPoint[1]

        if result.dx == 0.0:
            result.pointOnCurve = QtCore.QPointF(*result.prevPoint)
            return result

        xVal = point.x() if point.x() > result.prevPoint[0] and point.x() < result.nextPoint[0]\
                         else result.prevPoint[0] if point.x() < result.prevPoint[0] else result.nextPoint[0]
        yVal = result.prevPoint[1] + (result.dy / result.dx) * (xVal - result.prevPoint[0])
        result.pointOnCurve = QtCore.QPointF(xVal, yVal)

        return result

    @classmethod
    def findIndex(cls, x, y, point, useXValOnly: bool = False):
        numPoints = len(x)
        if numPoints == 0:
            return -1
        elif numPoints < 2:
            return 0

        if useXValOnly:
            return np.abs(x - point.x()).argmin()

        estimatedSegmentLength = pointToPointDistance(QtCore.QPointF(x[0], y[0]),
                                                       QtCore.QPointF(x[1], y[1]))
        absDiff = np.abs(x - point.x())
        potentialIndices = np.where(absDiff < 2.0 * estimatedSegmentLength)[0]
        if len(potentialIndices) == 0:
            potentialIndices = [absDiff.argmin()]
        elif len(potentialIndices) == 1:
            return potentialIndices[0]

        bestDistance = infinity()
        bestIndex = 0
        for index in potentialIndices:
            p = QtCore.QPointF(x[index], y[index])
            dist = pointToPointDistanceSq(p, point)
            if dist < bestDistance:
                bestDistance = dist
                bestIndex = index

        return bestIndex


class Plot:
    def __init__(self, item):

        self.plotItem = item
        self.curves = {}

        item.addLegend(offset=(0, 0))

        self.mouseMoveProxy = pg.SignalProxy(item.scene().sigMouseMoved,
                                             rateLimit=60,
                                             slot=self.onMouseMove)

    def setEnableCrosshair(self, enable: bool):
        self.crosshairEnabled = enable
        for _, (curve, _, _) in self.curves.items():
            curve.crosshair.setVisible(False)

    def getEnableCrosshair(self) -> bool:
        return self.crosshairEnabled

    def setEnableCrosshairSnap(self, enable: bool):
        self.crosshairSnap = enable

    def getEnableCrosshairSnap(self) -> bool:
        return self.crosshairSnap

    def addCurve(self, name, pen=(255, 0, 255)):
        self.createCurve(name, pen=pen)

    def createCurve(self,
                    name: str,
                    **kwargs):

        if 'pen' not in kwargs:
            kwargs['pen'] = len(self.curves)
        curve = Curve(self.plotItem.plot(list(),
                      list(),
                      name=name,
                      **kwargs),
                      self.plotItem,
                      filename=kwargs.get('filename'))

        data_x = list()
        data_y = list()

        data = ([curve, data_x, data_y])

        self.curves[name] = data

        return curve

    def addData(self, name, x, y):
        if (name in self.curves):
            plot = self.curves[name]

            plot[1].append(x)
            plot[2].append(y)

            plot[0].setData(plot[1],plot[2])
        else:
            print("Non existing plot: " + name)

    def onMouseMove(self, evt):
        position = evt[0]

        for _, (curve, _, _) in self.curves.items():
            curve.crosshair.setVisible(False)

        if not self.getEnableCrosshair() or not self.plotItem.sceneBoundingRect().contains(position):
            return

        mousePoint = self.plotItem.vb.mapSceneToView(position)

        bestDistance = infinity()
        closestProjectionData = None
        for name, (curve, x, y) in self.curves.items():
            projectionData = curve.project(mousePoint, self.getEnableCrosshairSnap())
            if projectionData is None or projectionData.pointOnCurve is None:
                continue

            distance = pointToPointDistance(projectionData.pointOnCurve, mousePoint)
            if distance < bestDistance:
                bestDistance = distance
                closestProjectionData = projectionData

        if closestProjectionData:
            projectionData = closestProjectionData
            curve = projectionData.curve
            curve.crosshair.setVisible(True)
            curve.crosshair.setPosition(projectionData.pointOnCurve)
            curve.crosshair.setColor(curve.item.opts['pen'])
            textPosition = 0.0
            interpolatePoint = projectionData.prevIndex < projectionData.nextIndex and \
                               projectionData.nextIndex < projectionData.numPoints
            snappedToPoint = projectionData.prevIndex < projectionData.numPoints and \
                             projectionData.prevIndex == projectionData.nextIndex
            isOutOfBounds = projectionData.nextIndex >= projectionData.numPoints
            if interpolatePoint:
                textPosition = float(projectionData.prevIndex) +\
                               np.clip((projectionData.pointOnCurve.x() -
                                        projectionData.prevPoint[0]) / projectionData.dx, 0.0, 1.0)\
                               if projectionData.dx > 1.0E-10 else float(projectionData.prevIndex)
            elif snappedToPoint or isOutOfBounds:
                textPosition = float(projectionData.prevIndex)

            curve.crosshair.updateText(position=textPosition, value=projectionData.pointOnCurve)


class PlotWindow(QtWidgets.QMainWindow):
    def __init__(self, *args, **kwargs):
        super(PlotWindow, self).__init__(*args)
        self.graphicsWindow = pg.GraphicsWindow("plotting")
        self.setCentralWidget(self.graphicsWindow)
        self.graphicsWindow.setBackground('w')
        self.setWindowTitle("Plot Server")

        self.resize(1024, 768)
        self.graphs = dict()

        port = 5555
        # Create a server that will communicate with the client
        if 'port' in kwargs:
            port = kwargs.get('port')

        self.plotServer = PlotServer(self, port=port)

        self.timer = QtCore.QTimer()
        self.timer.setInterval(3)
        self.timer.timeout.connect(self.update)
        self.timer.start()

        button = QtWidgets.QPushButton('Save', self)
        button.setToolTip('Save plot to file')
        button.setGeometry(10, 10, 50, 30)
        button.clicked.connect(self.on_click)

        quitButton = QtWidgets.QPushButton('Exit', self)
        quitButton.setToolTip('Exit application')
        quitButton.setGeometry(70, 10, 50, 30)
        quitButton.clicked.connect(self.exit)

        resetButton = QtWidgets.QPushButton('Reset', self)
        resetButton.setToolTip('Reset plot')
        resetButton.setGeometry(70+50+10, 10, 50, 30)
        resetButton.clicked.connect(self.reset)

        self.crossButton = QtWidgets.QCheckBox('Crosshair', self)
        self.crossButton.setChecked(True)
        self.crossButton.setToolTip('Toggle Crosshair')
        self.crossButton.move(190, 10)
        self.crossButton.stateChanged.connect(lambda: self.on_crosshair(self.crossButton))

    def exit(self, event):
        exit(0)

    def on_click(self, event):
        now = datetime.datetime.now()
        dateString = now.strftime("%Y-%m-%d_%H-%M-%S")
        filename = "capture\\plot_"+dateString
        self.save(filename + ".png")
        print("*** Saved to " + filename + ".png")

    def on_crosshair(self, button):
        f = button.isChecked()
        for _, g in self.graphs.items():
            g.setEnableCrosshair(f)

    def save(self, filename):
        exporter = pg.exporters.ImageExporter(self.graphicsWindow.scene())
        for graph in self.graphs:
            item = self.graphs[graph]
            print(item.plotItem)

        exporter.export(filename)

    def reset(self):
        for graph in self.graphs:
            item = self.graphs[graph]
            self.graphicsWindow.removeItem(item.plotItem)

        self.graphs.clear()

    def update(self):
        self.plotServer.checkMessages()

        if (self.plotServer.shouldStop()):
            self.close()

    def add_graph(self, title, new_row=False, xLabel='', xUnit='', yLabel='', yUnit=''):

        if (title in self.graphs):
            print("Graph " + title + " already exists, reusing")
            return self.graphs[title]

        if new_row:
            self.graphicsWindow.nextRow()

        plotItem = self.graphicsWindow.addPlot(title=title)
        plotItem.showGrid(x=True, y=True)
        plotItem.setLabel('right', yLabel, units=yUnit)
        plotItem.setLabel('bottom', xLabel, units=xUnit)

        plot = Plot(plotItem)

        plot.setEnableCrosshair(True)
        plot.setEnableCrosshairSnap(False)

        self.graphs[title] = plot
        return plot

    def createPlot(self,
                   **kwargs):

        labels = kwargs.get('labels', ('x', 'f(x)'))
        if 'labels' in kwargs:
            kwargs.pop('labels')
        item = self.graphicsWindow.addPlot(**kwargs)
        if kwargs.get('showGrid', True):
            item.showGrid(x=True, y=True)

        if kwargs.get('showLegend', True):
            item.addLegend()

        units = kwargs.get('units', ('', ''))
        item.setLabel('bottom', labels[0], units= units[0])
        item.setLabel('left', labels[1], units=units[1])

        plot = Plot(item)
        plot.setEnableCrosshair(kwargs.get('showCrosshair', True))
        plot.setEnableCrosshairSnap(kwargs.get('crosshairSnap', False))

        self.plots.append(plot)

        return plot
