from fishbook.fishbook import Dataset, RePro
import numpy as np
import nixio as nix
from scipy.stats import circstd
import os
import subprocess

from IPython import embed


def _zero_crossings(x, t, interpolate=False):
    dt = t[1] - t[0]
    x_shift = np.roll(x, 1)
    x_shift[0] = 0.0
    xings = np.where((x >= 0 ) & (x_shift < 0))[0]
    crossings = np.zeros(len(xings))
    if interpolate:
        for i, tf in enumerate(xings):
            if x[tf] > 0.001:
                m = (x[tf] - x[tf-1])/dt
                crossings[i] = t[tf] - x[tf]/m
            elif x[tf] < -0.001:
                m = (x[tf + 1] - x[tf]) / dt
                crossings[i] = t[tf] - x[tf]/m
            else:
                crossings[i] = t[tf]
    else:
        crossings = t[xings]
    return crossings


def _unzip_if_needed(dataset, tracename='trace-1.raw'):
    file_name = os.path.join(dataset, tracename)
    if os.path.exists(file_name):
        return
    if os.path.exists(file_name + '.gz'):
        print("\tunzip: %s" % tracename)
        subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])


class BaselineData:

    def __init__(self, dataset:Dataset):
        self.__spike_data = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
        for r in self.__repros:
            sd = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.append(sd)
            else:
                continue
            self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))

    def valid(self):
        # fixme implement me!
        pass

    def __read_spike_data(self, r:RePro):
        if self.__dataset.has_nix:
            return self.__read_spike_data_from_nix(r)
        else:
            return self.__read_spike_data_from_directory(r)

    def __read_eod_data(self, r:RePro, duration):
        if self.__dataset.has_nix:
            return self.__read_eod_data_from_nix(r, duration)
        else:
            return self.__read_eod_data_from_directory(r, duration)

    def __get_serial_correlation(self, times, max_lags=50):
        if times is None or len(times) < max_lags:
            return None
        isis = np.diff(times)
        unbiased = isis - np.mean(isis, 0)
        norm = sum(unbiased ** 2)
        a_corr = np.correlate(unbiased, unbiased, "same") / norm
        a_corr = a_corr[int(len(a_corr) / 2):]
        return a_corr[:max_lags]

    def serial_correlation(self, max_lags=50):
        """
            return the serial correlation for the the spike train provided by spike_times.
        @param max_lags: The number of lags to take into account
        @return: the serial correlation as a function of the lag
        """
        scs = []
        for sd in self.__spike_data:
            if sd is None or len(sd) < 100:
                continue
            corr = self.__get_serial_correlation(sd, max_lags=max_lags)
            if corr is not None:
                scs.append(corr)
        return scs

    def circular_std(self):
        cstds = []
        for i in range(self.size):
            phases = self.__spike_phases(index=i)
            cstds.append(circstd(phases))
        return cstds

    def eod_frequency(self):
        eodfs = []
        for i in range(self.size):
            eod, time = self.eod(i)
            xings = _zero_crossings(eod, time, interpolate=False)
            eodfs.append(len(xings)/xings[-1])
        return 0.0 if len(eodfs) < 1 else np.mean(eodfs)

    def __spike_phases(self, index=0): # fixme buffer this stuff
        etimes = self.eod_times(index=index)
        eod_period = np.mean(np.diff(etimes))
        phases = np.zeros(len(self.spikes(index)))
        for i, st in enumerate(self.spikes(index)):
            last_eod_index = np.where(etimes <= st)[0]
            if len(last_eod_index) == 0:
                continue
            phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi
        return phases

    def eod_times(self, index=0, interpolate=True):
        if index >= self.size:
            return None
        if len(self.__eod_times) < len(self.__eod_data):
            eod, time = self.eod(index)
            etimes = _zero_crossings(eod, time, interpolate=interpolate)
        else:
            etimes = self.__eod_times[index]
        return etimes

    @property
    def dataset(self):
        return self.__dataset

    @property
    def cell(self):
        cells = self.__dataset.cells
        return cells if len(cells) > 1 else cells[0]

    @property
    def subject(self):
        subjects = self.__dataset.subjects
        return subjects if len(subjects) > 1 else subjects[0]

    def spikes(self, index:int=0):
        return self.__spike_data[index] if len(self.__spike_data) >= index else None

    def eod(self, index:int=0):
        eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
        time = np.arange(len(eod)) / self.__dataset.samplerate
        return eod, time

    @property
    def burst_index(self):
        bi = []
        for i, sd in enumerate(self.__spike_data):
            if len(sd) < 2:
                continue
            et = self.eod_times(index=i)
            eod_period = np.mean(np.diff(et))
            isis = np.diff(sd)
            bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
        return bi

    @property
    def coefficient_of_variation(self):
        cvs = []
        for d in self.__spike_data:
            isis = np.diff(d)
            cvs.append(np.std(isis)/np.mean(isis))
        return cvs

    @property
    def vector_strength(self):
        vss = []
        spike_phases = []
        for i, sd in enumerate(self.__spike_data):
            phases = self.__spike_phases(i)
            ms_sin_alpha = np.mean(np.sin(phases)) ** 2
            ms_cos_alpha = np.mean(np.cos(phases)) ** 2
            vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
            vss.append(vs)
            spike_phases.append(phases)
        return vss, spike_phases

    @property
    def size(self):
        return len(self.__spike_data)

    def __str__(self):
        str = "Baseline data of cell %s " % self.__cell.id

    def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_eod_data_from_directory(r, duration)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("EOD")[:]
        except:
            data = np.empty();
        f.close()
        return data

    def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray:
        sr = self.__dataset.samplerate
        _unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
        eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
        eod = eod[:int(duration * sr)]
        return eod

    def __read_spike_data_from_nix(self, r:RePro)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("Spikes-1")[:]
        except:
            data = None

        f.close()
        if len(data) < 100:
            data = None
        return data


    def __read_spike_data_from_directory(self, r)->np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r.run
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        if len(data) < 100:
            return None
        return np.asarray(data)

    def __do_read(self, f)->np.ndarray:
        data = []
        f.readline()
        unit = f.readline().strip("#").strip()
        scale = 0.001 if unit == "ms" else 1
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip())*scale)
            l = f.readline()
        return np.asarray(data)


if __name__ == "__main__":
    dataset = Dataset(dataset_id='2011-06-14-ag')
    # dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1')
    baseline = BaselineData(dataset)
    embed()