from IPython import embed
import pyqtgraph as pg
import numpy as np
from scipy.signal import welch, find_peaks
from scipy.integrate import romb


class CalibrationPlot:
    def __init__(self, figure: pg.GraphicsLayoutWidget, nix_file):
        self.figure = figure
        self.nix_file = nix_file

    def plot(self):
        self.figure.setBackground("w")
        self.beat_plot = self.figure.addPlot(row=0, col=0)
        self.power_plot = self.figure.addPlot(row=1, col=0)
        self.beat_plot.addLegend()
        self.power_plot.addLegend()
        # self.power_plot.setLogMode(x=False, y=True)

        block = self.nix_file.blocks[0]
        colors = ["red", "green", "blue", "black", "yellow"]
        for i, (stim, fish) in enumerate(
            zip(list(block.data_arrays)[::2], list(block.data_arrays)[1::2])
        ):
            f_stim, stim_power = welch(
                stim[:],
                fs=40_000.0,
                window="flattop",
                nperseg=100_000,
            )
            stim_power = self.decibel(stim_power)
            stim_max_power_index = np.argmax(stim_power)
            freq_stim = f_stim[stim_max_power_index]

            f_fish, fish_power = welch(
                fish[:],
                fs=40_000.0,
                window="flattop",
                nperseg=100_000,
            )
            fish_power = self.decibel(fish_power)
            fish_max_power_index = np.argmax(fish_power)
            freq_fish = f_fish[fish_max_power_index]

            beat_frequency = np.abs(freq_fish - freq_stim)

            beat = stim[:] + fish[:]
            beat_squared = beat**2

            f, powerspec = welch(
                beat_squared,
                window="flattop",
                fs=40_000.0,
                nperseg=100_000,
            )
            powerspec = self.decibel(powerspec)

            padding = 20
            integration_window = powerspec[
                (f > beat_frequency - padding) & (f < beat_frequency + padding)
            ]

            peaks = find_peaks(powerspec, prominence=40)[0]

            pen = pg.mkPen(colors[i])

            self.beat_plot.plot(
                np.arange(0, len(beat)) / 40_000.0,
                beat,
                pen=pen,
                name=stim.name,
            )
            self.power_plot.plot(f, powerspec, pen=pen, name=stim.name)
            self.power_plot.plot(f[peaks], powerspec[peaks], pen=None, symbol="x")

    def decibel(self, power, ref_power=1.0, min_power=1e-20):
        """Transform power to decibel relative to ref_power.

        \\[ decibel = 10 \\cdot \\log_{10}(power/ref\\_power) \\]
        Power values smaller than `min_power` are set to `-np.inf`.

        Parameters
        ----------
        power: float or array
            Power values, for example from a power spectrum or spectrogram.
        ref_power: float or None or 'peak'
            Reference power for computing decibel.
            If set to `None` or 'peak', the maximum power is used.
        min_power: float
            Power values smaller than `min_power` are set to `-np.inf`.

        Returns
        -------
        decibel_psd: array
            Power values in decibel relative to `ref_power`.
        """
        if np.isscalar(power):
            tmp_power = np.array([power])
            decibel_psd = np.array([power])
        else:
            tmp_power = power
            decibel_psd = power.copy()
        if ref_power is None or ref_power == "peak":
            ref_power = np.max(decibel_psd)
        decibel_psd[tmp_power <= min_power] = float("-inf")
        decibel_psd[tmp_power > min_power] = 10.0 * np.log10(
            decibel_psd[tmp_power > min_power] / ref_power
        )
        if np.isscalar(power):
            return decibel_psd[0]
        else:
            return decibel_psd