import faulthandler
import time

import nixio as nix
import uldaq
from IPython import embed
import numpy as np
import matplotlib.pyplot as plt

from pyrelacs.devices.mccdaq import MccDaq
from pyrelacs.util.logging import config_logging

log = config_logging()
# for more information on seg faults
faulthandler.enable()


class Calibration(MccDaq):
    def __init__(self) -> None:
        super().__init__()
        self.SAMPLERATE = 40_000.0
        self.DURATION = 5
        self.AMPLITUDE = 1
        self.SINFREQ = 750

    @staticmethod
    def run(nix_file: nix.File):
        calb = Calibration()
        calb.check_beat(nix_file)

    def check_amplitude(self):
        db_values = [0.0, -5.0, -10.0, -20.0, -50.0]
        colors = ["red", "green", "blue", "black", "yellow"]
        self.set_attenuation_level(db_channel1=0.0, db_channel2=0.0)
        # write to ananlog 1
        t = np.arange(0, self.DURATION, 1 / self.SAMPLERATE)
        data = self.AMPLITUDE * np.sin(2 * np.pi * self.SINFREQ * t)
        fig, ax = plt.subplots()

        for i, db_value in enumerate(db_values):
            self.set_attenuation_level(db_channel1=db_value, db_channel2=db_value)
            log.debug(f"{db_value}")

            stim = self.write_analog(
                data,
                [0, 0],
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )

            data_channel_one = self.read_analog(
                [0, 0],
                self.DURATION,
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )
            time.sleep(1)

            log.debug("Starting the Scan")
            self.digital_trigger()

            try:
                self.ao_device.scan_wait(uldaq.WaitType.WAIT_UNTIL_DONE, 15)
                log.debug("Scan finished")
                self.write_bit(channel=0, bit=0)
                time.sleep(1)
                self.set_analog_to_zero()
            except uldaq.ul_exception.ULException:
                log.debug("Operation timed out")
                # reset the diggital trigger
                self.write_bit(channel=0, bit=0)
                time.sleep(1)
                self.set_analog_to_zero()
                self.disconnect_dac()

            if i == 0:
                ax.plot(t, stim, label=f"Input_{db_value}", color=colors[i])
            ax.plot(t, data_channel_one, label=f"Reaout {db_value}", color=colors[i])

        ax.legend()
        plt.show()

        self.disconnect_dac()

    def check_beat(self, nix_file: nix.File):
        self.set_attenuation_level(db_channel1=-10.0, db_channel2=0.0)
        t = np.arange(0, self.DURATION, 1 / self.SAMPLERATE)
        data = self.AMPLITUDE * np.sin(2 * np.pi * self.SINFREQ * t)
        # data = np.concatenate((data, data))
        db_values = [0.0, -5.0, -8.5, -10.0]
        colors = ["red", "blue", "black", "green"]
        colors_in = ["lightcoral", "lightblue", "grey", "lightgreen"]
        block = nix_file.create_block("Calibration", "data")
        # fig, axes = plt.subplots(2, 2, sharex="col")
        for i, db_value in enumerate(db_values):
            self.set_attenuation_level(db_channel1=db_value)
            stim = self.write_analog(
                data,
                [0, 0],
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )
            readout = self.read_analog(
                [0, 1],
                self.DURATION,
                self.SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
            )
            self.digital_trigger()
            log.info(self.ao_device)
            ai_status = uldaq.ScanStatus.RUNNING
            ao_status = uldaq.ScanStatus.RUNNING

            log.debug(
                f"Status Analog_output {ao_status}\n, Status Analog_input {ai_status}"
            )
            while (ai_status != uldaq.ScanStatus.IDLE) and (
                ao_status != uldaq.ScanStatus.IDLE
            ):
                # log.debug("Scanning")
                time.time_ns()
                ai_status = self.ai_device.get_scan_status()[0]
                ao_status = self.ao_device.get_scan_status()[0]

            self.write_bit(channel=0, bit=0)
            log.debug(
                f"Status Analog_output {ao_status}\n, Status Analog_input {ai_status}"
            )

            channel1 = np.array(readout[::2])
            channel2 = np.array(readout[1::2])

            stim_data = block.create_data_array(
                f"stimulus_{db_value}",
                "nix.regular_sampled",
                shape=data.shape,
                data=channel1,
                label="Voltage",
                unit="V",
            )
            stim_data.append_sampled_dimension(
                self.SAMPLERATE,
                label="time",
                unit="s",
            )
            fish_data = block.create_data_array(
                f"fish_{db_value}",
                "Array",
                shape=data.shape,
                data=channel2,
                label="Voltage",
                unit="V",
            )
            fish_data.append_sampled_dimension(
                self.SAMPLERATE,
                label="time",
                unit="s",
            )

        self.set_analog_to_zero()


def decibel(power, ref_power=1.0, min_power=1e-20):
    """Transform power to decibel relative to ref_power.

    \\[ decibel = 10 \\cdot \\log_{10}(power/ref\\_power) \\]
    Power values smaller than `min_power` are set to `-np.inf`.

    Parameters
    ----------
    power: float or array
        Power values, for example from a power spectrum or spectrogram.
    ref_power: float or None or 'peak'
        Reference power for computing decibel.
        If set to `None` or 'peak', the maximum power is used.
    min_power: float
        Power values smaller than `min_power` are set to `-np.inf`.

    Returns
    -------
    decibel_psd: array
        Power values in decibel relative to `ref_power`.
    """
    if np.isscalar(power):
        tmp_power = np.array([power])
        decibel_psd = np.array([power])
    else:
        tmp_power = power
        decibel_psd = power.copy()
    if ref_power is None or ref_power == "peak":
        ref_power = np.max(decibel_psd)
    decibel_psd[tmp_power <= min_power] = float("-inf")
    decibel_psd[tmp_power > min_power] = 10.0 * np.log10(
        decibel_psd[tmp_power > min_power] / ref_power
    )
    if np.isscalar(power):
        return decibel_psd[0]
    else:
        return decibel_psd