from fishbook.fishbook import Dataset, RePro
import numpy as np
import nixio as nix
import os
import subprocess

from IPython import embed


def _unzip_if_needed(dataset, tracename='trace-1.raw'):
    file_name = os.path.join(dataset, tracename)
    if os.path.exists(file_name):
        return
    if os.path.exists(file_name + '.gz'):
        print("\tunzip: %s" % tracename)
        subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])


class BaselineData:

    def __init__(self, dataset:Dataset):
        self.__spike_data = []
        self.__eod_data = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
        for r in self.__repros:
            self.__spike_data.append(self.__read_spike_data(r))
            self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))

    def __read_spike_data(self, r:RePro):
        if self.__dataset.has_nix:
            return self.__read_spike_data_from_nix(r)
        else:
            return self.__read_spike_data_from_directory(r)

    def __read_eod_data(self, r:RePro, duration):
        if self.__dataset.has_nix:
            return self.__read_eod_data_from_nix(r, duration)
        else:
            return self.__read_eod_data_from_directory(r, duration)

    @property
    def dataset(self):
        return self.__dataset

    @property
    def cell(self):
        cells = self.__dataset.cells
        return cells if len(cells) > 1 else cells[0]

    @property
    def subject(self):
        subjects = self.__dataset.subjects
        return subjects if len(subjects) > 1 else subjects[0]

    def spike_data(self, index:int=0):
        return self.__spike_data[index] if len(self.__spike_data) >= index else None

    def eod_data(self, index:int=0):
        eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
        time = np.arange(len(eod)) / self.__dataset.samplerate
        return eod, time

    @property
    def coefficient_of_variation(self):
        cvs = []
        for d in self.__spike_data:
            isis = np.diff(d)
            cvs.append(np.std(isis)/np.mean(d=isis))
        return cvs

    @property
    def vector_strength(self):
        vss = []
        return vss

    @property
    def size(self):
        return len(self.__spike_data)

    def __str__(self):
        str = "Baseline data of cell %s " % self.__cell.id

    def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_eod_data_from_directory(r, duration)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        data = t.retrieve_data("EOD")[:]
        f.close()
        return data

    def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray:
        sr = self.__dataset.samplerate
        _unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
        eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
        eod = eod[:int(duration * sr)]
        return eod

    def __read_spike_data_from_nix(self, r:RePro)->np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        data = t.retrieve_data("Spikes-1")[:]
        f.close()
        return data


    def __read_spike_data_from_directory(self, r)->np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r.run
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        return data

    def __do_read(self, f)->np.ndarray:
        data = []
        f.readline()
        unit = f.readline().strip("#").strip()
        scale = 0.001 if unit == "ms" else 1
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip())*scale)
            l = f.readline()
        return np.asarray(data)


if __name__ == "__main__":
    dataset = Dataset(dataset_id='2011-06-14-ag')
    # dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1')
    baseline = BaselineData(dataset)
    embed()