import os
import glob
import pandas as pd
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score     
from IPython import embed
import multiprocessing
from joblib import Parallel, delayed

from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata

figure_folder = "figures"
data_folder = "data"

def get_rates(spike_trains, duration, dt, kernel_width):
    """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.

    Args:
        spike_trains ([type]): [description]
        duration ([type]): [description]
        dt ([type]): [description]
        kernel_width ([type]): [description]

    Returns:
        np.ndarray: Matrix of firing rates, 1. dimension is the number of trials
        np.ndarray: the time vector
    """
    time = np.arange(0.0, duration, dt)
    rates = np.zeros((len(spike_trains), len(time)))
    for i, sp in enumerate(spike_trains):
        rates[i, :] = firing_rate(sp, duration, kernel_width, dt)
        
    return rates, time


def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
    """Retruns the firing rates and the spikes

    Args:
        block_map ([type]): [description]
        df ([type]): [description]
        contrast ([type]): [description]
        condition ([type]): [description]
        kernel_width (float, optional): [description]. Defaults to 0.0005.

    Returns:
        np.ndarray: the time vector.
        np.ndarray: the rates with the first dimension representing the trials.
        np.adarray: the spike trains.
    """
    block = block_map[(contrast, df, condition)]
    spikes = get_spikes(block)
    duration = float(block.metadata["stimulus parameter"]["duration"])
    dt = float(block.metadata["stimulus parameter"]["dt"])

    rates, time = get_rates(spikes, duration, dt, kernel_width)
    return time, rates, spikes


def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
    """Tries to detect the presence of a foreign fish by estimating the discriminability of the responses during the beat 
    versus the responses without another fish beeing there, i.e. the baseline activity.

    Applies a ROC analysis to the response segments between chirps. Calculates a) the distances between the baseline responses and
    b) distances between the baseline and beat responses. Tests whether distances in b) are larger than a) 
    Args:
        block_map ([type]): maps nix blocks to combination of stimulus parameters
        df ([type]): the difference frequency that should be used
        cs ([type]): ths chirpsize that should be used
        all_contrasts ([type]): list of all used contrasts
        all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other
        kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005.
        cell_name (str, optional): name of the cell. Defaults to "".
        store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False.

    Returns:
        list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1]. The 'detection_task' is 'beat'
    """
    detection_performances = [] 

    for contrast in all_contrasts:
        #  print(" " * 50, end="\r")
        #  print("Contrast: %.3f" % contrast, end="\r")
        no_other_block = block_map[(contrast, df, cs, "no-other")]
        self_block = block_map[(contrast, df, cs, "self")]
        
        # get some metadata assuming they are all the same for each condition, which they should
        duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
        
        interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
        interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
        ici  = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000

        # get the spiking responses
        no_other_spikes = get_spikes(no_other_block)
        self_spikes = get_spikes(self_block)
        
        # get firing rates
        no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
        self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
        
        # get the response snippets between chrips
        no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))  # section b, alone, no chirps
        self_snippets = np.zeros_like(no_other_snippets)  # section d, in company, no chirps, just beat
        for i in range(no_other_rates.shape[0]):
            for j, start in enumerate(interchirp_starts):
                start_index = int(start/dt)
                end_index = start_index + no_other_snippets.shape[1]
                index = i * len(interchirp_starts) + j
                no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
                self_snippets[index, :] = self_rates[i, start_index:end_index]
        
        # get the distances
        baseline_dist = within_group_distance(no_other_snippets)  
        comp_dist = across_group_distance(no_other_snippets, self_snippets)

        # sort and perfom ROC analysis
        triangle_indices = np.tril_indices_from(baseline_dist, -1)
        valid_distances_baseline = baseline_dist[triangle_indices]
        temp1 = np.zeros_like(valid_distances_baseline)

        valid_distances_comparison = comp_dist.ravel()
        temp2 = np.ones_like(valid_distances_comparison)

        group = np.hstack((temp1, temp2))
        score = np.hstack((valid_distances_baseline, valid_distances_comparison))          
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
    # print("\n")
    return detection_performances


def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
    """Tries to detect the presence of a foreign fish by estimating the discriminability of the chirp
    responses in the presence of another fish versus the responses without another fish beeing around.

    Applies a ROC analysis to the response segments containing the chirp. Does two discrimination tests:
    1) compares the responses to self-chirping alone to the responses to self-chriping in company.
    2) compares the responess to other-chirping to the response during the beat.

    Tests the assumptions that the distances a) between the self-chriping alone and self-chriping in company 
    are larger than the distances within the the self-chirping alone condition and b) the distances between
    other-chirping in company and no one is chirping in company (i.e. beat) are larger than the distances 
    within the beat responses.

    Args:
        block_map ([type]): maps nix blocks to combination of stimulus parameters
        df ([type]): the difference frequency that should be used
        cs ([type]): ths chirpsize that should be used
        all_contrasts ([type]): list of all used contrasts
        all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other
        kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005.
        cell_name (str, optional): name of the cell. Defaults to "".
        store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False.

    Returns:
        list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1]. 
        The 'detection_task' is either "self vs soliloquy" for 1) or "other vs quietness" for 2)

    """
    detection_performances = []

    for contrast in all_contrasts:
        # print(" " * 50, end="\r")
        # print("Contrast: %.3f" % contrast, end="\r")
        no_other_block = block_map[(contrast, df, cs, "no-other")]
        self_block = block_map[(contrast, df, cs, "self")]
        other_block = block_map[(contrast, df, cs, "self")]
        
        # get some metadata assuming they are all the same for each condition, which they should
        duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
                
        # get the spiking responses
        no_other_spikes = get_spikes(no_other_block)
        self_spikes = get_spikes(self_block)
        other_spikes = get_spikes(other_block)
        
        # get firing rates
        no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
        self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
        other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width)

        # get the chirp response snippets
        alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt)))  # section a, alone self-chirping
        self_snippets = np.zeros_like(alone_chirping_snippets)  # section c, self chirping in company
        other_snippets = np.zeros_like(alone_chirping_snippets)  # section e, other chirping in company
        silence_snippets = np.zeros_like(alone_chirping_snippets)  # section d, in company no one chirping

        for i in range(no_other_rates.shape[0]):
            for j, chirp_time in enumerate(chirp_times):
                start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt)
                end_index = start_index + alone_chirping_snippets.shape[1]
                index = i * len(chirp_times) + j
                alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
                self_snippets[index, :] = self_rates[i, start_index:end_index]
                other_snippets[index, :] = other_rates[i, start_index:end_index]
                silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
                silence_end_index = silence_start_index + alone_chirping_snippets.shape[1]
                silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index]
        
        # get the distances 
        # 1. Soliloquy
        # 2. Nobody chirps, all alone aka baseline response
        # 3. I chirp while the other is present compared to self chirping without the other one present
        # 4. the otherone chrips to me compared to baseline with anyone chirping
        alone_chirping_dist = within_group_distance(alone_chirping_snippets)                # within section a
        silence_dist = within_group_distance(silence_snippets)                              # within section d
        other_chirp_dist = within_group_distance(other_snippets)                            # within section e

        self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets)  # section a vs. section c 
        other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets)     # section d vs. section e
        self_other_chirp_dist = across_group_distance(self_snippets, other_snippets)        # section c vs. section e
        self_chirp_beat_dist = across_group_distance(self_snippets, silence_snippets)       # section c vs. section d
        alone_chirp_beat_dist = across_group_distance(alone_chirping_snippets, silence_snippets)  # section a vs. section d

        # sort and perfom ROC analysis for two comparisons
        # 1. soliloquy vs. self chirping in company
        # 2. other chirping vs. nobody is chirping
        triangle_indices = np.tril_indices_from(alone_chirping_dist, -1)
        valid_no_other_distances = alone_chirping_dist[triangle_indices]
        no_other_temp = np.zeros_like(valid_no_other_distances)

        valid_silence_distances = silence_dist[triangle_indices]
        silence_temp = np.zeros_like(valid_silence_distances)

        valid_other_chirp_distances = other_chirp_dist[triangle_indices]
        other_chirp_temp = np.zeros_like(valid_other_chirp_distances)

        valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
        self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)

        valid_other_vs_silence_distances = other_vs_silence_dist.ravel()
        other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances)

        valid_self_vs_other_chirp_distances = self_other_chirp_dist.ravel()
        self_vs_other_chirps_temp = np.ones_like(valid_self_vs_other_chirp_distances)

        valid_self_beat_distances = self_chirp_beat_dist.ravel()
        self_vs_beat_temp = np.ones_like(valid_self_beat_distances)

        valid_alone_chirp_beat_distance = alone_chirp_beat_dist.ravel()
        alone_chirp_beat_temp = np.ones_like(valid_alone_chirp_beat_distance)

        # Comparison 2: alone chirping (soliloquy) vs. self-chirping in company
        group = np.hstack((no_other_temp, self_vs_alone_temp))
        score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
        
        # Comparison 3: other fish chirping vs. beat
        group = np.hstack((silence_temp, other_vs_silence_temp))
        score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances))
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})

        # Comparison 4: soliloquy vs. beat
        group = np.hstack((no_other_temp, alone_chirp_beat_temp))
        score = np.hstack((valid_no_other_distances, valid_alone_chirp_beat_distance))
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})

        # Comparison 5: beat vs self-chirping in company
        group = np.hstack((silence_temp, self_vs_beat_temp))
        score = np.hstack((valid_silence_distances, valid_alone_chirp_beat_distance))
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
        
        # Comparison 6: self vs other-chirping in company
        group = np.hstack((other_chirp_temp, self_vs_other_chirps_temp))
        score = np.hstack((valid_other_chirp_distances, valid_self_vs_other_chirp_distances))
        auc = roc_auc_score(group, score)
        if store_roc:
            fpr, tpr, _ = roc_curve(group, score, pos_label=1)
            detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
    # print("\n")
    return detection_performances


def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes, current_df=None, current_chirpsize=None, cell_name="", store_roc=False):
    dfs = [current_df] if current_df is not None else all_dfs
    chirp_sizes = [current_chirpsize] if current_chirpsize is not None else all_chirpsizes
    kernels = [0.00025, 0.0005, 0.001, 0.0025]
    result_dicts = []
    for cs in chirp_sizes:
        for df in dfs:
            print("%s, chirp size: %i Hz, deltaf %.1f Hz" % (cell_name, cs, df))
            for kw in kernels:
                #print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw))
                #print("Foreign fish detection during beat:")
                result_dicts.extend(foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))
                #print("Foreign fish detection during chirp:")
                result_dicts.extend(foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))

    return result_dicts


def process_cell(filename):
    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
    block_map, all_contrasts, all_dfs, all_chirpsizes, all_conditions  = sort_blocks(nf)
    if "baseline" not in block_map.keys():
        print("ERROR: no baseline data for file %s!" % filename)
    results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes,
                                     current_df=None, current_chirpsize=None, 
                                     cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False)
    nf.close()
    return results


def main():
    num_cores = multiprocessing.cpu_count() - 6
    nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
    
    processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files)
    results = [] 
    for pr in processed_list:
        results.extend(pr)
    df = pd.DataFrame(results)
    df.to_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";")


if __name__ == "__main__":
    main()