import datajoint as dj
import nixio as nix
import os
import numpy as np
from IPython import embed
from database import _Dataset, _Repro
schema = dj.schema("fish_book", locals())


class BaselineData(object):

    def __init__(self, dataset:_Dataset):
        self.__data = []
        self.__dataset = dataset
        self._get_data()


    def _get_data(self):
        if not self.__dataset:
            self.__data = []
        self.__data = []

        repros = (_Repro & self.__dataset & "repro_name like 'BaselineActivity%'")
        for r in repros:
            self.__data.append(self.__read_data(r))

    def __read_data(self, r:_Repro):
        if self.__dataset["has_nix"]:
            return self.__read_data_from_nix(r)
        else:
            return self.__read_data_from_directory(r)

    @property
    def dataset(self):
        return self.__dataset

    @property
    def data(self, index:int=0):
        return self.__data[0] if len(self.__data) >= index else None

    @property
    def size(self):
        return len(self.__data)

    def __str__(self):
        str = "Baseline data of %s " % self.__dataset

    def __read_data_from_nix(self, r)->np.ndarray:
        data_source = os.path.join(self.__dataset["data_source"], self.__dataset["dataset_id"] + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from folder")
            return self.__read_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r["repro_id"]]
        if not t:
            print("Tag not found!")
        data = t.retrieve_data("Spikes-1")[:]
        f.close()
        return data

    def __read_data_from_directory(self, r)->np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset["data_source"], "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r["run"]
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        return data

    def __do_read(self, f)->np.ndarray:
        data = []
        f.readline()
        f.readline()
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip()))
            l = f.readline()
        return np.asarray(data)


if __name__ == "__main__":
    print("Test")
    dataset = _Dataset & "dataset_id like '2018-11-09-aa-%' "
    baseline = BaselineData(dataset.fetch1())
    embed()