import numpy as np
import datajoint as dj
import nixio as nix
import os
import glob
import socket
from fishbook.backend.util import read_info_file, read_dataset_info, read_stimuli_file
from fishbook.backend.util import find_key_recursive, deep_get, find_mtags_for_tag
from fishbook.backend.util import mtag_settings_to_yaml, nix_metadata_to_yaml, mtag_features_to_yaml, progress
import uuid
import yaml

from IPython import embed

dj.config["enable_python_native_blobs"] = True
schema = dj.schema("fish_book", locals())


@schema
class Datasets(dj.Manual):
    definition = """ # _Dataset
       dataset_id : varchar(256)
       ----
       data_source : varchar(512) # path to the dataset
       data_host : varchar(512) # fully qualified domain name
       experimenter : varchar(512)
       setup : varchar(128) 
       recording_date : date
       quality : varchar(512)
       comment : varchar(1024)
       duration : float
       has_nix : bool
       """

    @staticmethod
    def get_template_tuple(id=None):
        if id is not None:
            d = dict((Datasets() & {"dataset_id": id}).fetch1())
            return d
        return dict(dataset_id=None, data_source="", data_host="", experimenter="", setup="",
                    recording_date=None, quality="", comment="", duration=0.0, has_nix=False)

    @staticmethod
    def get_nix_file(key):
        dset = (Datasets() & key).fetch1()
        if dset["ignore"]:
            return None
        file_path = os.path.join(dset["data_source"], dset["dataset_id"] + ".nix")
        if not (os.path.exists(file_path)):
            print("\t No nix file found for path: %s" % dset["data_source"])
            return None
        if not Datasets.check_file_integrity(file_path):
            return None
        return file_path

    @staticmethod
    def check_file_integrity(nix_file):
        sane = True
        try:
            f = nix.File.open(nix_file, nix.FileMode.ReadOnly)
            b = f.blocks[0]
            m = b.metadata
            if "Recording" not in m.sections:
                Warning("\t Could not find Recording section in dataset: %s" % nix_file)
                sane = False
            f.close()
        except ():
            f = None
            print("file: %s is NOT SANE!")
            sane = False
        return sane


@schema
class Subjects(dj.Manual):
    definition = """
    # Subjects
    subject_id : varchar(256)
    ----
    species : varchar(256)
    """

    @staticmethod
    def get_template_tuple(subject_id=None):
        tup = dict(subject_id=None, species="")
        if subject_id is not None:
            d = dict((Subjects() & {"subject_id": subject_id}).fetch1())
            return d
        return tup

    def make(self, key):
        file_path = Datasets.get_nix_file(key)
        if file_path is None:
            return
        nix_file = nix.File.open(file_path, nix.FileMode.ReadOnly)
        m = nix_file.blocks[0].metadata
        inserts = Subjects.get_template_tuple()
        subj_info = m["Recording"]["Subject"]
        inserts["subject_id"] = subj_info["Identifier"]
        inserts["species"] = subj_info["Species"][0]
        inserts["weight"] = subj_info["Weight"]
        inserts["size"] = subj_info["Size"]
        inserts["eod_frequency"] = np.round(subj_info["EOD Frequency"] * 10) / 10
        inserts.update(key)
        self.insert1(inserts, skip_duplicates=True)
        nix_file.close()

     #@property
    #def datasets(self):
    #    retrun


@schema
class SubjectDatasetMap(dj.Manual):
    definition = """
    # SubjectDatasetMap
    -> Subjects
    -> Datasets
    """


@schema
class SubjectProperties(dj.Manual):
    definition = """
    # _SubjectProperties
    id : int auto_increment
    ----
    -> Subjects
    recording_date : date
    weight : float
    size : float
    eod_frequency : float
    """

    @staticmethod
    def get_template_tuple(id=None):
        tup = dict(id=None, subject_id=None, recording_date=None, weight=0.0, size=0.0,
                   eod_frequency=0.0)
        if id is not None:
            return dict((SubjectProperties() & {"id": id}).fetch1())
        return tup


@schema
class Cells(dj.Manual):
    definition = """
    # Table that stores information about recorded cells.
    cell_id : varchar(256)
    ----
    -> Subjects
    cell_type : varchar(256)
    firing_rate : float
    structure : varchar(256)
    region : varchar(256)
    subregion : varchar(256)
    depth : float
    lateral_pos : float
    transversal_section : float
    """

    @staticmethod
    def get_template_tuple(cell_id=None):
        tup = dict(cell_id=None, subject_id=None, cell_type="", firing_rate=0.0,
                   depth=0.0, region="", subregion="", structure="",
                   lateral_pos=0.0, transversal_section=0.0)
        if cell_id is not None:
            d = dict((Cells() & {"cell_id": cell_id}).fetch1())
            return d
        return tup


@schema
class CellDatasetMap(dj.Manual):
    definition = """
    # Table that maps recorded cells to datasets
    -> Datasets
    -> Cells
    """


@schema
class Repros(dj.Manual):
    definition = """
    repro_id : varchar(512)     # The name that was given to the RePro run by relacs
    run : smallint              # A counter counting the runs of the ReProp in this dataset
    -> Cells                    # 
    ----
    repro_name : varchar(512)   # The original name of the RePro itself, not any given name by user or relacs
    settings : varchar(3000)    # Yaml formatted string containing the repro settings (tag.metadata in case of a nix file)
    start : float               # The start time of the repro
    duration : float            # The duration of the repro
    """

    @staticmethod
    def get_template_tuple(repro_id=None):
        tup = dict(repro_id=None, cell_id=None, run=0, repro_name="", settings=None, start=None, duration=None)
        if repro_id is not None:
            d = dict((Repros() & {"repro_id": repro_id}).fetch1())
            return d
        return tup


@schema
class Stimuli(dj.Manual):
    definition = """
    stimulus_id : varchar(50)
    -> Repros
    ---
    stimulus_index : int
    stimulus_name : varchar(512)
    mtag_id : varchar(50)
    start_time : float
    start_index : int
    duration : float
    settings : varchar(3000)
    """

    @staticmethod
    def get_template_tuple(stimulus_id=None):
        if stimulus_id is not None:
            tup = dict((Stimuli & {"stimulus_id": stimulus_id}).fetch1())
        else:
            tup = dict(stimulus_id=None, stimulus_index=None, stimulus_name="", start_index=0, start_time=0.0,
                       duration=0.0, settings=None)
        return tup


def populate_datasets(data_path, update=False):
    if not os.path.exists(data_path):
        return False
    dset_name = os.path.split(data_path)[-1]
    experimenter, rec_date, quality, comment, has_nix, rec_duration, setup = read_dataset_info(os.path.join(data_path, 'info.dat'))
    if not experimenter:
        return False

    inserts = Datasets.get_template_tuple()
    inserts["dataset_id"] = dset_name
    inserts["data_source"] = os.path.abspath(data_path)
    inserts["data_host"] = socket.getfqdn()
    inserts["experimenter"] = experimenter
    inserts["recording_date"] = rec_date
    inserts["quality"] = quality if not isinstance(quality, dict) else ""
    inserts["comment"] = comment if not isinstance(comment, dict) else ""
    inserts["duration"] = rec_duration
    inserts["setup"] = setup
    inserts["has_nix"] = has_nix
    if len(Datasets & "dataset_id like '%s'" % inserts["dataset_id"]) > 0 and not update:
        print('\t\t %s is already in database!' % dset_name)
        return False
    Datasets().insert1(inserts, skip_duplicates=True)
    return True


def populate_subjects(data_path):
    print("\tImporting subject(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    info_file = os.path.join(data_path, 'info.dat')
    if not os.path.exists(info_file):
        return None, None, False
    info = read_info_file(info_file)

    p = []
    find_key_recursive(info, "Subject", p)
    subj = {}
    if len(p) > 0:
        subj = deep_get(info, p)

    inserts = Subjects.get_template_tuple()
    subj_id = None
    if "Identifier" in subj.keys():
        if isinstance(subj["Identifier"], dict):
            subj_id = "unspecified_" + dset_name
        else:
            subj_id = subj["Identifier"]
    elif "Identifier" in info.keys():
        if isinstance(info["Identifier"], dict):
            subj_id = "unspecified_" + dset_name
        else:
            subj_id = info["Identifier"]
    else:
        subj_id = "unspecified_" + dset_name
    inserts["subject_id"] = subj_id
    inserts["species"] = subj["Species"]
    Subjects().insert1(inserts, skip_duplicates=True)

    # multi match entry
    dataset = dict((Datasets() & {"dataset_id": dset_name}).fetch1())
    mm = dict(dataset_id=dataset["dataset_id"], subject_id=inserts["subject_id"])
    SubjectDatasetMap.insert1(mm, skip_duplicates=True)

    # subject properties
    props = SubjectProperties.get_template_tuple()
    props["subject_id"] = inserts["subject_id"]
    props["recording_date"] = dataset["recording_date"]
    if "Weight" in subj.keys():
        props["weight"] = np.round(float(subj["Weight"][:-1]), 1)
    if "Size" in subj.keys():
        props["size"] = np.round(float(subj["Size"][:-2]), 1)
    if "EOD Frequency" in subj.keys():
        props["eod_frequency"] = np.round(float(subj["EOD Frequency"][:-2]))
    p = props.copy()
    p.pop("id")
    if len(SubjectProperties & p) == 0:
        SubjectProperties.insert1(props, skip_duplicates=True)


def populate_cells(data_path):
    print("\tImporting cell(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    info_file = os.path.join(data_path, 'info.dat')
    if not os.path.exists(info_file):
        return None, None, False
    info = read_info_file(info_file)
    p = []
    find_key_recursive(info, "Subject", p)
    subject_info = deep_get(info, p)

    p = []
    find_key_recursive(info, "Cell", p)
    cell_info = deep_get(info, p)

    p = []
    res = find_key_recursive(info, "Firing Rate1", p)
    if res:
        firing_rate = deep_get(info, p, default=0.0)
    else:    
        firing_rate = 0.0
    if isinstance(firing_rate, str):
        firing_rate = float(firing_rate[:-2])

    subj_id = None
    if "Identifier" in subject_info.keys():
        if isinstance(subject_info["Identifier"], dict):
            subj_id = "unspecified_" + dset_name
        else:
            subj_id = subject_info["Identifier"]
    elif "Identifier" in info.keys():
        if isinstance(info["Identifier"], dict):
            subj_id = "unspecified_" + dset_name
        else:
            subj_id = info["Identifier"]
    else:
        subj_id = "unspecified_" + dset_name
    dataset = dict((Datasets & {"dataset_id": dset_name}).fetch1())
    subject = dict((Subjects & {"subject_id": subj_id}).fetch1())

    dataset_id = dataset["dataset_id"]
    cell_id = "-".join(dataset_id.split("-")[:4]) if len(dataset_id) > 4 else dataset_id
    cell_props = Cells.get_template_tuple()
    cell_props["subject_id"] = subject["subject_id"]
    cell_props["cell_id"] = cell_id
    cell_props["cell_type"] = cell_info["CellType"]
    cell_props["firing_rate"] = firing_rate
    if "Structure" in cell_info.keys():
        cell_props["structure"] = cell_info["Structure"]
    if "BrainRegion" in cell_info.keys():
        cell_props["region"] = cell_info["BrainRegion"]
    if "BrainSubRegion" in cell_info.keys():
        cell_props["subregion"] = cell_info["BrainSubRegion"]
    if "Depth" in cell_info.keys():
        cell_props["depth"] = float(cell_info["Depth"][:-2])
    if "Lateral position" in cell_info.keys():
        cell_props["lateral_pos"] = float(cell_info["Lateral position"][:-2])
    if "Transverse section" in cell_info.keys():
        cell_props["transversal_section"] = float(cell_info["Transverse section"])
    Cells.insert1(cell_props, skip_duplicates=True)

    # multi match entry
    mm = dict(dataset_id=dataset["dataset_id"], cell_id=cell_props["cell_id"])
    CellDatasetMap.insert1(mm, skip_duplicates=True)


def scan_nix_file_for_repros(dataset):
    print("\t\tscanning nix file")
    cell_id = (Cells * CellDatasetMap * (Datasets & "dataset_id = '%s'" % dataset["dataset_id"])).fetch("cell_id", limit=1)[0]
    nix_files = glob.glob(os.path.join(dataset["data_source"], "*.nix"))
    for nf in nix_files:
        if not Datasets.check_file_integrity(nf):
            print("\t\tfile is not sane!!!")
            continue
        f = nix.File.open(nf, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        repro_runs = [t for t in b.tags if "relacs.repro_run" in t.type]
        total = len(repro_runs)
        for i, t in enumerate(repro_runs):
            rs = t.metadata.find_sections(lambda x: "Run" in x.props)
            rs = rs[0]
            if len(rs) == 0:
                continue

            progress(i+1, total, "Scanning repro run %s" % rs["RePro"]) 
           
            rp = Repros.get_template_tuple()
            rp["run"] = rs["Run"]
            rp["repro_name"] = rs["RePro"]
            rp["cell_id"] = cell_id
            rp["repro_id"] = t.name
            settings = t.metadata.find_sections(lambda x: "settings" in x.type)
            if len(settings) > 0:
                rp["settings"] = nix_metadata_to_yaml(settings[0])
            else:
                rp["settings"] = nix_metadata_to_yaml(t.metadata)
            rp["start"] = t.position[0]
            rp["duration"] = t.extent[0]
            Repros.insert1(rp, skip_duplicates=True)

            # import Stimuli
            repro = dict((Repros & dict(repro_id=rp["repro_id"], cell_id=cell_id)).fetch1())
            repro.pop("settings")
            repro.pop("repro_name")
            repro.pop("start")
            repro.pop("duration")

            mtags, positions = find_mtags_for_tag(b, t)
            mt_settings_dict = {}
            positions_dict = {}
            extents_dict = {}
            for j, mt in enumerate(mtags):
                if mt.id in positions_dict.keys():
                    mt_positions = positions_dict[mt.id]
                    mt_extents = extents_dict[mt.id]
                    mdata_yaml = mt_settings_dict[mt.id]
                else:
                    mdata_yaml = nix_metadata_to_yaml(mt.metadata)
                    mt_settings_dict[mt.id] = mdata_yaml
                    mt_positions = np.atleast_2d(mt.positions[:])
                    mt_extents = np.atleast_2d(mt.extents[:])
                    if mt.positions.shape[0] != mt_positions.shape[0]:
                        mt_positions = mt_positions.T
                        mt_extents = mt_extents.T
                for p in positions[j]:
                    settings = mtag_features_to_yaml(mt, p, mdata_yaml)
                    stim_start = mt_positions[p, 0]
                    stim_duration = mt_extents[p, 0]

                    stim = Stimuli.get_template_tuple()
                    stim["stimulus_id"] = str(uuid.uuid1())
                    stim["stimulus_index"] = p
                    stim["start_time"] = stim_start
                    stim["start_index"] = -1
                    stim["duration"] = stim_duration
                    stim["settings"] = settings
                    stim["mtag_id"] = mt.id
                    stim["stimulus_name"] = mt.name
                    stim.update(repro)
                    Stimuli.insert1(stim, skip_duplicates=True)
            print(" " * 120, end="\r")
        print("\n")
        f.close()
        f = None


def scan_folder_for_repros(dataset):
    print("\t\tNo nix-file, scanning directory!")
    repro_settings, stim_indices = read_stimuli_file(dataset["data_source"])
    repro_counts = {}
    cell_id = (Cells * CellDatasetMap * (Datasets & "dataset_id = '%s'" % dataset["dataset_id"])).fetch("cell_id", limit=1)[0]
    for rs, si in zip(repro_settings, stim_indices):
        rp = Repros.get_template_tuple()
        path = []
        if not find_key_recursive(rs, "run", path):
            find_key_recursive(rs, "Run", path)
        if len(path) > 0:
            rp["run"] = deep_get(rs, path, 0)
        else:
            rp["run"] = -1

        path = []
        if not find_key_recursive(rs, "repro", path):
            find_key_recursive(rs, "RePro", path)
        rp["repro_name"] = deep_get(rs, path, "None")
        
        path = []
        if rp["repro_name"] in repro_counts.keys():
            repro_counts[rp["repro_name"]] += 1
        else:
            repro_counts[rp["repro_name"]] = 1
        rp["cell_id"] = cell_id
        rp["repro_id"] = rp["repro_name"] + str(repro_counts[rp["repro_name"]])
        rp["start"] = 0.
        rp["duration"] = 0.
        rp["settings"] = yaml.dump(rs).replace("'", "")
        Repros.insert1(rp, skip_duplicates=True)

        # import stimuli
        repro = dict((Repros & dict(repro_id=rp["repro_id"], cell_id=cell_id)).fetch1())
        repro.pop("settings")
        repro.pop("repro_name")
        repro.pop("start")
        repro.pop("duration")

        total = len(si.keys())
        for j, k in enumerate(si.keys()):
            progress(j+1, total, "scanning repro %s" % rp["repro_name"])
            s = int(si[k])
            stim_start = 0.
            path = []
            if not find_key_recursive(rs, "duration", path):
                find_key_recursive(rs, "Duration", path)
            if len(path) > 0 :
                stim_duration = deep_get(rs, path, None)
                if "ms" in stim_duration:
                    stim_duration = float(stim_duration[:stim_duration.index("ms")])
                else:
                    stim_duration = float(stim_duration[:stim_duration.index("s")])
            else:
                stim_duration = 0.0

            stim = Stimuli.get_template_tuple()
            stim["stimulus_id"] = str(uuid.uuid1())
            stim["stimulus_index"] = j
            stim["start_time"] = stim_start
            stim["start_index"] = s
            stim["duration"] = stim_duration
            stim["settings"] = yaml.dump(rs).replace("'", "")
            stim["mtag_id"] = ""
            stim["stimulus_name"] = ""
            stim.update(repro)
            Stimuli.insert1(stim, skip_duplicates=True)
        print(" " *120, end='\r')
        #if i < len(repro_settings):
        #    print((" " * 150), end="\r")


def populate_repros(data_path):
    print("\tImporting RePro(s) of %s" % data_path)
    dset_name = os.path.split(data_path)[-1]
    if len(Datasets & {"dataset_id": dset_name}) != 1:
        return False
    dataset = dict((Datasets & {"dataset_id": dset_name}).fetch1())

    if dataset["has_nix"]:
        scan_nix_file_for_repros(dataset)
    else:
        scan_folder_for_repros(dataset)
    return True


def drop_tables():
    Datasets.drop()
    Subjects.drop()


def populate(datasets, update=False):
    for i, d in enumerate(datasets):
        print("Importing %i of %i: %s" % (i+1, len(datasets), d))
        if not populate_datasets(d, update):
            continue
        populate_subjects(d)
        populate_cells(d)
        try:
            populate_repros(d)
        except ():
            print("\t\tsomething went wrong! %s" % d)


if __name__ == "__main__":
    data_dir = "/data/apteronotus"
    # data_dir = "../high_freq_chirps/data"
    # drop_tables()
    # datasets = glob.glob("/Users/jan/zwischenlager/2012-*")2010-06-21-ac/info.dat
    datasets = glob.glob(os.path.join(data_dir, '/data/apteronotus/2010-06-18*'))
    populate(datasets, update=False)