from .frontend_classes import Dataset, RePro, Stimulus
from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings
import numpy as np
import nixio as nix
from scipy.stats import circstd
# from scipy.optimize import curve_fit
import os
import subprocess
from tqdm import tqdm

from IPython import embed


class BaselineData:
    """
    Class representing the Baseline data that has been recorded within a given Dataset.
    """
    def __init__(self, dataset: Dataset):
        self.__spike_data = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
        for i in tqdm(range(len(self.__repros)), desc="loading data"):
            r = self.__repros[i]
            sd = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.append(sd)
            else:
                continue
            self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))

    def valid(self):
        # fixme implement me!
        pass

    def __read_spike_data(self, r: RePro):
        if self.__dataset.has_nix:
            return self.__read_spike_data_from_nix(r)
        else:
            return self.__read_spike_data_from_directory(r)

    def __read_eod_data(self, r: RePro, duration):
        if self.__dataset.has_nix:
            return self.__read_eod_data_from_nix(r, duration)
        else:
            return self.__read_eod_data_from_directory(r, duration)

    def __get_serial_correlation(self, times, max_lags=50):
        if times is None or len(times) < max_lags:
            return None
        isis = np.diff(times)
        unbiased = isis - np.mean(isis, 0)
        norm = sum(unbiased ** 2)
        a_corr = np.correlate(unbiased, unbiased, "same") / norm
        a_corr = a_corr[int(len(a_corr) / 2):]
        return a_corr[:max_lags]

    def serial_correlation(self, max_lags=50):
        """
            Returns the serial correlation of the interspike intervals.
        @param max_lags: The number of lags to take into account
        @return: the serial correlation as a function of the lag
        """
        scs = []
        for sd in self.__spike_data:
            if sd is None or len(sd) < 100:
                continue
            corr = self.__get_serial_correlation(sd, max_lags=max_lags)
            if corr is not None:
                scs.append(corr)
        return scs

    def circular_std(self):
        circular_stds = []
        for i in range(self.size):
            phases = self.__spike_phases(index=i)
            circular_stds.append(circstd(phases))
        return circular_stds

    def eod_frequency(self):
        eod_frequencies = []
        for i in range(self.size):
            eod, time = self.eod(i)
            xings = zero_crossings(eod, time, interpolate=False)
            eod_frequencies.append(len(xings)/xings[-1])
        return 0.0 if len(eod_frequencies) < 1 else np.mean(eod_frequencies)

    def __spike_phases(self, index=0): # fixme buffer this stuff
        e_times = self.eod_times(index=index)
        eod_period = np.mean(np.diff(e_times))
        phases = np.zeros(len(self.spikes(index)))
        for i, st in enumerate(self.spikes(index)):
            last_eod_index = np.where(e_times <= st)[0]
            if len(last_eod_index) == 0:
                continue
            phases[i] = (st - e_times[last_eod_index[-1]]) / eod_period * 2 * np.pi
        return phases

    def eod_times(self, index=0, interpolate=True):
        if index >= self.size:
            return None
        if len(self.__eod_times) < len(self.__eod_data):
            eod, time = self.eod(index)
            times = zero_crossings(eod, time, interpolate=interpolate)
        else:
            times = self.__eod_times[index]
        return times

    @property
    def dataset(self):
        return self.__dataset

    @property
    def cell(self):
        cells = self.__dataset.cells
        return cells if len(cells) > 1 else cells[0]

    @property
    def subject(self):
        subjects = self.__dataset.subjects
        return subjects if len(subjects) > 1 else subjects[0]

    def spikes(self, index: int=0):
        """Get the spike times of the spikes recorded in the given baseline recording.

        Args:
            index (int, optional): If the baseline activity has been recorded several times, the index can be given. Defaults to 0.

        Returns:
            : [description]
        """
        return self.__spike_data[index] if len(self.__spike_data) >= index else None

    def membrane_voltage(self, index: int=0):
        if index >= self.size:
            raise IndexError("Index %i out of bounds for size %i!" % (index, self.size))
        if not self.__dataset.has_nix:
            print('Sorry, this is not supported for non-nixed datasets. Implement it at '
                  'fishbook.reproclasses.BaselineData.membrane_voltage and send a pull request!')
            return None, None
        else:
            rp = self.__repros[index]
            data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
            f = nix.File.open(data_source, nix.FileMode.ReadOnly)
            b = f.blocks[0]
            t = b.tags[rp.id]
            if not t:
                print("Tag not found!")
            try:
                data = t.retrieve_data("V-1")[:]
                time = np.asarray(t.references["V-1"].dimensions[0].axis(len(data)))
            except:
                data = np.empty()
                time = np.empty()
            f.close()
            return time, data

    def eod(self, index: int=0):
        eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
        time = np.arange(len(eod)) / self.__dataset.samplerate
        return eod, time

    @property
    def burst_index(self):
        bi = []
        for i, sd in enumerate(self.__spike_data):
            if len(sd) < 2:
                continue
            et = self.eod_times(index=i)
            eod_period = np.mean(np.diff(et))
            isis = np.diff(sd)
            bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
        return bi

    @property
    def coefficient_of_variation(self):
        cvs = []
        for d in self.__spike_data:
            isis = np.diff(d)
            cvs.append(np.std(isis)/np.mean(isis))
        return cvs

    @property
    def vector_strength(self):
        vss = []
        spike_phases = []
        for i, sd in enumerate(self.__spike_data):
            phases = self.__spike_phases(i)
            ms_sin_alpha = np.mean(np.sin(phases)) ** 2
            ms_cos_alpha = np.mean(np.cos(phases)) ** 2
            vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
            vss.append(vs)
            spike_phases.append(phases)
        return vss, spike_phases

    @property
    def size(self):
        return len(self.__spike_data)

    def __str__(self):
        return "Baseline data of cell %s " % self.__cell.id

    def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_eod_data_from_directory(r, duration)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("EOD")[:]
        except:
            data = np.empty()
        f.close()
        return data

    def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
        sr = self.__dataset.samplerate
        unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
        eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
        eod = eod[:int(duration * sr)]
        return eod

    def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("Spikes-1")[:]
        except:
            data = None

        f.close()
        if len(data) < 100:
            data = None
        return data

    def __read_spike_data_from_directory(self, r) -> np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r.run
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        if len(data) < 100:
            return None
        return np.asarray(data)

    @staticmethod
    def __do_read(f) -> np.ndarray:
        data = []
        f.readline()
        unit = f.readline().strip("#").strip()
        scale = 0.001 if unit == "ms" else 1
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip())*scale)
            l = f.readline()
        return np.asarray(data)



class FIData:
    """
    Class representing the data recorded with the relacs FI-Curve repro. The instance will load the data upon
    construction which may take a while.
    FI Data offers convenient access to the spike and local EOD data as well as offers conveince methods to get the
    firing rate and also to fit a Boltzmann function to the the FI curve.
    """
    def __init__(self, dataset: Dataset):
        """
        Constructor.

        :param dataset: The dataset entity for which the fi curve repro data should be loaded.
        """
        self.__spike_data = []
        self.__contrasts = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()
        pass

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("FICurve", cell_id=self.__cell.id)
        for r in self.__repros:
            sd, c, eods, time = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.extend(sd)
                if eods:
                    self.__eod_data.extend(eods)
                self.__contrasts.extend(c)
                if time:
                    self.__eod_times.extend(time)
            else:
                continue

    def __read_spike_data(self, repro: RePro):
        """

        :param repro:
        :return: spike data and the respective contrasts
        """
        if self.__dataset.has_nix:
            return self.__read_spikes_from_nix(repro)
        else:
            return self.__read_spikes_from_directory(repro)

    def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
        r_settings = repro.settings.split("\n")
        s_settings = stimulus.settings.split("\n")
        delay = 0.0
        contrast = 0.0
        for s in r_settings:
            if "delay:" in s:
                delay = float(s.split(":")[-1])
                break
        for s in s_settings:
            if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
                contrast = float(s.split(":")[-1])
                break

        start_time = stimulus.start_time - delay
        end_time = stimulus.start_time + stimulus.duration
        eod_da = mtag.references["LocalEOD-1"]
        start_index_eod = eod_da.dimensions[0].index_of(start_time)
        end_index_eod = eod_da.dimensions[0].index_of(end_time)

        local_eod = eod_da[start_index_eod:end_index_eod]
        spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
        time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
        return spikes, local_eod, time, contrast

    def __read_spikes_from_nix(self, repro: RePro):
        spikes = []
        eods = []
        time = []
        contrasts = []
        stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        if len(stimuli) == 0:
            return spikes, contrasts, eods, time
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spikes_from_directory(repro)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        self.__all_spikes = b.data_arrays["Spikes-1"][:]
        mt = None
        for i in tqdm(range(len(stimuli)), desc="Loading data"):
            s = stimuli[i]
            if not mt or mt.id != s.multi_tag_id:
                mt = b.multi_tags[s.multi_tag_id]
            sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)

            spikes.append(sp)
            eods.append(eod)
            time.append(t)
            contrasts.append(c)
        f.close()
        return spikes, contrasts, eods, time

    def __do_read_data_block(self, f, l):
        spikes = []
        while len(l.strip()) > 0 and "#" not in l:
            spikes.append(float(l.strip())/1000)
            l = f.readline()
        return spikes, l

    def __read_spikes_from_directory(self, repro: RePro):
        spikes = []
        contrasts = []
        times = []
        print("Warning! Exact reconstruction of stimulus order not possible for old relacs files!")
        data_source = os.path.join(self.__dataset.data_source, "fispikes1.dat")
        delay = 0.0
        pause = 0.0
        for s in repro.settings.split(", "):
            if "pause" in s:
                s = s.split(":")[-1].strip()
                pause = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
            if "delay" in s:
                s = s.split(":")[-1].strip()
                delay = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
        t_start = -delay
        t_end = pause + delay
        time = np.arange(t_start, t_end, 1./repro.dataset.samplerate)
        if os.path.exists(data_source):
            with open(data_source, 'r') as f:
                line = f.readline()
                fish_intensity = None
                stim_intensity = None
                while line:
                    line = line.strip().lower()
                    if "index" in line:
                        fish_intensity = None
                        stim_intensity = None
                    if "intensity = " in line:
                        if "true intensity = " in line:
                            fish_intensity = float(line.split("=")[-1].strip()[:-2])
                        elif "pre" not in line:
                            stim_intensity = float(line.split("=")[-1].strip()[:-2])
                    if len(line) > 0 and "#" not in line:  # data line
                        sp, line = self.__do_read_data_block(f, line)
                        spikes.append(sp)
                        times.append(time)
                        contrasts.append((stim_intensity/fish_intensity-1)*100)
                        continue
                    line = f.readline()
        return spikes, contrasts, None, times

    @property
    def size(self) -> int:
        """
        The number of recorded trials

        :return: An integer with the number of trials.
        """
        return len(self.__spike_data)

    def spikes(self, index=-1):
        """
        The spike times recorded in the specified trial(s)

        :param index: the index of the trial. Default of -1 indicates that all data should be returned.
        :return:
        """
        if 0 <= index < self.size:
            return self.__spike_data[index]
        else:
            return self.__spike_data

    def eod(self, index=-1):
        """
        The local eod (including the stimulus) measurement of the selected trial(s).

        :param index: the index of the trial. Default of -1 indicates that all data should be returned.
        :return: Either two vectors representing time and the local eod or two lists of such vectors
        """
        if len(self.__eod_data) == 0:
            print("EOD data not available for old-style relacs data.")
            return None, None
        if 0 <= index < self.size:
            return self.__eod_times[index], self.__eod_data[index]
        else:
            return self.__eod_times, self.__eod_data

    def contrast(self, index=-1):
        """
        The stimulus contrast used in the respective trial(s).

        :param index:  the index of the trial. Default of -1 indicates that all data should be returned.
        :return: Either a single scalar representing the contrast, or a list of such scalars, one entry for each trial.
        """
        if 0 <= index < self.size:
            return self.__contrasts[index]
        else:
            return self.__contrasts

    def time_axis(self, index=-1):
        """
        Get the time axis of a single trial or a list of time-vectors for all trials.

        :param index: the index of the trial. Default of -1 indicates that all data should be returned.
        :return: Either a single vector representing time, or a list of such vectors, one for each trial.
        """
        if 0 <= index < self.size:
            return self.__eod_times[index]
        else:
            return self.__eod_times

    def rate(self, index=0, kernel_width=0.005):
        """
        Returns the firing rate for a single trial.

        :param index: The index of the trial. 0 <= index < size
        :param kernel_width: The width of the gaussian kernel in seconds
        :return: tuple of time and rate
        """
        t = self.time_axis(index)
        dt = np.mean(np.diff(t))
        sp = self.spikes(index)
        binary = np.zeros(t.shape)
        spike_indices = ((sp - t[0]) / dt).astype(int)
        binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
        g = gaussian_kernel(kernel_width, dt)
        rate = np.convolve(binary, g, mode='same')
        return t, rate

    def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
        """
        Extracts the average firing rate within a time window from the averaged across trial firing rate.
        The analysis time window is specified by the start_time and end_time parameters. Firing rate is estimated by
        convolution with a Gaussian kernel of a given width. All parameters are given in 's'.

        :param start_time: the start of the analysis window.
        :param end_time: the end of the analysis window.
        :param kernel_width: standard deviation of the Gaussian kernel used for firing rate estimation.
        :return: object of type BoltzmannFit
        """
        contrasts = np.zeros(self.size)
        rates = np.zeros(self.size)
        for i in range(self.size):
            contrasts[i] = np.round(self.contrast(i))
            t, r = self.rate(i, kernel_width)
            rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])

        boltzmann_fit = BoltzmannFit(contrasts, rates)
        return boltzmann_fit


class FileStimulusData:

    def __init__(self, dataset: Dataset):
        """
        Constructor.

        :param dataset: The dataset entity for which the filestimulus repro data should be loaded.
        """
        self.__spike_data = []
        self.__contrasts = []
        self.__stimuli = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("FileStimulus", cell_id=self.__cell.id)
        for r in self.__repros:
            sd, c, stims = self.__read_spike_data_from_nix(r) if self.__dataset.has_nix else self.__read_spike_data_from_directory(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.extend(sd)
                self.__contrasts.extend(c)
                self.__stimuli.extend(stims)
            else:
                continue

    def __do_read_spike_data_from_nix(self, mt: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
        spikes = None
        contrast = 0.0
        stim_file = ""

        r_settings = repro.settings.split("\n")
        s_settings = stimulus.settings.split("\n")
        delay = 0.0
        for s in r_settings:
            if "delay:" in s:
                delay = float(s.split(":")[-1])
                break
        start_time = stimulus.start_time - delay
        end_time = stimulus.start_time + mt.extents[stimulus.index]
        contrast = 0.0 # this is a quick fix!!!
        embed()
        for s in s_settings:
            if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
                contrast = float(s.split(":")[-1])
                break

        return spikes, contrast, stim_file




        local_eod = eod_da[start_index_eod:end_index_eod]
        spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
        time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
        return spikes, local_eod, time, contrast

        return spikes, contrast, stim_file

    def __read_spike_data_from_nix(self, repro: RePro):
        spikes = []
        contrasts = []
        stim_files = []
        stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        if len(stimuli) == 0:
            return spikes, contrasts, stim_files
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(repro)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        self.__all_spikes = b.data_arrays["Spikes-1"][:]
        mt = None
        for i in tqdm(range(len(stimuli)), desc="Loading data"):
            s = stimuli[i]
            if not mt or mt.id != s.multi_tag_id:
                mt = b.multi_tags[s.multi_tag_id]
            sp, c, stim = self.__do_read_spike_data_from_nix(mt, s, repro)
            spikes.append(sp)
            contrasts.append(c)
            stim_files.append(stim)
        f.close()
        return spikes, contrasts, stim_files

    def __read_spike_data_from_directory(self, repro: RePro):
        print("not yet my friend!")
        spikes = []
        contrast = 0.0
        stim = None

        return spikes, contrast, stim

    def read_stimulus(self, index=0):
        pass


if __name__ == "__main__":
    # dataset = Dataset(dataset_id='2011-06-14-ag')
    dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1")
    # dataset = Dataset(dataset_id='2013-04-18-ac')
    fi_curve = FileStimulusData(dataset)
    embed()