from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
    QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
    QDoubleSpinBox, QComboBox, QSpinBox
from PyQt5.QtCore import pyqtSlot

import numpy as np

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from DataProvider import DataProvider
from redetector import detect_spiketimes


class SpikeRedetectGui(QWidget):

    def __init__(self, data_provider: DataProvider):
        super().__init__()
        self.data_provider = data_provider
        self.title = 'Spike Redetection'
        self.left = 10
        self.top = 10
        self.width = 640
        self.height = 400

        self.trial_idx = 0

        self.initUI()

    def initUI(self):
        self.setWindowTitle(self.title)
        self.setGeometry(self.left, self.top, self.width, self.height)

        # Middle:
        middle = QHBoxLayout()

        # Canvas Area for matplotlib figure

        plot_area = QFrame()
        plot_area_layout = QVBoxLayout()
        self.canvas = PlotCanvas(self)
        self.canvas.move(0, 0)
        plot_area_layout.addWidget(self.canvas)

        # plot area buttons
        plot_area_buttons = QFrame()
        plot_area_buttons_layout = QHBoxLayout()
        plot_area_buttons.setLayout(plot_area_buttons_layout)
        plot_area_layout.addWidget(plot_area_buttons)

        button = QPushButton('Button1', self)
        button.setToolTip('A nice button!')
        button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        plot_area_buttons_layout.addWidget(button)

        button = QPushButton('Button2', self)
        button.setToolTip('Another nice button!')
        button.clicked.connect(lambda: self.threshold_spinbox.setValue(2))
        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        plot_area_buttons_layout.addWidget(button)

        button = QPushButton('Button3', self)
        button.setToolTip('Even more nice buttons!')
        button.clicked.connect(lambda: self.threshold_spinbox.setValue(3))
        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        plot_area_buttons_layout.addWidget(button)

        button = QPushButton('Button4', self)
        button.setToolTip('Even more nice buttons!')
        button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
        button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        plot_area_buttons_layout.addWidget(button)

        plot_area.setLayout(plot_area_layout)

        middle.addWidget(plot_area)
        middle.addWidget(QVLine())

        # Side (options) panel
        panel = QFrame()
        panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
        panel.setMaximumWidth(200)
        panel_layout = QVBoxLayout()



        stim_val_label = QLabel("Stimulus value:")
        self.stim_val_box = QComboBox()

        repro_label = QLabel("Repro:")
        panel_layout.addWidget(repro_label)
        self.repro_box = QComboBox()

        for repro in self.data_provider.get_repros():
            self.repro_box.addItem(repro)
        self.repro_box.currentTextChanged.connect(self.repro_change)
        panel_layout.addWidget(self.repro_box)
        panel_layout.addWidget(stim_val_label)
        panel_layout.addWidget(self.stim_val_box)

        trial_label = QLabel("Trial:")
        panel_layout.addWidget(trial_label)
        trial_spinbox = QSpinBox(self)
        trial_spinbox.setValue(1)
        trial_spinbox.setSingleStep(1)
        trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value()))
        panel_layout.addWidget(trial_spinbox)

        filler = QFill(minh=200)
        panel_layout.addWidget(filler)

        self.status_label = QLabel("Done x/15 Stimulus Values")
        panel_layout.addWidget(self.status_label)
        filler = QFill()
        panel_layout.addWidget(filler)

        threshold_label = QLabel("Threshold:")
        panel_layout.addWidget(threshold_label)
        self.threshold_spinbox = QDoubleSpinBox(self)
        self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0])
        self.threshold_spinbox.setSingleStep(0.5)
        self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
        panel_layout.addWidget(self.threshold_spinbox)

        window_label = QLabel("Min window size:")
        panel_layout.addWidget(window_label)
        self.window_spinbox = QSpinBox(self)
        self.window_spinbox.setMaximum(2**21)
        self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1])
        self.window_spinbox.setSingleStep(500)
        self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
        panel_layout.addWidget(self.window_spinbox)

        step_label = QLabel("step size:")
        panel_layout.addWidget(step_label)
        self.step_spinbox = QSpinBox(self)
        self.step_spinbox.setMaximum(2 ** 21)
        self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2])
        self.step_spinbox.setSingleStep(200)
        self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
        panel_layout.addWidget(self.step_spinbox)

        button = QPushButton('Accept!', self)
        button.setToolTip('Accept the threshold for current stimulus value')
        panel_layout.addWidget(button)

        panel.setLayout(panel_layout)
        middle.addWidget(panel)

        self.setLayout(middle)
        self.show()

    @pyqtSlot()
    def redetection_changed(self):
        redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
        self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)

    @pyqtSlot()
    def trial_change(self, new_trial_idx):
        # TODO test if in range of trials!
        self.trial_idx = new_trial_idx
        redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
        self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)

    @pyqtSlot()
    def repro_change(self):
        repro = self.repro_box.currentText()
        self.stim_val_box.clear()

        for val in self.data_provider.get_stim_values(repro):
            self.stim_val_box.addItem(str(val))

        redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
        self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)


class PlotCanvas(FigureCanvas):

    def __init__(self, parent=None, dpi=100):
        self.fig = Figure(dpi=dpi)
        self.axes = self.fig.add_subplot(111)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QSizePolicy.Expanding,
                                   QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        self.mouse_button_pressed = False
        self.mouse_button = "-1"

        self.mouse_position_start = (-1, -1)
        self.start_limits = ((0, 0), (0, 0))

        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.fig.canvas.mpl_connect('button_release_event', self.release)
        self.fig.canvas.mpl_connect('motion_notify_event', self.moved)
        self.fig.canvas.mpl_connect('scroll_event', self.scrolled)

    # XY position 0,0 at the bottom left!
    # XY positions change when rescaling the window size
    # zoom depend on percentage of total x/y length

    def scrolled(self, event):
        # event.button = "up" / "down"
        print('Scrolled: ', event)

    def moved(self, event):
        if not self.mouse_button_pressed or event.xdata is None:
            return

        if self.mouse_button == 1:
            new_x = event.x
            new_y = event.y

            # mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
            mouse_data_new = self.axes.transData.inverted().transform((new_x, new_y))

            diff_x = mouse_data_new[0] - self.mouse_data_start[0]
            diff_y = mouse_data_new[1] - self.mouse_data_start[1]

            self.axes.set_xlim(self.axes.get_xlim() - diff_x)
            self.axes.set_ylim(self.axes.get_ylim() - diff_y)

            self.draw()

        elif self.mouse_button == 3:

            zoom_strength = 50  # pixels to half or double the axis limits
            length_x = self.start_limits[0][1] - self.start_limits[0][0]
            length_y = self.start_limits[1][1] - self.start_limits[1][0]

            diff_x = self.mouse_position_start[0] - event.x
            diff_y = self.mouse_position_start[1] - event.y

            factor_x = 2**(diff_x / zoom_strength)
            factor_y = 2**(diff_y / zoom_strength)

            new_length_x = length_x * factor_x
            new_length_y = length_y * factor_y

            new_xlimits = (self.start_limits[0][0] - 0.5*new_length_x + 0.5*length_x, self.start_limits[0][1] + 0.5*new_length_x - 0.5*length_x)
            new_ylimits = (self.start_limits[1][0] - 0.5*new_length_y + 0.5*length_y, self.start_limits[1][1] + 0.5*new_length_y - 0.5*length_y)

            self.axes.set_xlim(new_xlimits)
            self.axes.set_ylim(new_ylimits)
            self.draw()


    def onclick(self, event):

        if event.button in (1, 3):
            self.mouse_button_pressed = True
            self.mouse_button = event.button
            self.mouse_position_start = (event.x, event.y)

            self.mouse_data_start = self.axes.transData.inverted().transform((self.mouse_position_start))
            # print("Figure:", self.mouse_position_start, "Data:", self.mouse_data_start)
            xlim = self.axes.get_xlim()
            ylim = self.axes.get_ylim()
            self.start_limits = (xlim, ylim)

            # print('Cliched: %s click: button=%d, x=%d, y=%d, xdata=%s, ydata=%s' %
            #       ('double' if event.dblclick else 'single', event.button,
            #        event.x, event.y, event.xdata, event.ydata))

    def release(self, event):
        if event.button in (1, 3):
            self.mouse_button_pressed = False

    @pyqtSlot()
    def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars):
        traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
        trace = traces[trial_idx]
        spiketimes = spiketimes[trial_idx]
        recording_times = recording_times
        sampling_interval = data_provider.sampling_interval

        ax = self.axes
        xlim = self.axes.get_xlim()
        ylim = self.axes.get_ylim()
        ax.clear()
        time = np.arange(len(trace)) * sampling_interval - recording_times[0]
        ax.plot(time, trace)
        ax.eventplot(spiketimes, lineoffsets=max(trace) + 1, colors="black")
        redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
        ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red")
        ax.set_title('Trial XYZ')
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        self.draw()


class QHLine(QFrame):
    def __init__(self):
        super(QHLine, self).__init__()
        self.setFrameShape(QFrame.HLine)
        self.setFrameShadow(QFrame.Sunken)


class QVLine(QFrame):
    def __init__(self):
        super(QVLine, self).__init__()
        self.setFrameShape(QFrame.VLine)
        self.setFrameShadow(QFrame.Sunken)


class QFill(QFrame):
    def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0):
        super(QFill, self).__init__()
        self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
        self.setMaximumSize(maxw, maxh)
        self.setMinimumSize(minw, minh)