from nixview.util.enums import PlotterTypes
from PyQt5.QtWidgets import QGroupBox, QHBoxLayout, QLabel, QPushButton, QSizePolicy, QSlider, QVBoxLayout, QWidget
from PyQt5.QtCore import pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from IPython import embed

from nixview.util.file_handler import FileHandler
from nixview.util.dataview import DataView
from nixview.communicator import communicator


def create_label(item):
    label = ""
    if hasattr(item, "label"):
        label += (item.label if item.label is not None else "")
        if len(label) == 0 and hasattr(item, "name"):
            label += item.name
    if hasattr(item, "unit") and item.unit is not None:
        label += " [%s]" % item.unit
    return label


class MplCanvas(FigureCanvasQTAgg):
    view_changed = pyqtSignal()

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        figure = Figure(figsize=(width, height), dpi=dpi)
        self.axis = figure.add_subplot(111)
        super(MplCanvas, self).__init__(figure)
        self._figure = figure

    #def clear(self):
    #    self.clear()
        
    #@property
    #def figure(self):
    #    return self._fig


class Plotter(MplCanvas):
    def __init__(self, file_handler, item, data_view, parent=None) -> None:
        super().__init__(parent=parent)
        self._file_handler = file_handler
        self._item = item
        self._dataview = data_view

    
    def current_view(self):
        raise NotImplementedError("current_view is not implemented on the current plotter")
    


class EventPlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.sc = None
        self.dim_count = len(data_array.dimensions)
        if xdim == -1:
            self.xdim = guess_best_xdim(self.array)
        elif xdim > 1:
            raise ValueError("EventPlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim

    def plot(self, axis=None):
        if axis is None:
            self.fig = plt.figure(figsize=[5.5, 2.])
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if len(self.array.dimensions) == 1:
            return self.plot_1d()
        else:
            return None

    def plot_1d(self):
        data = self.array[:]
        xlabel = create_label(self.array.dimensions[self.xdim])
        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
            ylabel = create_label(self.array)
        else:
            ylabel = ""
        self.sc = self.axis.scatter(data, np.ones(data.shape))
        self.axis.set_ylim([0.5, 1.5])
        self.axis.set_yticks([1.])
        self.axis.set_yticklabels([])
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        return self.axis


class CategoryPlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.bars = []
        if xdim == -1:
            self.xdim = guess_best_xdim(self.array)
        elif xdim > 2:
            raise ValueError("CategoryPlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim

    def plot(self, axis=None):
        if axis is None:
            self.fig = plt.figure()
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if len(self.array.dimensions) == 1:
            return self.plot_1d()
        elif len(self.array.dimensions) == 2:
            return self.plot_2d()
        else:       
            return None

    def plot_1d(self):
        data = self.array[:]
        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            categories = list(dim.labels)
        else:
            return None
        if categories is None:
            categories = ["Cat-%i" % i for i in range(len(data))]
        ylabel = create_label(self.array)
        if len(categories) == 0:
            raise ValueError("Cannot plot a bar chart without any labels")
        self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
                                       tick_label=categories))
        self.axis.set_ylabel(ylabel)
        return self.axis

    def plot_2d(self):
        data = self.array[:]
        if self.xdim == 1:
            data = data.T

        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            categories = list(dim.labels)
        if len(categories) == 0:
            categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]

        dim = self.array.dimensions[1-self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            series_names = list(dim.labels)
        if len(series_names) == 0:
            series_names = ["Series-%i" % i
                            for i in range(data.shape[1-self.xdim])]

        bar_width = 1/data.shape[1] * 0.75
        for i in range(data.shape[1]):
            x_values = np.arange(data.shape[0]) + i * bar_width
            self.bars.append(self.axis.bar(x_values, data[:, i],
                                           width=bar_width,
                                           align="center")[0])
        self.axis.set_xticks(np.arange(data.shape[0]) +
                             data.shape[1] * bar_width/2)
        self.axis.set_xticklabels(categories)
        self.axis.legend(self.bars, series_names, loc=1)
        return self.axis


class ImagePlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.image = None

    def plot(self, axis=None):
        dim_count = len(self.array.dimensions)
        if axis is None:
            self.fig = plt.figure()
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if dim_count == 2:
            return self.plot_2d()
        elif dim_count == 3:
            return self.plot_3d()
        else:
            return None

    def plot_2d(self):
        data = self.array[:]
        x = self.array.dimensions[0].axis(data.shape[0])
        y = self.array.dimensions[1].axis(data.shape[1])
        xlabel = create_label(self.array.dimensions[0])
        ylabel = create_label(self.array.dimensions[1])
        self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        self.axis.set
        return self.axis

    def plot_3d(self):
        if self.array.shape[2] > 3:
            print("cannot plot 3d data with more than 3 channels "
                  "in the third dim")
            return None
        return self.plot_2d()


class LinePlotter(Plotter):

    def __init__(self, file_handler, item, data_view, xdim=-1, parent=None):
        super().__init__(file_handler, item, data_view, parent)
        
        self.dimensions = self._file_handler.request_dimensions(self._item.block_id, self._item.id)
        self.lines = []
        self.dim_count = len(self._dataview.full_shape)
        if xdim == -1:
            self.xdim = item.best_xdim
        elif xdim > 2:
            raise ValueError("LinePlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim

    def plot(self, maxpoints=100000):
        self.maxpoints = maxpoints
        if self.dim_count > 2:
            return
        if self.dim_count == 1:
            self.plot_array_1d()
        else:
            self.plot_array_2d()

    def __add_slider(self):
        steps = self.array.shape[self.xdim] / self.maxpoints
        slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025])
        self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1.,
                             valstep=0.25)
        self.slider.on_changed(self.__update)

    def __update(self, val):
        if len(self.lines) > 0:
            minimum = val * self.maxpoints - self.maxpoints
            start = minimum if minimum > 0 else 0
            end = val * self.maxpoints
            self.__draw(start, end)
        self.fig.canvas.draw_idle()

    def __draw(self, start, end):
        if self.dim_count == 1:
            self.__draw_1d(start, end)
        else:
            self.__draw_2d(start, end)

    def __draw_1d(self, start, end):
        if start < 0:
            start = 0
        if end > self._dataview.current_shape[self.xdim]:
            end = self._dataview.current_shape[self.xdim]

        y = self._dataview._buffer[int(start):int(end)]
        x = self._file_handler.request_axis(self._item.block_id, self._item.id, 0, len(y), start) 

        if len(self.lines) == 0:
            l, = self.axis.plot(x, y, label=self._item.name)
            self.lines.append(l)
        else:
            self.lines[0].set_ydata(y)
            self.lines[0].set_xdata(x)

        self.axis.set_xlim([x[0], x[-1]])

    def __draw_2d(self, start, end):
        if start < 0:
            start = 0
        if end > self._dataview.current_shape[self.xdim]:
            end = self._dataview.current_shape[self.xdim]

        x = self._file_handler.request_axis(self._item.block_id, self._item.id, self.xdim, int(end-start), start) 
        line_count = self._dataview.current_shape[1 - self.xdim]
        line_labels = self._file_handler.request_axis(self._item.block_id, self._item.id, 1-self.xdim, line_count, 0)
        
        for i, l in enumerate(line_labels):
            if (self.xdim == 0):
                y = self._dataview._buffer[int(start):int(end), i]
            else:
                y = self._dataview._buffer[i, int(start):int(end)]

            if len(self.lines) <= i:
                ll, = self.axis.plot(x, y, label=l)
                self.lines.append(ll)
            else:
                self.lines[i].set_ydata(y)
                self.lines[i].set_xdata(x)

        self.axis.set_xlim([x[0], x[-1]])

    def plot_array_1d(self):
        self.__draw_1d(0, self.maxpoints)
        xlabel = create_label(self.dimensions[self.xdim])
        ylabel = create_label(self._item)
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        self.view_changed.emit()

    def plot_array_2d(self):
        self.__draw_2d(0, self.maxpoints)
        xlabel = create_label(self.dimensions[self.xdim])
        ylabel = create_label(self._item)
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        self.axis.legend(loc=1)
        self.view_changed.emit()


class PlotContainer(QWidget):
    def __init__(self, parent=None) -> None:
        super().__init__(parent=parent)
        self.setLayout(QVBoxLayout())
        self.plotter = None
    
    def set_plotter(self, plotter):
        if not self.layout().isEmpty():
            self.layout().removeWidget(self.plotter)
        self.layout().addWidget(plotter)
        self.plotter = plotter


class PlotScreen(QWidget):
    close_signal = pyqtSignal()

    def __init__(self, parent=None) -> None:
        super().__init__(parent=parent)
        self._file_handler = FileHandler()

        self.setLayout(QVBoxLayout())
        self._container = PlotContainer(self)
        self.layout().addWidget(self._container)    

        close_btn = QPushButton("close")
        close_btn.clicked.connect(self.on_close)

        self._create_plot_controls()
        self.layout().addWidget(close_btn)
        self._data_view = None

    def _create_plot_controls(self):
        plot_controls = QGroupBox()
        plot_controls.setFlat(True)
        plot_controls.setLayout(QHBoxLayout())
        plot_controls.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        self._zoom_slider = QSlider(Qt.Horizontal)
        self._zoom_slider.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
        self._zoom_slider.setFixedWidth(120)
        self._zoom_slider.setFixedHeight(20)
        self._zoom_slider.setTickPosition(QSlider.TicksBelow)
        self._zoom_slider.setSliderPosition(50)
        self._zoom_slider.setMinimum(0)
        self._zoom_slider.setMaximum(100)
        self._zoom_slider.setTickInterval(25)
        self._zoom_slider.valueChanged.connect(self.on_zoom)

        self._pan_slider = QSlider(Qt.Horizontal)
        self._pan_slider.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
        self._pan_slider.setFixedHeight(20)
        self._pan_slider.setTickPosition(QSlider.TicksBelow)
        self._pan_slider.setSliderPosition(0)
        self._pan_slider.setMinimum(0)
        self._pan_slider.setMaximum(100)
        self._pan_slider.setTickInterval(25)
        self._pan_slider.valueChanged.connect(self.on_pan)

        pl = QLabel("horiz. pos.:")
        pl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
        pl.setMaximumWidth(150)
        zl = QLabel("zoom:")
        zl.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed)
        zl.setMaximumWidth(75)
        plot_controls.layout().addWidget(pl)
        plot_controls.layout().addWidget(self._pan_slider)
        plot_controls.layout().addWidget(zl)
        plot_controls.layout().addWidget(QLabel("+"))
        plot_controls.layout().addWidget(self._zoom_slider)
        plot_controls.layout().addWidget(QLabel("-"))
        self.layout().addWidget(plot_controls)

    def on_close(self):
        self.close_signal.emit()

    def on_zoom(self, new_position):
        print("zoom", new_position)

    def on_pan(self, new_position):
        print("pan", new_position)

    def on_view_changed(self):
        print("view changed!")
        print(self.plotter.current_view())

    def plot(self, item):
        try:
            self._data_view = DataView(item, self._file_handler)
        except ValueError as e:
            communicator.plot_error.emit("error in plotscreen.plot %s" % e)
            return
        if self._data_view is None:
            return
        while not self._data_view.fully_loaded:
            self._data_view.request_more() # TODO this is just a test, needs to be removed
        if item.suggested_plotter == PlotterTypes.LinePlotter:
            self.plotter = LinePlotter(self._file_handler, item, self._data_view)
            self.plotter.view_changed.connect(self.on_view_changed)
            self._container.set_plotter(self.plotter)
            self.plotter.plot()