from functools import reduce
import numpy as np
import nixio as nix
import re
import os
import glob
import datetime as dt
import subprocess
from IPython import embed


def read_info_file(file_name):
    """
    Reads the info file and returns the stored metadata in a dictionary. The dictionary may be nested.
    @param file_name:  The name of the info file.
    @return: dictionary, the stored information.
    """
    root = {}

    try:
        with open(file_name, 'r') as f:
            lines = f.readlines()
    except UnicodeDecodeError:
        print("Replacing experimenter!!!")
        command = "sudo sed -i '/Experimenter/c\#       Experimenter: Anna Stoeckl' %s" % file_name
        subprocess.check_call(command, shell=True)
        with open(file_name, 'r') as f:
            lines = f.readlines()
    for l in lines:
        if not l.startswith("#"):
            continue
        l = l.strip("#").strip()
        if len(l) == 0:
            continue
        if not ": " in l:  # subsection
            sec = {}
            root[l[:-1] if l.endswith(":") else l] = sec
        else:
            parts = l.split(': ')
            sec[parts[0].strip()] = parts[1].strip('"').strip()
    return root


def parse_metadata_line(line):
    if not line.startswith("#"):
        return None, None

    line = line.strip("#").strip()
    parts = line.split(":")
    if len(parts) == 0:
        return None, None
    if len(parts) == 1 or len(parts[-1].strip()) == 0:
        return parts[0].strip(), None
    else:
        return parts[0].strip(), parts[-1].strip()


def has_signal(line, col_names):
    """
    Checks whether a signal/stimulus was given in the line.
    :param line: the current line of the data table
    :param col_names: The names of the table header columns
    :return: whether or not any of the signal entries is not empty ("-")
    """
    values = line.split()
    for i, n in enumerate(col_names):
        if n.lower() == "signal" and i < len(values):
            if len(values[i].strip()) > 0 and values[i].strip()[0] != "-":
                return True
    return False


def parse_table(lines, start_index):
    """

    :param lines:
    :param start_index:
    :return:
    """
    data_indices = {}
    stim_count = 0
    names = re.split(r'\s{2,}', lines[start_index + 3][1:].strip())
    while start_index < len(lines):
        l = lines[start_index].strip()
        if l.startswith("#"):  # ignore
            start_index += 1
        elif len(l) > 0:
            if stim_count == 0 and (has_signal(l, names)):
                data_indices[stim_count] = l.split()[0]
                stim_count += 1
            elif stim_count > 0:
                data_indices[stim_count] = l.split()[0]
                stim_count += 1
            start_index += 1
        else:
            start_index += 1
            break
    return data_indices, start_index


def read_stimuli_file(dataset):
    repro_settings = []
    stimulus_indices = []
    settings = {}
    with open(os.path.join(dataset, 'stimuli.dat'), 'r') as f:
        lines = f.readlines()
        index = 0
        current_section = None
        current_section_name = ""
        while index < len(lines):
            l = lines[index].strip()
            if len(l) == 0:
                index += 1
            elif l.startswith("#") and "key" not in l.lower():
                name, value = parse_metadata_line(l)
                if not name:
                    continue
                if name and not value:
                    if current_section:
                        settings[current_section_name] = current_section.copy()

                    current_section = {}
                    current_section_name = name
                else:
                    current_section[name] = value
                index += 1
            elif l.lower().startswith("#key"):  # table data coming
                data, index = parse_table(lines, index)
                # we are done with this repro run
                stimulus_indices.append(data)
                settings[current_section_name] = current_section.copy()
                repro_settings.append(settings.copy())
                current_section = None
                settings = {}
            else:
                # data lines, ignore them here
                index += 1
    return repro_settings, stimulus_indices


def find_key_recursive(dictionary, key, path=[]):
    assert(isinstance(dictionary, dict))
    if key in dictionary.keys():
        path.append(key)
        return True
    for k in dictionary.keys():
        if isinstance(dictionary[k], dict):
            if find_key_recursive(dictionary[k], key, path):
                path.insert(-1, k)
                break
    return len(path) > 0


def deep_get(dictionary, keys, default=None):
    assert(isinstance(dictionary, dict))
    assert(isinstance(keys, list))
    return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys, dictionary)


def read_dataset_info(info_file):
    exp = ""
    quality = ""
    comment = ""
    rec_date = None
    has_nix = False
    if not os.path.exists(info_file):
        return exp, rec_date, quality, comment, has_nix
    has_nix = len(glob.glob(os.path.sep.join(info_file.split(os.path.sep)[:-1]) + os.path.sep + "*.nix")) > 0
    info = read_info_file(info_file)
    p = []
    find_key_recursive(info, "Experimenter", p)
    if len(p) > 0:
        exp = deep_get(info, p)
    p = []
    find_key_recursive(info, "Date", p)
    if len(p) > 0:
        rec_date = dt.date.fromisoformat(deep_get(info, p))
    p = []
    find_key_recursive(info, "Recording quality", p)
    if len(p) > 0:
        quality = deep_get(info, p)
    find_key_recursive(info, "Comment", p)
    if len(p) > 0:
        comment = deep_get(info, p, default="")

    return exp, rec_date, quality, comment, has_nix


def nix_metadata_to_dict(section):
    info = {}
    for p in section.props:
        info[p.name] = [v.value for v in p.values]
    for s in section.sections:
        info[s.name] = nix_metadata_to_dict(s)
    return info


def nix_metadata_to_yaml(section, cur_depth=0, val_count=1):
    assert(isinstance(section, nix.section.SectionMixin))
    yaml = "%s%s:\n" % ("\t" * cur_depth, section.name)
    for p in section.props:
        val_str = ""
        if val_count > 1 and len(p.values) > 1:
            val_str = "[" + ', '.join([v.to_string() for v in p.values]) + "]"
        elif len(p.values) == 1:
            val_str = p.values[0].to_string()
        yaml += "%s%s: %s\n" % ("\t" * (cur_depth+1), p.name, val_str)
    for s in section.sections:
        yaml += nix_metadata_to_yaml(s, cur_depth+1)
    return yaml


def find_mtags_for_tag(block, tag):
    """
        Finds those multi tags and the respective positions within that match to a certain
        repro run.

        @:returns list of mtags, list of mtag positions
    """
    assert(isinstance(block, nix.pycore.block.Block))
    assert(isinstance(tag, nix.pycore.tag.Tag))
    mtags = []
    indices = []
    tag_start = np.atleast_1d(tag.position)
    tag_end = tag_start + np.atleast_1d(tag.extent)
    for mt in block.multi_tags:
        position_count = mt.positions.shape[0]
        in_tag_positions = []
        for i in range(position_count):
            mt_start = np.atleast_1d(mt.positions[i, :])
            mt_end = mt_start + np.atleast_1d(mt.extents[i, :])

            for j in range(len(tag_start)):
                if mt_start[j] >= tag_start[j] and mt_end[j] <= tag_end[j]:
                    in_tag_positions.append(i)
        if len(in_tag_positions) > 0:
            mtags.append(mt)
            indices.append(in_tag_positions)
    return mtags, indices


def mtag_settings_to_yaml(mtag, pos_index):
    assert(isinstance(mtag, nix.pycore.multi_tag.MultiTag))
    assert(0 <= pos_index < mtag.positions.shape[0])

    yaml = ""
    if mtag.metadata is not None:
        yaml = nix_metadata_to_yaml(mtag.metadata)
    for i in range(len(mtag.features)):
        feat = mtag.features[i]
        feat_data = mtag.retrieve_feature_data(pos_index, i)

        if len(feat_data.shape) == 1:
            feat_name = feat.data.label if feat.data.label and len(feat.data.label) > 0 else feat.data.name
            feat_unit = feat.data.unit if feat.data.unit and len(feat.data.unit) > 0 else ""
            if feat_data.shape[0] == 1:
                feat_content = "%s %s" % (feat_data[0], feat_unit)
            else:
                feat_content = "[" + ','.join(map(str, feat_data[:])) + "] %s" % feat_unit
            yaml += "\t%s: %s\n" % (feat_name, feat_content)
    return yaml


if __name__ == "__main__":
    """
    nix_file = "../../science/high_freq_chirps/data/2018-11-09-aa-invivo-1/2018-11-09-aa-invivo-1.nix"
    f = nix.File.open(nix_file, nix.FileMode.ReadOnly)
    b = f.blocks[0]
    yml = nix_metadata_to_yaml(b.tags[0].metadata)
    print(yml)
    print("-"* 80)
    print(nix_metadata_to_yaml(b.metadata))
    embed()
    f.close()
    """
    dataset = "/Users/jan/zwischenlager/2012-03-23-ad"
    settings = read_stimuli_file(os.path.join(dataset, "stimuli.dat"))
    embed()