import ctypes
import time

import uldaq
from IPython import embed
from pyrelacs.repros.repos import Repos
from pyrelacs.util.logging import config_logging
import numpy as np
import matplotlib.pyplot as plt

log = config_logging()


class Attenuator(Repos):
    def __init__(self) -> None:
        super().__init__()

    def check_attenuator(self):
        """
        ident     : attdev-1
        strobepin : 6
        datainpin : 5
        dataoutpin: -1
        cspin     : 4
        mutepin   : 7
        zcenpin   : -1
        """
        t = np.arange(0, DURATION, 1 / SAMPLERATE)
        data = AMPLITUDE * np.sin(2 * np.pi * SINFREQ * t)
        data_channels = np.concatenate((data, data))

        db_values = [0, 0, -2, -5, -10, -20, -50]
        for i, db_value in enumerate(db_values):
            log.info(f"Attenuating the Channels, with {db_value}")
            if i == 1:
                log.info("Muting the Channels")
                self.set_attenuation_level(
                    db_value, db_value, mute_channel1=True, mute_channel2=True
                )
            else:
                self.set_attenuation_level(db_value, db_value)

            stim, ao_device = self.write_analog_dac(
                data_channels,
                [0, 1],
                SAMPLERATE,
                ScanOption=uldaq.ScanOption.EXTTRIGGER,
                Range=uldaq.Range.BIP10VOLTS,
            )
            self.diggital_trigger()

            try:
                ao_device.scan_wait(uldaq.WaitType.WAIT_UNTIL_DONE, 15)
            except uldaq.ul_exception.ULException:
                log.debug("Operation timed out")
                self.write_bit(channel=0, bit=0)
                self.disconnect_dac()
                self.connect_dac()
                self.set_analog_to_zero()
            finally:
                self.write_bit(channel=0, bit=0)
                self.disconnect_dac()
                self.connect_dac()
                self.set_analog_to_zero()

            log.info("Sleeping for 1 second, before next attenuation")
            time.sleep(1)
        self.deactivate_attenuator()

    def set_attenuation_level(
        self,
        db_channel1: float = 5.0,
        db_channel2: float = 5.0,
        mute_channel1: bool = False,
        mute_channel2: bool = False,
    ):
        """
        ident     : attdev-1
        strobepin : 6
        datainpin : 5
        dataoutpin: -1
        cspin     : 4
        mutepin   : 7
        zcenpin   : -1
        """

        self.activate_attenuator()
        hardware_possible_db = np.arange(-95.5, 32.0, 0.5)
        byte_number = np.arange(1, 256)
        byte_number_db1 = byte_number[hardware_possible_db == db_channel1][0]
        binary_db1 = np.binary_repr(byte_number_db1, width=8)
        byte_number_db2 = byte_number[hardware_possible_db == db_channel2][0]
        binary_db2 = np.binary_repr(byte_number_db2, width=8)
        if mute_channel1:
            log.info("Muting channel one")
            binary_db1 = "00000000"
        if mute_channel2:
            log.info("Muting channel one")
            binary_db2 = "00000000"

        channels_db = binary_db1 + binary_db2
        self.write_bit(channel=4, bit=0)
        for b in channels_db:
            self.write_bit(channel=5, bit=int(b))
            time.time_ns()
            self.write_bit(channel=6, bit=1)
            time.time_ns()
            self.write_bit(channel=6, bit=0)
            time.time_ns()
        self.write_bit(channel=4, bit=1)

    def activate_attenuator(self):
        for ch, b in zip([4, 5, 6, 7], [1, 0, 0, 1]):
            self.write_bit(channel=ch, bit=b)

    def deactivate_attenuator(self):
        # mute should be enabled for starting calibration
        self.write_bit(channel=7, bit=0)


if __name__ == "__main__":
    SAMPLERATE = 40_000.0
    DURATION = 5
    AMPLITUDE = 1
    SINFREQ = 1

    att = Attenuator()
    # att.set_attenuation_level(db_channel1=5, db_channel2=5)

    att.check_attenuator()