import signal
import sys
import faulthandler
import time

import uldaq
from IPython import embed
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import welch, csd
from scipy.signal import find_peaks

from pyrelacs.devices.mccdac import MccDac
from pyrelacs.util.logging import config_logging

log = config_logging()
faulthandler.enable()


class Calibration(MccDac):
    def __init__(self) -> None:
        super().__init__()
        self.SAMPLERATE = 40_000.0
        self.DURATION = 5
        self.AMPLITUDE = 1
        self.SINFREQ = 750

    def run(self):

    def segfault_handler(self, signum, frame):
        print(f"Segmentation fault caught! Signal number: {signum}")
        self.disconnect_dac()
        sys.exit(1)  # Gracefully exit the program

    def check_amplitude(self):
        db_values = [0.0, -5.0, -10.0, -20.0, -50.0]
        colors = ["red", "green", "blue", "black", "yellow"]
        self.set_attenuation_level(db_channel1=0.0, db_channel2=0.0)
        # write to ananlog 1
        t = np.arange(0, self.DURATION, 1 / self.SAMPLERATE)
        data = self.AMPLITUDE * np.sin(2 * np.pi * self.SINFREQ * t)
        fig, ax = plt.subplots()

        for i, db_value in enumerate(db_values):
            self.set_attenuation_level(db_channel1=db_value, db_channel2=db_value)
            log.debug(f"{db_value}")

            stim = self.write_analog(
                data,
                [0, 0],
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )

            data_channel_one = self.read_analog(
                [0, 0], self.DURATION, self.SAMPLERATE, ScanOption=uldaq.ScanOption.EXTTRIGGER
            )
            time.sleep(1)

            log.debug("Starting the Scan")
            self.diggital_trigger()

            try:
                self.ao_device.scan_wait(uldaq.WaitType.WAIT_UNTIL_DONE, 15)
                log.debug("Scan finished")
                self.write_bit(channel=0, bit=0)
                time.sleep(1)
                self.set_analog_to_zero()
            except uldaq.ul_exception.ULException:
                log.debug("Operation timed out")
                # reset the diggital trigger
                self.write_bit(channel=0, bit=0)
                time.sleep(1)
                self.set_analog_to_zero()
                self.disconnect_dac()

            if i == 0:
                ax.plot(t, stim, label=f"Input_{db_value}", color=colors[i])
            ax.plot(t, data_channel_one, label=f"Reaout {db_value}", color=colors[i])

        ax.legend()
        plt.show()

        self.disconnect_dac()

    def check_beat(self):
        self.set_attenuation_level(db_channel1=-10.0, db_channel2=0.0)
        t = np.arange(0, self.DURATION, 1 / self.SAMPLERATE)
        data = self.AMPLITUDE * np.sin(2 * np.pi * self.SINFREQ * t)
        # data = np.concatenate((data, data))
        db_values = [0.0, -5.0, -8.5, -10.0]
        colors = ["red", "blue", "black", "green"]
        colors_in = ["lightcoral", "lightblue", "grey", "lightgreen"]
        fig, axes = plt.subplots(2, 2, sharex="col")
        for i, db_value in enumerate(db_values):
            self.set_attenuation_level(db_channel1=db_value)
            stim = self.write_analog(
                data,
                [0, 0],
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )
            readout = self.read_analog(
                [0, 1],
                self.DURATION,
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )
            self.diggital_trigger()
            signal.signal(signal.SIGSEGV, self.segfault_handler)
            log.info(self.ao_device)
            ai_status = uldaq.ScanStatus.RUNNING
            ao_status = uldaq.ScanStatus.RUNNING

            log.debug(
                f"Status Analog_output {ao_status}\n, Status Analog_input {ai_status}"
            )
            while (ai_status != uldaq.ScanStatus.IDLE) and (
                ao_status != uldaq.ScanStatus.IDLE
            ):
                # log.debug("Scanning")
                time.time_ns()
                ai_status = self.ai_device.get_scan_status()[0]
                ao_status = self.ao_device.get_scan_status()[0]

            self.write_bit(channel=0, bit=0)
            log.debug(
                f"Status Analog_output {ao_status}\n, Status Analog_input {ai_status}"
            )
            channel1 = np.array(readout[::2])
            channel2 = np.array(readout[1::2])
            beat = channel1 + channel2
            beat_square = beat**2

            f, powerspec = welch(beat, fs=self.SAMPLERATE)
            powerspec = decibel(powerspec)

            f_sq, powerspec_sq = welch(beat_square, fs=self.SAMPLERATE)
            powerspec_sq = decibel(powerspec_sq)
            peaks = find_peaks(powerspec_sq, prominence=20)[0]

            f_stim, powerspec_stim = welch(channel1, fs=self.SAMPLERATE)
            powerspec_stim = decibel(powerspec_stim)

            f_in, powerspec_in = welch(channel2, fs=self.SAMPLERATE)
            powerspec_in = decibel(powerspec_in)

            axes[0, 0].plot(
                t,
                channel1,
                label=f"{db_value} Readout Channel0",
                color=colors[i],
            )
            axes[0, 0].plot(
                t,
                channel2,
                label=f"{db_value} Readout Channel1",
                color=colors_in[i],
            )

            axes[0, 1].plot(
                f_stim,
                powerspec_stim,
                label=f"{db_value} powerspec Channel0",
                color=colors[i],
            )
            axes[0, 1].plot(
                f_in,
                powerspec_in,
                label=f"{db_value} powerspec Channel2",
                color=colors_in[i],
            )
            axes[0, 1].set_xlabel("Freq [HZ]")
            axes[0, 1].set_ylabel("dB")

            axes[1, 0].plot(
                t,
                beat,
                label="Beat",
                color=colors[i],
            )
            axes[1, 0].plot(
                t,
                beat**2,
                label="Beat squared",
                color=colors_in[i],
            )
            axes[1, 0].legend()

            axes[1, 1].plot(
                f,
                powerspec,
                color=colors[i],
            )
            axes[1, 1].plot(
                f_sq,
                powerspec_sq,
                color=colors_in[i],
                label=f"dB {db_value}, first peak {np.min(f_sq[peaks])}",
            )
            axes[1, 1].scatter(
                f_sq[peaks],
                powerspec_sq[peaks],
                color="maroon",
            )
            axes[1, 1].set_xlabel("Freq [HZ]")
            axes[1, 1].set_ylabel("dB")
            axes[0, 0].legend()
            axes[1, 1].legend()
        plt.show()
        self.set_analog_to_zero()
        self.disconnect_dac()


def decibel(power, ref_power=1.0, min_power=1e-20):
    """Transform power to decibel relative to ref_power.

    \\[ decibel = 10 \\cdot \\log_{10}(power/ref\\_power) \\]
    Power values smaller than `min_power` are set to `-np.inf`.

    Parameters
    ----------
    power: float or array
        Power values, for example from a power spectrum or spectrogram.
    ref_power: float or None or 'peak'
        Reference power for computing decibel.
        If set to `None` or 'peak', the maximum power is used.
    min_power: float
        Power values smaller than `min_power` are set to `-np.inf`.

    Returns
    -------
    decibel_psd: array
        Power values in decibel relative to `ref_power`.
    """
    if np.isscalar(power):
        tmp_power = np.array([power])
        decibel_psd = np.array([power])
    else:
        tmp_power = power
        decibel_psd = power.copy()
    if ref_power is None or ref_power == "peak":
        ref_power = np.max(decibel_psd)
    decibel_psd[tmp_power <= min_power] = float("-inf")
    decibel_psd[tmp_power > min_power] = 10.0 * np.log10(
        decibel_psd[tmp_power > min_power] / ref_power
    )
    if np.isscalar(power):
        return decibel_psd[0]
    else:
        return decibel_psd