from fishbook.fishbook import Dataset, RePro, Stimulus
import numpy as np
import nixio as nix
from scipy.stats import circstd
import os
import subprocess

from IPython import embed


def _zero_crossings(x, t, interpolate=False):
    dt = t[1] - t[0]
    x_shift = np.roll(x, 1)
    x_shift[0] = 0.0
    xings = np.where((x >= 0 ) & (x_shift < 0))[0]
    crossings = np.zeros(len(xings))
    if interpolate:
        for i, tf in enumerate(xings):
            if x[tf] > 0.001:
                m = (x[tf] - x[tf-1])/dt
                crossings[i] = t[tf] - x[tf]/m
            elif x[tf] < -0.001:
                m = (x[tf + 1] - x[tf]) / dt
                crossings[i] = t[tf] - x[tf]/m
            else:
                crossings[i] = t[tf]
    else:
        crossings = t[xings]
    return crossings


def _unzip_if_needed(dataset, tracename='trace-1.raw'):
    file_name = os.path.join(dataset, tracename)
    if os.path.exists(file_name):
        return
    if os.path.exists(file_name + '.gz'):
        print("\tunzip: %s" % tracename)
        subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])


class BaselineData:

    def __init__(self, dataset:Dataset):
        self.__spike_data = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
        for r in self.__repros:
            sd = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.append(sd)
            else:
                continue
            self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))

    def valid(self):
        # fixme implement me!
        pass

    def __read_spike_data(self, r:RePro):
        if self.__dataset.has_nix:
            return self.__read_spike_data_from_nix(r)
        else:
            return self.__read_spike_data_from_directory(r)

    def __read_eod_data(self, r:RePro, duration):
        if self.__dataset.has_nix:
            return self.__read_eod_data_from_nix(r, duration)
        else:
            return self.__read_eod_data_from_directory(r, duration)

    def __get_serial_correlation(self, times, max_lags=50):
        if times is None or len(times) < max_lags:
            return None
        isis = np.diff(times)
        unbiased = isis - np.mean(isis, 0)
        norm = sum(unbiased ** 2)
        a_corr = np.correlate(unbiased, unbiased, "same") / norm
        a_corr = a_corr[int(len(a_corr) / 2):]
        return a_corr[:max_lags]

    def serial_correlation(self, max_lags=50):
        """
            return the serial correlation for the the spike train provided by spike_times.
        @param max_lags: The number of lags to take into account
        @return: the serial correlation as a function of the lag
        """
        scs = []
        for sd in self.__spike_data:
            if sd is None or len(sd) < 100:
                continue
            corr = self.__get_serial_correlation(sd, max_lags=max_lags)
            if corr is not None:
                scs.append(corr)
        return scs

    def circular_std(self):
        cstds = []
        for i in range(self.size):
            phases = self.__spike_phases(index=i)
            cstds.append(circstd(phases))
        return cstds

    def eod_frequency(self):
        eodfs = []
        for i in range(self.size):
            eod, time = self.eod(i)
            xings = _zero_crossings(eod, time, interpolate=False)
            eodfs.append(len(xings)/xings[-1])
        return 0.0 if len(eodfs) < 1 else np.mean(eodfs)

    def __spike_phases(self, index=0): # fixme buffer this stuff
        etimes = self.eod_times(index=index)
        eod_period = np.mean(np.diff(etimes))
        phases = np.zeros(len(self.spikes(index)))
        for i, st in enumerate(self.spikes(index)):
            last_eod_index = np.where(etimes <= st)[0]
            if len(last_eod_index) == 0:
                continue
            phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi
        return phases

    def eod_times(self, index=0, interpolate=True):
        if index >= self.size:
            return None
        if len(self.__eod_times) < len(self.__eod_data):
            eod, time = self.eod(index)
            etimes = _zero_crossings(eod, time, interpolate=interpolate)
        else:
            etimes = self.__eod_times[index]
        return etimes

    @property
    def dataset(self):
        return self.__dataset

    @property
    def cell(self):
        cells = self.__dataset.cells
        return cells if len(cells) > 1 else cells[0]

    @property
    def subject(self):
        subjects = self.__dataset.subjects
        return subjects if len(subjects) > 1 else subjects[0]

    def spikes(self, index:int=0):
        return self.__spike_data[index] if len(self.__spike_data) >= index else None

    def eod(self, index:int=0):
        eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
        time = np.arange(len(eod)) / self.__dataset.samplerate
        return eod, time

    @property
    def burst_index(self):
        bi = []
        for i, sd in enumerate(self.__spike_data):
            if len(sd) < 2:
                continue
            et = self.eod_times(index=i)
            eod_period = np.mean(np.diff(et))
            isis = np.diff(sd)
            bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
        return bi

    @property
    def coefficient_of_variation(self):
        cvs = []
        for d in self.__spike_data:
            isis = np.diff(d)
            cvs.append(np.std(isis)/np.mean(isis))
        return cvs

    @property
    def vector_strength(self):
        vss = []
        spike_phases = []
        for i, sd in enumerate(self.__spike_data):
            phases = self.__spike_phases(i)
            ms_sin_alpha = np.mean(np.sin(phases)) ** 2
            ms_cos_alpha = np.mean(np.cos(phases)) ** 2
            vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
            vss.append(vs)
            spike_phases.append(phases)
        return vss, spike_phases

    @property
    def size(self):
        return len(self.__spike_data)

    def __str__(self):
        str = "Baseline data of cell %s " % self.__cell.id

    def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_eod_data_from_directory(r, duration)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("EOD")[:]
        except:
            data = np.empty();
        f.close()
        return data

    def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray:
        sr = self.__dataset.samplerate
        _unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
        eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
        eod = eod[:int(duration * sr)]
        return eod

    def __read_spike_data_from_nix(self, r:RePro)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("Spikes-1")[:]
        except:
            data = None

        f.close()
        if len(data) < 100:
            data = None
        return data


    def __read_spike_data_from_directory(self, r)->np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r.run
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        if len(data) < 100:
            return None
        return np.asarray(data)

    def __do_read(self, f)->np.ndarray:
        data = []
        f.readline()
        unit = f.readline().strip("#").strip()
        scale = 0.001 if unit == "ms" else 1
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip())*scale)
            l = f.readline()
        return np.asarray(data)


class FIData:
    def __init__(self, dataset:Dataset):
        self.__spike_data = []
        self.__contrasts = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()
        pass

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("FICurve", cell_id=self.__cell.id)
        for r in self.__repros:
            sd, c, eods, time = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.extend(sd)
                self.__eod_data.extend(eods)
                self.__contrasts.extend(c)
                self.__eod_times.extend(time)
            else:
                continue

    def __read_spike_data(self, repro:RePro):
        """

        :param repro:
        :return: spike data and the respective contrasts
        """
        if self.__dataset.has_nix:
            return self.__read_spikes_from_nix(repro)
        else:
            print("Sorry, so far only from nix!!!")
            pass
        return None, None, None, None

    def __do_read_spike_data_from_nix(self, mtag:nix.pycore.MultiTag, stimulus:Stimulus, repro: RePro):
        r_settings = repro.settings.split("\n")
        s_settings = stimulus.settings.split("\n")
        delay = 0.0
        contrast = 0.0
        for s in r_settings:
            if "delay:" in s:
                delay = float(s.split(":")[-1])
                break
        for s in s_settings:
            if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
                contrast = float(s.split(":")[-1])
                break

        start_time = stimulus.start_time - delay
        end_time = stimulus.start_time + stimulus.duration
        spikes_da = mtag.references["Spikes-1"]
        eod_da = mtag.references["LocalEOD-1"]
        start_index_spikes = spikes_da.dimensions[0].index_of(start_time)
        end_index_spikes = spikes_da.dimensions[0].index_of(end_time)
        start_index_eod = eod_da.dimensions[0].index_of(start_time)
        end_index_eod = eod_da.dimensions[0].index_of(end_time)

        local_eod = eod_da[start_index_eod:end_index_eod]
        spikes = spikes_da[start_index_spikes:end_index_spikes] - start_time
        time = np.asarray(eod_da.dimensions[0].axis(len(local_eod))) - delay

        return spikes, local_eod, time, contrast

    def __read_spikes_from_nix(self, repro:RePro):
        spikes = []
        eods = []
        time = []
        contrasts = []
        stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        if len(stimuli) == 0:
            return spikes, contrasts, eods, time
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(repro)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]

        mt = None
        for s in stimuli:
            if not mt or mt.id != s.multi_tag_id:
                mt = b.multi_tags[s.multi_tag_id]
            sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)

            spikes.append(sp)
            eods.append(eod)
            time.append(t)
            contrasts.append(c)
        f.close()
        return  spikes, contrasts, eods, time

    @property
    def size(self):
        return len(self.__spike_data)

    def spikes(self, index=-1):
        if 0 < index < self.size:
            return self.__spike_data[index]
        else:
            return self.__spike_data

    def eod(self, index=-1):
        if 0 < index < self.size:
            return self.__eod_times[index], self.__eod_data[index]
        else:
            return self.__eod_times, self.__eod_data

    def contrast(self, index=-1):
        if 0 < index < self.size:
            return self.__contrasts[index]
        else:
            return self.__contrasts

if __name__ == "__main__":
    #dataset = Dataset(dataset_id='2011-06-14-ag')
    dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")
    # dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1')
    baseline = BaselineData(dataset)
    embed()