from PyQt5.QtWidgets import QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from PyQt5.QtCore import QObject, pyqtSignal, Qt
import matplotlib
matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import nixio as nix




def create_label(entity):
    label = ""
    if hasattr(entity, "label"):
        label += (entity.label if entity.label is not None else "")
        if len(label) == 0 and hasattr(entity, "name"):
            label += entity.name
    if hasattr(entity, "unit") and entity.unit is not None:
        label += " [%s]" % entity.unit
    return label


class Plotter(object):

    def show(self):
        plt.show()


class EventPlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.sc = None
        self.dim_count = len(data_array.dimensions)
        if xdim == -1:
            self.xdim = guess_best_xdim(self.array)
        elif xdim > 1:
            raise ValueError("EventPlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim

    def plot(self, axis=None):
        if axis is None:
            self.fig = plt.figure(figsize=[5.5, 2.])
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if len(self.array.dimensions) == 1:
            return self.plot_1d()
        else:
            return None

    def plot_1d(self):
        data = self.array[:]
        xlabel = create_label(self.array.dimensions[self.xdim])
        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Range and not dim.is_alias:
            ylabel = create_label(self.array)
        else:
            ylabel = ""
        self.sc = self.axis.scatter(data, np.ones(data.shape))
        self.axis.set_ylim([0.5, 1.5])
        self.axis.set_yticks([1.])
        self.axis.set_yticklabels([])
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        return self.axis


class CategoryPlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.bars = []
        if xdim == -1:
            self.xdim = guess_best_xdim(self.array)
        elif xdim > 2:
            raise ValueError("CategoryPlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim

    def plot(self, axis=None):
        if axis is None:
            self.fig = plt.figure()
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if len(self.array.dimensions) == 1:
            return self.plot_1d()
        elif len(self.array.dimensions) == 2:
            return self.plot_2d()
        else:       
            return None

    def plot_1d(self):
        data = self.array[:]
        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            categories = list(dim.labels)
        else:
            return None
        if categories is None:
            categories = ["Cat-%i" % i for i in range(len(data))]
        ylabel = create_label(self.array)
        if len(categories) == 0:
            raise ValueError("Cannot plot a bar chart without any labels")
        self.bars.append(self.axis.bar(range(1, len(categories)+1), data,
                                       tick_label=categories))
        self.axis.set_ylabel(ylabel)
        return self.axis

    def plot_2d(self):
        data = self.array[:]
        if self.xdim == 1:
            data = data.T

        dim = self.array.dimensions[self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            categories = list(dim.labels)
        if len(categories) == 0:
            categories = ["Cat-%i" % i for i in range(data.shape[self.xdim])]

        dim = self.array.dimensions[1-self.xdim]
        if dim.dimension_type == nix.DimensionType.Set:
            series_names = list(dim.labels)
        if len(series_names) == 0:
            series_names = ["Series-%i" % i
                            for i in range(data.shape[1-self.xdim])]

        bar_width = 1/data.shape[1] * 0.75
        for i in range(data.shape[1]):
            x_values = np.arange(data.shape[0]) + i * bar_width
            self.bars.append(self.axis.bar(x_values, data[:, i],
                                           width=bar_width,
                                           align="center")[0])
        self.axis.set_xticks(np.arange(data.shape[0]) +
                             data.shape[1] * bar_width/2)
        self.axis.set_xticklabels(categories)
        self.axis.legend(self.bars, series_names, loc=1)
        return self.axis


class ImagePlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.image = None

    def plot(self, axis=None):
        dim_count = len(self.array.dimensions)
        if axis is None:
            self.fig = plt.figure()
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
        else:
            self.fig = axis.figure
            self.axis = axis
        if dim_count == 2:
            return self.plot_2d()
        elif dim_count == 3:
            return self.plot_3d()
        else:
            return None

    def plot_2d(self):
        data = self.array[:]
        x = self.array.dimensions[0].axis(data.shape[0])
        y = self.array.dimensions[1].axis(data.shape[1])
        xlabel = create_label(self.array.dimensions[0])
        ylabel = create_label(self.array.dimensions[1])
        self.image = self.axis.imshow(data, extent=[x[0], x[-1], y[0], y[-1]])
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        self.axis.set
        return self.axis

    def plot_3d(self):
        if self.array.shape[2] > 3:
            print("cannot plot 3d data with more than 3 channels "
                  "in the third dim")
            return None
        return self.plot_2d()


class LinePlotter(Plotter):

    def __init__(self, data_array, xdim=-1):
        self.array = data_array
        self.lines = []
        self.dim_count = len(data_array.dimensions)
        if xdim == -1:
            self.xdim = guess_best_xdim(self.array)
        elif xdim > 2:
            raise ValueError("LinePlotter: xdim is larger than 2! "
                             "Cannot plot that kind of data")
        else:
            self.xdim = xdim
        self.fig = None
        self.axis = None

    def plot(self, axis=None, maxpoints=100000):
        self.maxpoints = maxpoints
        if axis is None:
            self.fig = plt.figure()
            self.axis = self.fig.add_axes([0.15, .2, 0.8, 0.75])
            self.axis.set_title(self.array.name)
            self.__add_slider()
        else:
            self.axis = axis

        dim_count = len(self.array.dimensions)
        if dim_count > 2:
            return
        if dim_count == 1:
            return self.plot_array_1d()
        else:
            return self.plot_array_2d()

    def __add_slider(self):
        steps = self.array.shape[self.xdim] / self.maxpoints
        slider_ax = self.fig.add_axes([0.15, 0.025, 0.8, 0.025])
        self.slider = Slider(slider_ax, 'Slider', 1., steps, valinit=1.,
                             valstep=0.25)
        self.slider.on_changed(self.__update)

    def __update(self, val):
        if len(self.lines) > 0:
            minimum = val * self.maxpoints - self.maxpoints
            start = minimum if minimum > 0 else 0
            end = val * self.maxpoints
            self.__draw(start, end)
        self.fig.canvas.draw_idle()

    def __draw(self, start, end):
        if self.dim_count == 1:
            self.__draw_1d(start, end)
        else:
            self.__draw_2d(start, end)

    def __draw_1d(self, start, end):
        if start < 0:
            start = 0
        if end > self.array.shape[self.xdim]:
            end = self.array.shape[self.xdim]

        y = self.array[int(start):int(end)]
        dim = self.array.dimensions[self.xdim]
        x = np.asarray(dim.axis(len(y), int(start)))

        if len(self.lines) == 0:
            l, = self.axis.plot(x, y, label=self.array.name)
            self.lines.append(l)
        else:
            self.lines[0].set_ydata(y)
            self.lines[0].set_xdata(x)

        self.axis.set_xlim([x[0], x[-1]])

    def __draw_2d(self, start, end):
        if start < 0:
            start = 0
        if end > self.array.shape[self.xdim]:
            end = self.array.shape[self.xdim]

        x_dimension = self.array.dimensions[self.xdim]
        x = np.asarray(x_dimension.axis(int(end-start), start))
        y_dimension = self.array.dimensions[1-self.xdim]
        labels = y_dimension.labels
        if len(labels) == 0:
            labels = list(map(str, range(self.array.shape[1-self.xdim])))

        for i, l in enumerate(labels):
            if (self.xdim == 0):
                y = self.array[int(start):int(end), i]
            else:
                y = self.array[i, int(start):int(end)]

            if len(self.lines) <= i:
                ll, = self.axis.plot(x, y, label=l)
                self.lines.append(ll)
            else:
                self.lines[i].set_ydata(y)
                self.lines[i].set_xdata(x)

        self.axis.set_xlim([x[0], x[-1]])

    def plot_array_1d(self):
        self.__draw_1d(0, self.maxpoints)
        xlabel = create_label(self.array.dimensions[self.xdim])
        ylabel = create_label(self.array)
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        return self.axis

    def plot_array_2d(self):
        self.__draw_2d(0, self.maxpoints)
        xlabel = create_label(self.array.dimensions[self.xdim])
        ylabel = create_label(self.array)
        self.axis.set_xlabel(xlabel)
        self.axis.set_ylabel(ylabel)
        self.axis.legend(loc=1)
        return self.axis


class MplCanvas(FigureCanvasQTAgg):

    def __init__(self, parent=None, width=5, height=4, dpi=100):
        fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = fig.add_subplot(111)
        super(MplCanvas, self).__init__(fig)


class PlotScreen(QWidget):
    close_signal = pyqtSignal()
    
    def __init__(self, parent) -> None:
        super().__init__(parent=parent)
        sc = MplCanvas(self, width=5, height=4, dpi=100)
        sc.axes.plot([0,1,2,3,4], [10,1,20,3,40])
        
        self.setLayout(QVBoxLayout())
        self.layout().addWidget(sc)
        
        close_btn = QPushButton("close")
        close_btn.clicked.connect(self.on_close)
        
        self.layout().addWidget(close_btn)
    
    def on_close(self):
        self.close_signal.emit()
        
    def plot(self, item):
        print("plot!", item)