import numpy as np
import datajoint as dj
import nixio as nix
import os
import glob
import util as ut

from IPython import embed

schema = dj.schema("fish_book", locals())


@schema
class Dataset(dj.Manual):
    definition = """ # Dataset
       dataset_id : varchar(256)
       ----
       data_source : varchar(512) # path to the dataset
       experimenter : varchar(512) 
       recording_date : date
       quality : varchar(512)
       comment : varchar(1024)
       has_nix : bool
       """

    @staticmethod
    def get_template_tuple(id=None):
        if id is not None:
            d = dict((Dataset() & {"dataset_id": id}).fetch1())
            return d
        return dict(dataset_id=None, data_source="", experimenter="", recording_date=None,
                    quality="", comment="", has_nix=False)

    @staticmethod
    def get_nix_file(key):
        dset = (Dataset() & key).fetch1()
        if dset["ignore"]:
            return None
        file_path = os.path.join(dset["data_source"], dset["dataset_id"] + ".nix")
        if not (os.path.exists(file_path)):
            print("\t No nix file found for path: %s" % dset["data_source"])
            return None
        if not Dataset.check_file_integrity(file_path):
            return None
        return file_path

    @staticmethod
    def check_file_integrity(nix_file):
        sane = True
        try:
            f = nix.File.open(nix_file, nix.FileMode.ReadOnly)
            b = f.blocks[0]
            m = b.metadata
            if "Recording" not in m.sections:
                Warning("\t Could not find Recording section in dataset: %s" % nix_file)
                sane = False
            f.close()
        except ():
            f = None
            print("file: %s is NOT SANE!")
            sane = False
        return sane


@schema
class Subject(dj.Manual):
    definition = """
    # Subject
    subject_id : varchar(256)
    ----
    species : varchar(256)
    """

    @staticmethod
    def get_template_tuple(subject_id=None):
        tup = dict(subject_id=None, species="")
        if subject_id is not None:
            d = dict((Subject() & {"subject_id": subject_id}).fetch1())
            return d
        return tup

    def make(self, key):
        file_path = Dataset.get_nix_file(key)
        if file_path is None:
            return
        nix_file = nix.File.open(file_path, nix.FileMode.ReadOnly)
        m = nix_file.blocks[0].metadata
        inserts = Subject.get_template_tuple()
        subj_info = m["Recording"]["Subject"]
        inserts["subject_id"] = subj_info["Identifier"]
        inserts["species"] = subj_info["Species"][0]
        inserts["weight"] = subj_info["Weight"]
        inserts["size"] = subj_info["Size"]
        inserts["eod_frequency"] = np.round(subj_info["EOD Frequency"] * 10) / 10
        inserts.update(key)
        self.insert1(inserts, skip_duplicates=True)
        nix_file.close()


@schema
class SubjectDatasetMap(dj.Manual):
    definition = """
    # SubjectDatasetMap
    -> Subject
    -> Dataset
    """


@schema
class SubjectProperties(dj.Manual):
    definition = """
    # SubjectProperties
    id : int auto_increment
    ----
    -> Subject
    recording_date : date
    weight : float
    size : float
    eod_frequency : float
    """

    def get_template_tuple(id=None):
        tup = dict(id=None, subject_id=None, recording_date=None, weight=0.0, size=0.0,
                   eod_frequency=0.0)
        if id is not None:
            return dict((SubjectProperties() & {"id": id}).fetch1())
        return tup


@schema
class Cell(dj.Manual):
    definition = """
    # Table that stores information about recorded cells.
    cell_id : varchar(256)
    ----
    -> Subject
    cell_type : varchar(256)
    firing_rate : float
    structure : varchar(256)
    region : varchar(256)
    subregion : varchar(256)
    depth : float
    lateral_pos : float
    transversal_section : float
    """

    @staticmethod
    def get_template_tuple(cell_id=None):
        tup = dict(cell_id=None, subject_id=None, cell_type="", firing_rate=0.0,
                   depth=0.0, region="", subregion="", structure="",
                   lateral_pos=0.0, transversal_section=0.0)
        if cell_id is not None:
            d = dict((Cell() & {"cell_id": cell_id}).fetch1())
            return d
        return tup


@schema
class CellDatasetMap(dj.Manual):
    definition = """
    # Table that maps recorded cells to datasets
    -> Dataset
    -> Cell
    """


@schema
class Repro(dj.Manual):
    definition = """
    repro_id : varchar(512)
    run : smallint
    -> Dataset
    ----
    repro_name : varchar(512)
    settings : varchar(3000)
    start : float
    duration : float
    """

    @staticmethod
    def get_template_tuple(repro_id=None):
        tup = dict(repro_id=None, dataset_id=None, run=0, repro_name="", settings=None, start=None, duration=None)
        if repro_id is not None:
            d = dict((RePro() & {"repro_id": repro_id}).fetch1())
            return d
        return tup


@schema
class Stimulus(dj.Manual):
    definition = """
    stimulus_id : int
    -> Repro 
    ---
    settings : varchar(3000)
    start : float
    duration : float
    """

def populate_datasets(data_path):
    print("Importing dataset %s" % data_path)
    if not os.path.exists(data_path):
        return
    dset_name = os.path.split(data_path)[-1]
    experimenter, rec_date, quality, comment, has_nix = ut.read_dataset_info(os.path.join(data_path, 'info.dat'))
    if not experimenter:
        return False

    inserts = Dataset.get_template_tuple()
    inserts["dataset_id"] = dset_name
    inserts["data_source"] = data_path
    inserts["experimenter"] = experimenter
    inserts["recording_date"] = rec_date
    inserts["quality"] = quality
    inserts["comment"] = comment
    inserts["has_nix"] = has_nix
    Dataset().insert1(inserts, skip_duplicates=True)
    return True


def populate_subjects(data_path):
    print("\tImporting subject(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    info_file = os.path.join(data_path, 'info.dat')
    if not os.path.exists(info_file):
        return None, None, False
    info = ut.read_info_file(info_file)
    p = []
    ut.find_key_recursive(info, "Subject", p)
    if len(p) > 0:
        subj = ut.deep_get(info, p)
    inserts = Subject.get_template_tuple()
    inserts["subject_id"] = subj["Identifier"]
    inserts["species"] = subj["Species"]
    Subject().insert1(inserts, skip_duplicates=True)

    # multi mach entry
    dataset = dict((Dataset() & {"dataset_id": dset_name}).fetch1())
    mm = dict(dataset_id=dataset["dataset_id"], subject_id=subj["Identifier"])
    SubjectDatasetMap.insert1(mm, skip_duplicates=True)

    # subject properties
    props = SubjectProperties.get_template_tuple()
    props["subject_id"] = subj["Identifier"]
    props["recording_date"] = dataset["recording_date"]
    if "Weight" in subj.keys():
        props["weight"] = np.round(float(subj["Weight"][:-1]), 1)
    if "Size" in subj.keys():
        props["size"] = np.round(float(subj["Size"][:-2]), 1)
    if "EOD Frequency" in subj.keys():
        props["eod_frequency"] = np.round(float(subj["EOD Frequency"][:-2]))
    p = props.copy()
    p.pop("id")
    if len(SubjectProperties & p) == 0:
        SubjectProperties.insert1(props, skip_duplicates=True)


def populate_cells(data_path):
    print("\tImporting cell(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    info_file = os.path.join(data_path, 'info.dat')
    if not os.path.exists(info_file):
        return None, None, False
    info = ut.read_info_file(info_file)
    p = []
    ut.find_key_recursive(info, "Subject", p)
    subject_info = ut.deep_get(info, p)

    p = []
    ut.find_key_recursive(info, "Cell", p)
    cell_info = ut.deep_get(info, p)

    p = []
    ut.find_key_recursive(info, "Firing Rate1", p)
    firing_rate = ut.deep_get(info, p, default=0.0)
    if isinstance(firing_rate, str):
        firing_rate = float(firing_rate[:-2])

    dataset = dict((Dataset & {"dataset_id": dset_name}).fetch1())
    subject = dict((Subject & {"subject_id": subject_info["Identifier"]}).fetch1())

    dataset_id = dataset["dataset_id"]
    cell_id = "-".join(dataset_id.split("-")[:4]) if len(dataset_id) > 4 else dataset_id
    cell_props = Cell.get_template_tuple()
    cell_props["subject_id"] = subject["subject_id"]
    cell_props["cell_id"] = cell_id
    cell_props["cell_type"] = cell_info["CellType"]
    cell_props["firing_rate"] = firing_rate
    if "Structure" in cell_info.keys():
        cell_props["structure"] = cell_info["Structure"]
    if "BrainRegion" in cell_info.keys():
        cell_props["region"] = cell_info["BrainRegion"]
    if "BrainSubRegion" in cell_info.keys():
        cell_props["subregion"] = cell_info["BrainSubRegion"]
    if "Depth" in cell_info.keys():
        cell_props["depth"] = float(cell_info["Depth"][:-2])
    if "Lateral position" in cell_info.keys():
        cell_props["lateral_pos"] = float(cell_info["Lateral position"][:-2])
    if "Transverse section" in cell_info.keys():
        cell_props["transversal_section"] = float(cell_info["Transverse section"])

    Cell.insert1(cell_props, skip_duplicates=True)

    # multi mach entry
    mm = dict(dataset_id=dataset["dataset_id"], cell_id=cell_props["cell_id"])
    CellDatasetMap.insert1(mm, skip_duplicates=True)


def populate_repros(data_path):
    print("\tImporting RePro(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    if len(Dataset & {"dataset_id": dset_name}) != 1:
        return False
    dataset = dict((Dataset & {"dataset_id": dset_name}).fetch1())

    if dataset["has_nix"]:
        print("\t\tscanning nix file")
        nix_files = glob.glob(os.path.join(dataset["data_source"], "*.nix"))
        for nf in nix_files:
            if not Dataset.check_file_integrity(nf):
                print("file is not sane!!!")
                continue
            f = nix.File.open(nf, nix.FileMode.ReadOnly)
            b = f.blocks[0]
            for t in b.tags:
                if "relacs.repro_run" in t.type:
                    rs = t.metadata.find_sections(lambda x: "Run" in x.props)
                    if len(rs) == 0:
                        continue
                    rs = rs[0]
                    rp = Repro.get_template_tuple()
                    rp["run"] = rs["Run"]
                    rp["repro_name"] = rs["RePro"]
                    rp["dataset_id"] = dataset["dataset_id"]
                    rp["repro_id"] = t.name
                    settings = t.metadata.find_sections(lambda x: "settings" in x.type)
                    if len(settings) > 0:
                        rp["settings"] = ut.nix_metadata_to_yaml(settings[0])
                    else:
                        rp["settings"] = ut.nix_metadata_to_yaml(t.metadata)
                    rp["start"] = t.position[0]
                    rp["duration"] = t.extent[0]
                    Repro.insert1(rp, skip_duplicates=True)
            f.close()
            f = None
    else:
        pass
    return True


def drop_tables():
    Dataset.drop()
    Subject.drop()


def populate(datasets):
    for d in datasets:
        if not populate_datasets(d):
            continue
        populate_subjects(d)
        populate_cells(d)
        try:
            populate_repros(d)
        except ():
            print("something went wrong! %s" % d)


if __name__ == "__main__":
    # data_dir = "../../science/high_frequency_chirps/data"
    data_dir = "../high_freq_chirps/data"
    datasets = glob.glob(os.path.join(data_dir, '2018-11-20*'))
    # drop_tables()
    populate(datasets)