from PyQt6.QtGui import QAction, QIcon, QKeySequence
from PyQt6.QtCore import Qt, QSize, QThreadPool
from PyQt6.QtWidgets import (
    QGridLayout,
    QPushButton,
    QToolBar,
    QWidget,
    QMainWindow,
    QPlainTextEdit,
    QMenuBar,
    QStatusBar,
)
import uldaq
import numpy as np
import nixio as nix
import pyqtgraph as pg

from pathlib import Path as path
from scipy.signal import welch, find_peaks

from pyrelacs.worker import Worker
from pyrelacs.repros.repros import Repro
from pyrelacs.util.logging import config_logging
from pyrelacs.ui.about import AboutDialog

log = config_logging()
_root = path(__file__).parent.parent

from IPython import embed


class PyRelacs(QMainWindow):
    def __init__(self):
        super().__init__()
        # self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextBesideIcon)  # Ensure icons are displayed with text
        self.setWindowTitle("PyRelacs")
        self.beat_plot = pg.PlotWidget()
        self.power_plot = pg.PlotWidget()

        self.threadpool = QThreadPool()
        self.repros = Repro()

        self.text = QPlainTextEdit()
        self.text.setReadOnly(True)

        self.setMenuBar(QMenuBar(self))
        self.setStatusBar(QStatusBar(self))
        self.create_actions()
        self.create_buttons()
        self.create_toolbars()

        layout = QGridLayout()
        layout.addWidget(self.plot_calibration_button, 0, 0)
        layout.addWidget(self.daq_disconnect_button, 0, 1)
        layout.addWidget(self.beat_plot, 2, 0, 1, 2)
        layout.addWidget(self.power_plot, 3, 0, 1, 2)

        self.toolbar = QToolBar("Repros")
        self.addToolBar(self.toolbar)
        self.repros_to_toolbar()

        # self.setFixedSize(QSize(400, 300))
        widget = QWidget()
        widget.setLayout(layout)
        self.setCentralWidget(widget)

        filename = path.joinpath(path.cwd(), "data.nix")
        self.nix_file = nix.File.open(
            str(filename), nix.FileMode.Overwrite
        )

    def create_actions(self):
        self._rlx_exitaction = QAction(QIcon(str(path.joinpath(_root, "icons/exit.png"))), "Exit", self)
        self._rlx_exitaction.setStatusTip("Close relacs")
        self._rlx_exitaction.setShortcut(QKeySequence("Alt+q"))
        self._rlx_exitaction.triggered.connect(self.on_exit)
        
        self._rlx_aboutaction = QAction("about")
        self._rlx_aboutaction.setStatusTip("Show about dialog")
        self._rlx_aboutaction.setEnabled(True)
        self._rlx_aboutaction.triggered.connect(self.on_about)

        self._daq_connectaction = QAction(QIcon(str(path.joinpath(_root, "icons/connect.png"))), "Connect DAQ", self)
        self._daq_connectaction.setStatusTip("Connect to daq device")
        # self._daq_connectaction.setShortcut(QKeySequence("Alt+d"))
        self._daq_connectaction.triggered.connect(self.connect_dac)

        self._daq_disconnectaction = QAction(QIcon(str(path.joinpath(_root, "icons/disconnect.png"))), "Disconnect DAQ", self)
        self._daq_disconnectaction.setStatusTip("Disconnect the DAQ device")
        # self._daq_connectaction.setShortcut(QKeySequence("Alt+d"))
        self._daq_disconnectaction.triggered.connect(self.disconnect_dac)

        self._daq_calibaction = QAction(QIcon(str(path.joinpath(_root, "icons/calibration.png"))), "Plot calibration", self)
        self._daq_calibaction.setStatusTip("Calibrate the attenuator device")
        # self._daq_calibaction.setShortcut(QKeySequence("Alt+d"))
        self._daq_calibaction.triggered.connect(self.plot_calibration)
        self.create_menu()

    def create_menu(self):
        menu = self.menuBar()
        file_menu = menu.addMenu("&File")
        file_menu.addAction(self._rlx_exitaction)
        file_menu.addAction(self._rlx_aboutaction)

        device_menu = menu.addMenu("&DAQ")
        device_menu.addAction(self._daq_connectaction)
        device_menu.addAction(self._daq_disconnectaction)
        device_menu.addSeparator()
        device_menu.addAction(self._daq_calibaction)

        help_menu = menu.addMenu("&Help")
        help_menu.addSeparator()
        # help_menu.addAction(self._help_action)
        self.setMenuBar(menu)

    def create_toolbars(self):
        rlx_toolbar = QToolBar("Relacs")
        rlx_toolbar.addAction(self._rlx_exitaction)
        rlx_toolbar.setIconSize(QSize(24, 24))

        self.addToolBar(Qt.ToolBarArea.TopToolBarArea, rlx_toolbar)
        daq_toolbar = QToolBar("DAQ")
        daq_toolbar.addAction(self._daq_connectaction)
        daq_toolbar.addAction(self._daq_disconnectaction)
        daq_toolbar.addAction(self._daq_calibaction)
        self.addToolBar(Qt.ToolBarArea.TopToolBarArea, daq_toolbar)

        repro_toolbar = QToolBar("Repros")
        repro_names, file_names = self.repros.names_of_repros()
        for rep, fn in zip(repro_names, file_names):
            repro_action = QAction(rep, self)
            repro_action.setStatusTip(rep)
            repro_action.triggered.connect(
                lambda checked, n=rep, f=fn: self.run_repro(n, f)
            )
            repro_toolbar.addAction(repro_action)
        self.addToolBar(Qt.ToolBarArea.TopToolBarArea, repro_toolbar)

    def create_buttons(self):
        self.daq_connect_button = QPushButton("Connect Daq")
        self.daq_connect_button.setCheckable(True)
        self.daq_connect_button.clicked.connect(self.connect_dac)

        self.daq_disconnect_button = QPushButton("Disconnect Daq")
        self.daq_disconnect_button.setCheckable(True)
        self.daq_disconnect_button.clicked.connect(self.disconnect_dac)

        self.plot_calibration_button = QPushButton("Plot Calibration")
        self.plot_calibration_button.setCheckable(True)
        self.plot_calibration_button.clicked.connect(self.plot_calibration)

    def plot_calibration(self):
        def decibel(power, ref_power=1.0, min_power=1e-20):
            """Transform power to decibel relative to ref_power.

            \\[ decibel = 10 \\cdot \\log_{10}(power/ref\\_power) \\]
            Power values smaller than `min_power` are set to `-np.inf`.

            Parameters
            ----------
            power: float or array
                Power values, for example from a power spectrum or spectrogram.
            ref_power: float or None or 'peak'
                Reference power for computing decibel.
                If set to `None` or 'peak', the maximum power is used.
            min_power: float
                Power values smaller than `min_power` are set to `-np.inf`.

            Returns
            -------
            decibel_psd: array
                Power values in decibel relative to `ref_power`.
            """
            if np.isscalar(power):
                tmp_power = np.array([power])
                decibel_psd = np.array([power])
            else:
                tmp_power = power
                decibel_psd = power.copy()
            if ref_power is None or ref_power == "peak":
                ref_power = np.max(decibel_psd)
            decibel_psd[tmp_power <= min_power] = float("-inf")
            decibel_psd[tmp_power > min_power] = 10.0 * np.log10(
                decibel_psd[tmp_power > min_power] / ref_power
            )
            if np.isscalar(power):
                return decibel_psd[0]
            else:
                return decibel_psd

        block = self.nix_file.blocks[0]
        colors = ["red", "green", "blue", "black", "yellow"]
        for i, (stim, fish) in enumerate(
            zip(list(block.data_arrays)[::2], list(block.data_arrays)[1::2])
        ):
            beat = stim[:] + fish[:]
            beat_squared = beat**2

            f, powerspec = welch(beat, fs=40_000.0)
            powerspec = decibel(powerspec)

            f_sq, powerspec_sq = welch(beat_squared, fs=40_000.0)
            powerspec_sq = decibel(powerspec_sq)
            peaks = find_peaks(powerspec_sq, prominence=20)[0]
            pen = pg.mkPen(colors[i])
            self.beat_plot.plot(
                np.arange(0, len(beat)) / 40_000.0,
                beat_squared,
                pen=pen,
                # name=stim.name,
            )
            self.power_plot.plot(f_sq, powerspec_sq, pen=pen)
            self.power_plot.plot(f[peaks], powerspec_sq[peaks], pen=None, symbol="x")

    def connect_dac(self):
        devices = uldaq.get_daq_device_inventory(uldaq.InterfaceType.USB)
        try:
            self.daq_device = uldaq.DaqDevice(devices[0])
            log.debug(f"Found daq devices {len(devices)}, connecting to the first one")
            self.daq_device.connect()
            log.debug("Connected")
        except IndexError:
            log.debug("DAQ is not connected, closing")
            self.on_exit()
        self.daq_connect_button.setDisabled(True)

    def disconnect_dac(self):
        try:
            log.debug(f"{self.daq_device}")
            self.daq_device.disconnect()
            self.daq_device.release()
            log.debug(f"{self.daq_device}")
            self.daq_disconnect_button.setDisabled(True)
            self.daq_connect_button.setEnabled(True)
        except AttributeError:
            log.debug("DAQ was not connected")

    def repros_to_toolbar(self):
        repro_names, file_names = self.repros.names_of_repros()
        for rep, fn in zip(repro_names, file_names):
            individual_repro_button = QAction(rep, self)
            individual_repro_button.setStatusTip("Button")
            individual_repro_button.triggered.connect(
                lambda checked, n=rep, f=fn: self.run_repro(n, f)
            )
            self.toolbar.addAction(individual_repro_button)

    def run_repro(self, n, fn):
        self.text.appendPlainText(f"started Repro {n}, {fn}")
        worker = Worker(self.repros.run_repro, self.nix_file, n, fn)
        worker.signals.result.connect(self.print_output)
        worker.signals.finished.connect(self.thread_complete)
        worker.signals.progress.connect(self.progress_fn)

        self.threadpool.start(worker)

    def on_exit(self):
        print("exit button!")
        self.close()
    
    def on_about(self, e):
        about = AboutDialog(self)
        about.show()

    def print_output(self, s):
        print(s)

    def thread_complete(self):
        print("THREAD COMPLETE!")

    def progress_fn(self, n):
        print("%d%% done" % n)