from ctypes import Array, c_double
import time
from typing import Union
from IPython import embed
import numpy.typing as npt
import uldaq
import numpy as np

from pyrelacs.util.logging import config_logging

log = config_logging()


class MccDac:
    def __init__(self) -> None:
        devices = uldaq.get_daq_device_inventory(uldaq.InterfaceType.USB)
        log.debug(f"Found daq devices {len(devices)}, connecting to the first one")
        if len(devices) == 0:
            log.error("Did not found daq devices, please connect one")
            exit(1)
        self.daq_device = uldaq.DaqDevice(devices[0])
        try:
            self.daq_device.connect()
        except uldaq.ul_exception.ULException:
            self.disconnect_dac()
            self.connect_dac()
        self.ai_device = self.daq_device.get_ai_device()
        self.ao_device = self.daq_device.get_ao_device()
        self.dio_device = self.daq_device.get_dio_device()
        log.debug("Connected")

    def connect_dac(self):
        devices = uldaq.get_daq_device_inventory(uldaq.InterfaceType.USB)
        log.debug(f"Found daq devices {len(devices)}, connecting to the first one")
        if len(devices) == 0:
            log.error("Did not found daq devices, please connect one")
            exit(1)
        self.daq_device = uldaq.DaqDevice(devices[0])
        self.daq_device.connect()
        self.ai_device = self.daq_device.get_ai_device()
        self.ao_device = self.daq_device.get_ao_device()
        self.dio_device = self.daq_device.get_dio_device()
        log.debug("Connected")

    def read_analog(
        self,
        channels: list[int],
        duration: int,
        samplerate: float,
        AiInputMode: uldaq.AiInputMode = uldaq.AiInputMode.SINGLE_ENDED,
        Range: uldaq.Range = uldaq.Range.BIP10VOLTS,
        ScanOption: uldaq.ScanOption = uldaq.ScanOption.DEFAULTIO,
        AInScanFlag: uldaq.AInScanFlag = uldaq.AInScanFlag.DEFAULT,
    ) -> Array[c_double]:
        assert len(channels) == 2, log.error("You can only provide two channels [0, 1]")

        if channels[0] != channels[1]:
            buffer_len_channels = 2
        else:
            buffer_len_channels = 1

        buffer_len = np.shape(np.arange(0, duration, 1 / samplerate))[0]
        data_analog_input = uldaq.create_float_buffer(buffer_len_channels, buffer_len)

        er = self.ai_device.a_in_scan(
            channels[0],
            channels[1],
            AiInputMode,
            Range,
            buffer_len,
            samplerate,
            ScanOption,
            AInScanFlag,
            data=data_analog_input,
        )

        return data_analog_input

    def write_analog(
        self,
        data: Union[list, npt.NDArray],
        channels: list[int],
        samplerate: float,
        Range: uldaq.Range = uldaq.Range.BIP10VOLTS,
        ScanOption: uldaq.ScanOption = uldaq.ScanOption.DEFAULTIO,
        AOutScanFlag: uldaq.AOutScanFlag = uldaq.AOutScanFlag.DEFAULT,
    ) -> Array[c_double]:
        assert len(channels) == 2, log.error("You can only provide two channels [0, 1]")

        buffer = c_double * len(data)
        data_analog_output = buffer(*data)
        log.debug(f"Created C_double data {data_analog_output}")

        try:
            err = self.ao_device.a_out_scan(
                channels[0],
                channels[1],
                Range,
                int(len(data)),
                samplerate,
                ScanOption,
                AOutScanFlag,
                data_analog_output,
            )
        except Exception as e:
            print(f"{e}")
            self.set_analog_to_zero()
            self.disconnect_dac()

        return data_analog_output

    def set_analog_to_zero(self, channels: list[int] = [0, 1]):
        try:
            err = self.ao_device.a_out_list(
                channels[0],
                channels[1],
                [
                    uldaq.Range.BIP10VOLTS,
                    uldaq.Range.BIP10VOLTS,
                ],
                uldaq.AOutListFlag.DEFAULT,
                [0, 0],
            )
        except Exception as e:
            log.error("f{e}")
            log.error("disconnection dac")
            self.disconnect_dac()

    def diggital_trigger(self) -> None:
        data = self.read_bit(channel=0)
        if data:
            self.write_bit(channel=0, bit=0)
            time.time_ns()
            self.write_bit(channel=0, bit=1)
        else:
            self.write_bit(channel=0, bit=1)

    def write_bit(self, channel: int = 0, bit: int = 1) -> None:
        self.dio_device.d_config_bit(
            uldaq.DigitalPortType.AUXPORT, channel, uldaq.DigitalDirection.OUTPUT
        )
        self.dio_device.d_bit_out(
            uldaq.DigitalPortType.AUXPORT, bit_number=channel, data=bit
        )

    def read_bit(self, channel: int = 0):
        bit = self.dio_device.d_bit_in(uldaq.DigitalPortType.AUXPORT, channel)
        return bit

    def read_digitalio(
        self,
        channels: list[int],
        duration,
        samplerate,
        ScanOptions: uldaq.ScanOption = uldaq.ScanOption.DEFAULTIO,
        DInScanFlag: uldaq.DInScanFlag = uldaq.DInScanFlag.DEFAULT,
    ):
        if channels[0] == channels[1]:
            channel_len = 1
        else:
            channel_len = len(channels)

        buffer_len = np.shape(np.arange(0, duration, 1 / samplerate))[0]
        data_digital_input = uldaq.create_int_buffer(channel_len, buffer_len)

        self.dio_device.d_config_port(
            uldaq.DigitalPortType.AUXPORT, uldaq.DigitalDirection.INPUT
        )
        scan_rate = self.dio_device.d_in_scan(
            uldaq.DigitalPortType.AUXPORT0,
            uldaq.DigitalPortType.AUXPORT0,
            len(data_digital_input),
            samplerate,
            ScanOptions,
            DInScanFlag,
            data_digital_input,
        )
        return data_digital_input

    def disconnect_dac(self):
        self.daq_device.disconnect()
        self.daq_device.release()

    def check_attenuator(self):
        """
        ident     : attdev-1
        strobepin : 6
        datainpin : 5
        dataoutpin: -1
        cspin     : 4
        mutepin   : 7
        zcenpin   : -1
        """

        SAMPLERATE = 40_000.0
        DURATION = 5
        AMPLITUDE = 1
        SINFREQ = 1
        t = np.arange(0, DURATION, 1 / SAMPLERATE)
        data = AMPLITUDE * np.sin(2 * np.pi * SINFREQ * t)
        # data_channels = np.concatenate((data, data))

        db_values = [0, 0, -2, -5, -10, -20, -50]
        db_values = [0, -10, -20]
        for i, db_value in enumerate(db_values):
            log.info(f"Attenuating the Channels, with {db_value}")
            if i == 1:
                log.info("Muting the Channels")
                self.set_attenuation_level(
                    db_value, db_value, mute_channel1=True, mute_channel2=True
                )
            else:
                self.set_attenuation_level(db_value, db_value)

            _ = self.write_analog(
                data,
                [0, 0],
                SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
                Range=uldaq.Range.BIP10VOLTS,
            )
            self.diggital_trigger()

            try:
                self.ao_device.scan_wait(uldaq.WaitType.WAIT_UNTIL_DONE, 15)
                self.write_bit(channel=0, bit=0)
                self.set_analog_to_zero()
            except uldaq.ul_exception.ULException:
                log.debug("Operation timed out")
                self.write_bit(channel=0, bit=0)
                self.disconnect_dac()
                self.connect_dac()
                self.set_analog_to_zero()
            finally:
                self.write_bit(channel=0, bit=0)
                self.disconnect_dac()
                self.connect_dac()
                self.set_analog_to_zero()

            log.info("Sleeping for 1 second, before next attenuation")
            time.sleep(1)

    def set_attenuation_level(
        self,
        db_channel1: float = 5.0,
        db_channel2: float = 5.0,
        mute_channel1: bool = False,
        mute_channel2: bool = False,
    ):
        """
        ident     : attdev-1
        strobepin : 6
        datainpin : 5
        dataoutpin: -1
        cspin     : 4
        mutepin   : 7
        zcenpin   : -1
        """

        self.activate_attenuator()
        hardware_possible_db = np.arange(-95.5, 32.0, 0.5)
        byte_number = np.arange(1, 256)
        byte_number_db1 = byte_number[hardware_possible_db == db_channel1][0]
        binary_db1 = np.binary_repr(byte_number_db1, width=8)
        byte_number_db2 = byte_number[hardware_possible_db == db_channel2][0]
        binary_db2 = np.binary_repr(byte_number_db2, width=8)
        if mute_channel1:
            log.info("Muting channel one")
            binary_db1 = "00000000"
        if mute_channel2:
            log.info("Muting channel one")
            binary_db2 = "00000000"

        channels_db = binary_db2 + binary_db1
        self.write_bit(channel=4, bit=0)
        for b in channels_db:
            self.write_bit(channel=5, bit=int(b))
            time.time_ns()
            self.write_bit(channel=6, bit=1)
            time.time_ns()
            self.write_bit(channel=6, bit=0)
            time.time_ns()
        self.write_bit(channel=4, bit=1)

    def activate_attenuator(self):
        for ch, b in zip([4, 5, 6, 7], [1, 0, 0, 1]):
            self.write_bit(channel=ch, bit=b)

    def deactivate_attenuator(self):
        # mute should be enabled for starting calibration
        self.write_bit(channel=7, bit=0)