from __future__ import print_function
import os
import datetime
from shutil import copyfile

try:
    import RPi.GPIO as GPIO
except:
    pass
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from time import sleep, time
from os import system
from sys import stdout
from IPython import embed

from uldaq import (get_daq_device_inventory, DaqDevice, AInScanFlag, ScanStatus,
                   ScanOption, create_float_buffer, InterfaceType, AiInputMode)


class plot():
    def __init__(self):
        self.n_rows = None
        self.n_cols = None
        self.max_v = None
        self.channel_handle = []

        self.fig = plt.figure(figsize=(20 / 2.54, 12 / 2.54), facecolor='white')
        self.axs = []
        plt.show(block=False)

    def create_axis(self):
        gs = gridspec.GridSpec(self.n_rows, self.n_cols)
        for x in range(self.n_cols):
            for y in range(self.n_rows):
            # for x in range(self.n_cols):
                ax = plt.subplot(gs[y, x])

                if not y == self.n_rows - 1:
                    ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)

                if not x == 0:
                    ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)

                ax.set_ylim(-self.max_v, self.max_v)

                self.axs.append(ax)
        gs.update(left=0.1, bottom=0.05, top=1, right=1, hspace=0, wspace=0)

    def creat_axis_competition(self):
        gs = gridspec.GridSpec(3, 8, left =0.1, bottom = 0.05, top = 1, right=1)
        pos = [[0, 1],
               [0, 2],
               [0, 3],
               [0, 4],
               [1, 0],
               [1, 1],
               [1, 2],
               [1, 3],
               [1, 4],
               [1, 5],
               [2, 1],
               [2, 2],
               [2, 3],
               [2, 4],
               [1, 7]]

        show_x = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], dtype=bool)
        show_y = np.array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], dtype=bool)
        for enu, p in enumerate(pos):
            ax = self.fig.add_subplot(gs[p[0], p[1]]) # elec 1
            ax.set_ylim(-self.max_v, self.max_v)
            if show_x[enu] == False:
                ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
            if show_y[enu] == False:
                ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)


            self.axs.append(ax)


def GPIO_setup(LED1_pin, LED2_pin, Button1_pin, Button2_pin):
    # LED output pins
    GPIO.setmode(GPIO.BOARD)

    GPIO.setup(LED1_pin, GPIO.OUT)  # 1
    GPIO.output(LED1_pin, GPIO.LOW)
    GPIO.setup(LED2_pin, GPIO.OUT)  # 2
    GPIO.output(LED2_pin, GPIO.LOW)

    LED_status = [False, False]

    # switch controlled input
    GPIO.setup(Button1_pin, GPIO.IN)
    GPIO.setup(Button2_pin, GPIO.IN)

    # GPIO.setup(Button1_pin, GPIO.IN, pull_up_down=GPIO.PUD_DOWN)
    # GPIO.setup(Button2_pin, GPIO.IN, pull_up_down=GPIO.PUD_DOWN)

    return LED_status


def read_cfg(cfg_file, now, init_read=False):
    cfg_f = open(cfg_file, 'r+')
    cfg = cfg_f.readlines()

    ### read cfg information ###
    if init_read:
        for line in cfg:
            if "PathFormat" in line:
                path_format = ':'.join(line.split(':')[1:]).strip().replace('"', '').replace("'", "")
                cfg_f.close()
                return path_format

    for line in cfg:
        if 'Columns1' in line:
            n_cols = int(line.split(':')[1].strip())
        elif 'Rows1' in line:
            n_rows = int(line.split(':')[1].strip())
        elif "AISampleRate" in line:
            samplerate = int(float(line.split(':')[-1].strip().replace('kHz', '')) * 1000)
        elif "AIMaxVolt" in line:
            max_v = float(line.split(':')[1].strip().replace('mV', ''))
        elif 'Gain' in line:
            gain = int(line.split(':')[1].strip())
    channels = n_rows * n_cols

    ### alter information and re-write ###
    for enu, line in enumerate(cfg):
        if "StartDate" in line:
            cfg[enu] = ('          StartDate        : %s\n' % now.strftime('%Y-%m-%d'))
        elif "StartTime" in line:
            cfg[enu] = ('          StartTime        : %s\n' % (now.strftime('%H:%M:%S') + now.strftime(".%f")[:4]))
    cfg_f.close()

    cfg_f = open(cfg_file, 'w+')
    for line in cfg:
        cfg_f.write(line)
    cfg_f.close()

    return channels, samplerate, n_cols, n_rows, max_v, gain

    # for line in cfg:
    #     if 'Columns1' in line:
    #         self.Grid.columns_val = int(line.split(':')[1].strip())
    #     elif 'Rows1' in line:
    #         self.Grid.rows_val = int(line.split(':')[1].strip())
    #     elif "ColumnDistance1" in line:
    #         self.Grid.col_dist_val = float(line.split(':')[-1].strip().replace('cm', ''))
    #     elif "RowDistance1" in line:
    #         self.Grid.row_dist_val = float(line.split(':')[-1].strip().replace('cm', ''))
    #     elif "ChannelOffset1" in line:
    #         self.Grid.channel_offset_val = int(line.split(':')[-1].strip())
    #     elif "ElectrodeType1" in line:
    #         self.Grid.elec_type_val = line.split(':')[-1].strip()
    #     elif "RefElectrodeType1" in line:
    #         self.Grid.ref_elec_type_val = line.split(":")[-1].strip()
    #     elif "RefElectrodePosX1" in line:
    #         self.Grid.ref_elec_posx_val = float(line.split(':')[-1].strip().replace('m', ''))
    #     elif 'RefElectrodePosY1' in line:
    #         self.Grid.ref_elec_posy_val = float(line.split(':')[-1].strip().replace('m', ''))
    #     elif 'WaterDepth1' in line:
    #         self.Grid.water_depth_val = float(line.split(':')[-1].strip().replace('m', ''))
    #
    #     elif "AISampleRate" in line:
    #         self.HardWare.ai_sr_val = float(line.split(':')[-1].strip().replace('kHz', ''))
    #     elif "AIMaxVolt" in line:
    #         self.HardWare.ai_max_vol_val = float(line.split(':')[-1].strip().replace('mV', ''))
    #     elif "AmplName" in line:
    #         self.HardWare.amp_name_val = line.split(':')[-1].strip()
    #     elif "AmplModel" in line:
    #         self.HardWare.amp_model_val = line.split(':')[-1].strip()
    #     elif '  Type  ' in line:
    #         self.HardWare.amp_type_val = line.split(':')[-1].strip()
    #     elif 'Gain' in line:
    #         self.HardWare.gain_val = line.split(':')[-1].strip()
    #     elif "HighpassCutoff" in line:
    #         self.HardWare.highpass_cutoff_val = int(line.split(':')[-1].strip().replace('Hz', ''))
    #     elif 'LowpassCutoff' in line:
    #         self.HardWare.lowpass_cutoff_val = float(line.split(':')[-1].strip().replace('kHz', ''))
    #
    #     elif "Experiment.Name" in line:
    #         self.Recording.experiment_name_val = line.split(':')[-1].strip()
    #     elif "StartDate" in line:
    #         self.Recording.startdate_val = line.split(':')[-1].strip()
    #     elif "StartTime" in line:
    #         self.Recording.starttime_val = ':'.join(line.split(':')[1:]).strip()
    #     elif "Location" in line:
    #         self.Recording.location_val = line.split(':')[-1].strip()
    #     elif "Position" in line:
    #         self.Recording.position_val = line.split(':')[-1].strip()
    #     elif "WaterTemperature" in line:
    #         self.Recording.water_temp_val = float(line.split(':')[-1].strip().replace('C', ''))
    #     elif "WaterConductivity" in line:
    #         self.Recording.water_cond_val = float(line.split(':')[-1].strip().replace('uS/cm', ''))
    #     elif 'WaterpH' in line:
    #         self.Recording.water_ph_val = float(line.split(':')[-1].strip().replace('pH', ''))
    #     elif "WaterOxygen" in line:
    #         self.Recording.water_oxy_val = float(line.split(':')[-1].strip().replace('mg/l', ''))
    #     elif "Comment" in line:
    #         self.Recording.comment_val = ':'.join(line.split(':')[1:]).strip()
    #     elif "Experimenter" in line:
    #         self.Recording.experimenter_val = ':'.join(line.split(':')[1:]).strip()
    #     elif "DataTime" in line:
    #         self.Recording.datatime_val = int(line.split(':')[-1].strip().replace('ms', ''))
    #     elif "DataInterval" in line:
    #         self.Recording.datainterval_val = int(line.split(':')[-1].strip().replace('ms', ''))
    #     elif "BufferTime" in line:
    #         self.Recording.buffertime_val = int(line.split(':')[-1].strip().replace('s', ''))
    #     else:
    #         continue


def main():
    now = datetime.datetime.now()

    # get init cfg
    if os.path.exists('/media/pi/data1'):
        init_path = '/media/pi/data1'
    else:
        init_path = '/home/raab/data/rasp_test'

    init_cfgfile = os.path.join(init_path, 'fishgrid.cfg')
    if os.path.exists(init_cfgfile):
        path_format = read_cfg(init_cfgfile, now, init_read=True)
    else:
        print('cfg file missing !!!')
        quit()

    # read and edit config file
    channels, rate, n_cols, n_rows, max_v, gain = read_cfg(init_cfgfile, now)

    LED1_pin = 11
    LED2_pin = 13
    Button1_pin = 16
    Button2_pin = 18

    # LED_status = GPIO_setup(LED1_pin, LED2_pin, Button1_pin, Button2_pin)


    # DAQ setup
    if True:

        status = ScanStatus.IDLE

        descriptor_index = 0  # ToDo: ????
        range_index = 0  # ToDo: ????

        interface_type = InterfaceType.USB
        low_channel = 0
        high_channel = channels - 1

        samples_per_channel = rate * 2  # * channels = Buffer size
        # rate = 20000
        scan_options = ScanOption.CONTINUOUS
        flags = AInScanFlag.DEFAULT

        # Get descriptors for all of the available DAQ devices.
        devices = get_daq_device_inventory(interface_type)
        number_of_devices = len(devices)
        if number_of_devices == 0:
            raise Exception('Error: No DAQ devices found')

        print('Found', number_of_devices, 'DAQ device(s):')
        for i in range(number_of_devices):
            print('  ', devices[i].product_name, ' (', devices[i].unique_id, ')', sep='')

        # Create the DAQ device object associated with the specified descriptor index.
        daq_device = None
        daq_device = DaqDevice(devices[descriptor_index])

        # Get the AiDevice object and verify that it is valid.
        ai_device = None
        ai_device = daq_device.get_ai_device()
        if ai_device is None:
            raise Exception('Error: The DAQ device does not support analog input')

        # Verify that the specified device supports hardware pacing for analog input.
        ai_info = ai_device.get_info()
        if not ai_info.has_pacer():
            raise Exception('\nError: The specified DAQ device does not support hardware paced analog input')

        # Establish a connection to the DAQ device.
        descriptor = daq_device.get_descriptor()
        print('\nConnecting to', descriptor.dev_string, '- please wait...')
        daq_device.connect()

        # The default input mode is SINGLE_ENDED.
        input_mode = AiInputMode.SINGLE_ENDED
        # If SINGLE_ENDED input mode is not supported, set to DIFFERENTIAL.
        if ai_info.get_num_chans_by_mode(AiInputMode.SINGLE_ENDED) <= 0:
            input_mode = AiInputMode.DIFFERENTIAL

        # Get the number of channels and validate the high channel number.
        number_of_channels = ai_info.get_num_chans_by_mode(input_mode)
        if high_channel >= number_of_channels:
            high_channel = number_of_channels - 1
        channel_count = high_channel - low_channel + 1

        # Get a list of supported ranges and validate the range index.
        ranges = ai_info.get_ranges(input_mode)
        int_ranges = []
        for r in ranges:
            int_ranges.append(int(r.name.replace('BIP', '').replace('VOLTS', '')))

        for idx in np.argsort(int_ranges):
            if max_v * gain / 1000 <= int_ranges[idx]:
                range_index = idx
                break
        print(ranges[range_index])

        # range_index = 0
        # if range_index >= len(ranges):
        #     range_index = len(ranges) - 1

        # Allocate a buffer to receive the data.
        data = create_float_buffer(channel_count, samples_per_channel)
        # system('clear')

        # Start the acquisition.
        rate = ai_device.a_in_scan(low_channel, high_channel, input_mode, ranges[range_index], samples_per_channel,
                                   rate, scan_options, flags, data)
        last_idx = 0
        # f = open('/media/pi/data1/test_file.raw', 'wb')

    # LED on when here ... wait for switch to start data aquisition

    # GPIO.output(LED1_pin, GPIO.HIGH)
    # LED_status[0] = True
    # while GPIO.input(Button1_pin) == GPIO.LOW:
    #     sleep(.1)
    # sleep(2)
    # GPIO.output(LED1_pin, GPIO.LOW)
    # LED_status[0] = False

    LED_t = time()
    LED_t_interval = 2

    disp_eth_power = True

    Plot = plot()
    Plot.max_v = max_v
    Plot.n_rows = n_rows
    Plot.n_cols = n_cols

    #Plot.create_axis()
    Plot.creat_axis_competition()

    init_fig = True
    try:
        while True:
            # blinking LED
            # if time() - LED_t < .1 and LED_status[0] == False:
            #     LED_status[0] = True
            #     GPIO.output(LED1_pin, GPIO.HIGH)
            # if time() - LED_t >= .1 and LED_status[0] == True:
            #     LED_status[0] = False
            #     GPIO.output(LED1_pin, GPIO.LOW)
            # if time() - LED_t >= LED_t_interval:
            #     LED_t = time()
            #
            # # dist & eth0 controll
            # if GPIO.input(Button2_pin) == GPIO.HIGH:
            #     if disp_eth_power == True:
            #         subprocess.run(['tvservice', '-o'])
            #         subprocess.run(['vcgencmd', 'display_power', '0'])
            #
            #         subprocess.run(['sudo', 'ip', 'link', 'set', 'eth0', 'down'])
            #         disp_eth_power = False
            # else:
            #     if disp_eth_power == False:
            #         subprocess.run(['tvservice', '-p'])
            #         subprocess.run(['vcgencmd', 'display_power', '1'])
            #         subprocess.run(['sudo', '/bin/chvt', '6'])
            #         subprocess.run(['sudo', '/bin/chvt', '7'])
            #
            #         subprocess.run(['sudo', 'ip', 'link', 'set', 'eth0', 'up'])
            #         disp_eth_power = True

            # Get the status of the background operation
            status, transfer_status = ai_device.get_scan_status()

            index = transfer_status.current_index

            if (last_idx > index) and (index != -1):
                channel_array = np.arange(channels)
                channel_data = list(map(lambda x : data[x::channels][:250], channel_array))
                channel_std = list(map(lambda x : np.std(data[x::channels][:250]), channel_array))
                power_channel = int(np.argmax(channel_std))

                if init_fig == True:
                    yspan = (np.min(channel_data[power_channel]) / gain, np.max(channel_data[power_channel]) / gain)
                    ylim = (yspan[0] - np.abs(np.diff(yspan)) * 0.2, yspan[1] + np.abs(np.diff(yspan))* 0.2)

                    for ch in channel_array:
                        h, = Plot.axs[ch].plot(np.arange(250)[:len(channel_data[ch])] / rate, np.array(channel_data[ch]) / gain, color='k')
                        Plot.axs[ch].set_ylim(ylim)
                        Plot.channel_handle.append(h)

                    Plot.fig.canvas.draw()

                    init_fig = False
                else:
                    yspan = [np.min(channel_data[power_channel]) / gain, np.max(channel_data[power_channel]) / gain]
                    ylim = [yspan[0] - np.abs(np.diff(yspan)) * 0.2, yspan[1] + np.abs(np.diff(yspan)) * 0.2]
                    for ch in channel_array:
                        Plot.channel_handle[ch].set_data(np.arange(250)[:len(channel_data[ch])] / rate, np.array(channel_data[ch]) / gain)
                        Plot.axs[ch].set_ylim(ylim)
                    Plot.fig.canvas.draw()

            if index == -1:
                last_idx = len(data)
            else:
                last_idx = index

    except KeyboardInterrupt:
        plt.close()
        pass

    # f.close()
    if daq_device:
        # Stop the acquisition if it is still running.
        if status == ScanStatus.RUNNING:
            ai_device.scan_stop()
        if daq_device.is_connected():
            daq_device.disconnect()
        daq_device.release()

    if disp_eth_power == False:
        subprocess.run(['tvservice', '-p'])
        subprocess.run(['vcgencmd', 'display_power', '1'])
        subprocess.run(['sudo', '/bin/chvt', '6'])
        subprocess.run(['sudo', '/bin/chvt', '7'])

        # subprocess.run(['sudo', 'ip', 'link', 'set', 'eth0', 'up'])
    # GPIO.cleanup()


if __name__ == '__main__':
    main()