from functools import reduce
import numpy as np
import nixio as nix
import re
import os
import sys
import glob
import datetime as dt
import subprocess
from IPython import embed
try:
    iso = dt.date.fromisoformat
except AttributeError:
    from backports.datetime_fromisoformat import MonkeyPatch
    MonkeyPatch.patch_fromisoformat()


def progress(count, total, status='', bar_len=60):
    """
    modified after https://gist.github.com/vladignatyev/06860ec2040cb497f0f3
    by Vladimir Ignatev published under MIT License
    """
    percents = count / total
    filled_len = int(percents * bar_len)
    bar = '=' * filled_len + '-' * (bar_len - filled_len)

    sys.stderr.write('[%s] %.2f%s ...%s\r' % (bar, percents * 100, '%', status))
    sys.stderr.flush()


def read_info_file(file_name):
    """
    Reads the info file and returns the stored metadata in a dictionary. The dictionary may be nested.
    @param file_name:  The name of the info file.
    @return: dictionary, the stored information.
    """
    root = {}

    try:
        with open(file_name, 'r') as f:
            lines = f.readlines()
    except UnicodeDecodeError:
        print("Replacing experimenter!!!")
        command = r"sudo sed -i '/Experimenter/c\#       Experimenter: Anna Stoeckl' %s" % file_name
        subprocess.check_call(command, shell=True)
        with open(file_name, 'r') as f:
            lines = f.readlines()
    for l in lines:
        if not l.startswith("#"):
            continue
        l = l.strip("#").strip()
        if len(l) == 0:
            continue
        if not ": " in l:  # subsection
            sec = {}
            root[l[:-1] if l.endswith(":") else l] = sec
        else:
            parts = l.split(': ')
            sec[parts[0].strip()] = parts[1].strip('"').strip()
    return root


def parse_metadata_line(line):
    if not line.startswith("#"):
        return None, None

    line = line.strip("#").strip()
    parts = line.split(":")
    if len(parts) == 0:
        return None, None
    if len(parts) == 1 or len(parts[-1].strip()) == 0:
        return parts[0].strip(), None
    else:
        return parts[0].strip(), parts[-1].strip()


def has_signal(line, col_names):
    """
    Checks whether a signal/stimulus was given in the line.
    :param line: the current line of the data table
    :param col_names: The names of the table header columns
    :return: whether or not any of the signal entries is not empty ("-")
    """
    values = line.split()
    for i, n in enumerate(col_names):
        if n.lower() == "signal" and i < len(values):
            if len(values[i].strip()) > 0 and values[i].strip()[0] != "-":
                return True
    return False


def parse_table(lines, start_index):
    """

    :param lines:
    :param start_index:
    :return:
    """
    data_indices = {}
    stim_count = 0
    names = re.split(r'\s{2,}', lines[start_index + 3][1:].strip())
    while start_index < len(lines):
        l = lines[start_index].strip()
        if l.startswith("#"):  # ignore
            start_index += 1
        elif len(l) > 0:
            if stim_count == 0 and (has_signal(l, names)):
                data_indices[stim_count] = l.split()[0]
                stim_count += 1
            elif stim_count > 0:
                data_indices[stim_count] = l.split()[0]
                stim_count += 1
            start_index += 1
        else:
            start_index += 1
            break
    return data_indices, start_index


def read_stimuli_file(dataset):
    repro_settings = []
    stimulus_indices = []
    settings = {}
    with open(os.path.join(dataset, 'stimuli.dat'), 'r') as f:
        lines = f.readlines()
        index = 0
        current_section = None
        current_section_name = ""
        while index < len(lines):
            l = lines[index].strip()
            if len(l) == 0:
                index += 1
            elif l.startswith("#") and "key" not in l.lower():
                name, value = parse_metadata_line(l)
                if not name:
                    continue
                if name and not value:
                    if current_section:
                        settings[current_section_name] = current_section.copy()

                    current_section = {}
                    current_section_name = name
                else:
                    current_section[name] = value
                index += 1
            elif l.lower().startswith("#key"):  # table data coming
                data, index = parse_table(lines, index)
                # we are done with this repro run
                stimulus_indices.append(data)
                settings[current_section_name] = current_section.copy()
                repro_settings.append(settings.copy())
                current_section = None
                settings = {}
            else:
                # data lines, ignore them here
                index += 1
    return repro_settings, stimulus_indices


def find_key_recursive(dictionary, key, path=[]):
    assert(isinstance(dictionary, dict))
    if key in dictionary.keys():
        path.append(key)
        return True
    for k in dictionary.keys():
        if isinstance(dictionary[k], dict):
            if find_key_recursive(dictionary[k], key, path):
                path.insert(-1, k)
                break
    return len(path) > 0


def deep_get(dictionary, keys, default=None):
    assert(isinstance(dictionary, dict))
    assert(isinstance(keys, list))
    return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys, dictionary)


def _get_string(dictionary: dict, key:str, alt_key=None, default=None):
    p = []
    value = default
    find_key_recursive(dictionary, key, p)
    if len(p) > 0:
        value = deep_get(dictionary, p, default)
    elif alt_key:
        find_key_recursive(dictionary, alt_key, p)
        value = deep_get(dictionary, p, default)
    if default and value != default and isinstance(value, dict):
        value = default
    return value


def _get_date(dictionary: dict, key: str, alt_key=None, default=None):
    p = []
    value = default
    find_key_recursive(dictionary, key, p)
    if len(p) > 0:
        value = dt.date.fromisoformat(deep_get(dictionary, p, default))
    elif alt_key:
        find_key_recursive(dictionary, alt_key, p)
        value = dt.date.fromisoformat(deep_get(dictionary, p, default))
    if value != default and isinstance(value, dict):
        value = default
    return value


def read_dataset_info(info_file):
    exp = ""
    quality = ""
    comment = ""
    rec_date = None
    has_nix = False
    setup = ""
    rec_duration = 0.0
    if not os.path.exists(info_file):
        return exp, rec_date, quality, comment, has_nix, rec_duration, setup
    has_nix = len(glob.glob(os.path.sep.join(info_file.split(os.path.sep)[:-1]) + os.path.sep + "*.nix")) > 0
    info = read_info_file(info_file)
    exp = _get_string(info, "Experimenter")
    rec_date = _get_date(info, "Date")
    quality = _get_string(info, "Recording quality")
    comment = _get_string(info, "Comment", default="")
    rec_duration = _get_string(info, "Recording duration", "Recording duratio", default=0.0)

    if rec_duration != 0.0 and isinstance(rec_duration, str) and "min" in rec_duration:
        rec_duration = rec_duration[:-3]
    elif isinstance(rec_duration, dict):
        rec_duration = 0.0
    setup_info = _get_string(info, "Setup", default=None)
    if setup_info and isinstance(setup_info, dict):
        setup = _get_string(setup_info, "Identifier")
    return exp, rec_date, quality, comment, has_nix, rec_duration, setup


def nix_metadata_to_dict(section):
    """Converts a nix.Section to a dictionary. Keys are the property names, values are 
    always a list of values. Additional information such as unit, definition etc are discarded.

    Args:
        section (nix.Section): The section

    Returns:
        dict: the dictionary containing the section info.
    """
    info = {}
    for p in section.props:
        info[p.name] = [v.value for v in p.values]
    for s in section.sections:
        info[s.name] = nix_metadata_to_dict(s)
    return info


def nix_metadata_to_yaml(section, cur_depth=0, val_count=1):
    """Convert a section to yaml

    Args:
        section ([type]): [description]
        cur_depth (int, optional): [description]. Defaults to 0.
        val_count (int, optional): [description]. Defaults to 1.

    Returns:
        [type]: [description]
    """
    assert(isinstance(section, nix.section.SectionMixin))
    yaml = "%s%s:\n" % ("\t" * cur_depth, section.name)
    for p in section.props:
        val_str = ""
        if val_count > 1 and len(p.values) > 1:
            val_str = "[" + ', '.join([v.to_string() for v in p.values]) + "]"
        elif len(p.values) == 1:
            val_str = p.values[0].to_string()
        yaml += "%s%s: %s\n" % ("\t" * (cur_depth+1), p.name, val_str)
    for s in section.sections:
        yaml += nix_metadata_to_yaml(s, cur_depth+1)
    return yaml


def find_mtags_for_tag(block, tag):
    """
        Finds those multi tags and the respective positions within that match to a certain
        repro run.

        @:returns list of mtags, list of mtag positions
    """
    assert(isinstance(block, nix.pycore.block.Block))
    assert(isinstance(tag, nix.pycore.tag.Tag))
    mtags = []
    indices = []
    tag_start = np.atleast_1d(tag.position)
    tag_end = tag_start + np.atleast_1d(tag.extent)
    for mt in block.multi_tags:
        position_count = mt.positions.shape[0]
        in_tag_positions = []
        for i in range(position_count):
            mt_start = np.atleast_1d(mt.positions[i, :])
            mt_end = mt_start + np.atleast_1d(mt.extents[i, :])

            for j in range(len(tag_start)):
                if mt_start[j] >= tag_start[j] and mt_end[j] <= tag_end[j]:
                    in_tag_positions.append(i)
        if len(in_tag_positions) > 0:
            mtags.append(mt)
            indices.append(in_tag_positions)
    return mtags, indices


def mtag_features_to_yaml(mtag, pos_index, section_yaml=None):
    yaml = section_yaml if section_yaml is not None else ""
    for i in range(len(mtag.features)):
        feat = mtag.features[i]
        feat_name = feat.data.label if feat.data.label and len(feat.data.label) > 0 else feat.data.name
        feat_unit = feat.data.unit if feat.data.unit and len(feat.data.unit) > 0 else ""
        feat_data = mtag.retrieve_feature_data(pos_index, i)

        if np.prod(feat_data.shape) == 1:
            feat_content = "%s %s" % (feat_data[:][0], feat_unit)
        else:
            feat_data = np.round(np.squeeze(feat_data), 10)
            feat_content = "[" + ','.join(map(str, feat_data[:])) + "] %s" % feat_unit
        yaml += "\t%s: %s\n" % (feat_name, feat_content)

    return yaml


def mtag_settings_to_yaml(mtag, pos_index):
    assert(isinstance(mtag, nix.pycore.multi_tag.MultiTag))
    assert(0 <= pos_index < mtag.positions.shape[0])

    yaml = ""
    if mtag.metadata is not None:
        yaml = nix_metadata_to_yaml(mtag.metadata)
    yaml = mtag_features_to_yaml(mtag, pos_index, yaml)
    
    return yaml


if __name__ == "__main__":
    """
    nix_file = "../../science/high_freq_chirps/data/2018-11-09-aa-invivo-1/2018-11-09-aa-invivo-1.nix"
    f = nix.File.open(nix_file, nix.FileMode.ReadOnly)
    b = f.blocks[0]
    yml = nix_metadata_to_yaml(b.tags[0].metadata)
    print(yml)
    print("-"* 80)
    print(nix_metadata_to_yaml(b.metadata))
    embed()
    f.close()
    """
    dataset = "/data/apteronotus/2012-03-23-ad"
    settings = read_stimuli_file(dataset)
    embed()