import os
import glob
import pandas as pd
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from sklearn.metrics import roc_curve, roc_auc_score     
from IPython import embed

from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
figure_folder = "figures"
data_folder = "data"


def read_baseline(block):
    spikes = []
    if "baseline" not in block.name:
        print("Block %s does not appear to be a baseline block!" % block.name )
        return spikes
    spikes = block.data_arrays[0][:]
    return spikes


def sort_blocks(nix_file):
    block_map = {}
    contrasts = []
    deltafs = []
    conditions = []
    for b in nix_file.blocks:
        if "baseline" not in b.name.lower(): 
            name_parts = b.name.split("_")
            cntrst = float(name_parts[1])
            if cntrst not in contrasts:
                contrasts.append(cntrst)
            cndtn = name_parts[3]
            if cndtn not in conditions:
                conditions.append(cndtn)
            dltf = float(name_parts[5])
            if dltf not in deltafs:
                deltafs.append(dltf)
            block_map[(cntrst, dltf, cndtn)] = b
        else:
            block_map["baseline"] = b
    return block_map, contrasts, deltafs, conditions


def get_spikes(block):
    """Get the spike trains.

    Args:
        block ([type]): [description]

    Returns:
        list of np.ndarray: the spike trains.
    """
    response_map = {}
    spikes = []

    for da in block.data_arrays:
        if "spike_times" in da.type and "response" in da.name:
            resp_id = int(da.name.split("_")[-1])
            response_map[resp_id] = da
    for k in sorted(response_map.keys()):
        spikes.append(response_map[k][:])
   
    return spikes
    

def get_rates(spike_trains, duration, dt, kernel_width):
    """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.

    Args:
        spike_trains ([type]): [description]
        duration ([type]): [description]
        dt ([type]): [description]
        kernel_width ([type]): [description]

    Returns:
        np.ndarray: Matrix of firing rates, 1. dimension is the number of trials
        np.ndarray: the time vector
    """
    time = np.arange(0.0, duration, dt)
    rates = np.zeros((len(spike_trains), len(time)))
    for i, sp in enumerate(spike_trains):
        rates[i, :] = firing_rate(sp, duration, kernel_width, dt)
        
    return rates, time


def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
    """Retruns the firing rates and the spikes

    Args:
        block_map ([type]): [description]
        df ([type]): [description]
        contrast ([type]): [description]
        condition ([type]): [description]
        kernel_width (float, optional): [description]. Defaults to 0.0005.

    Returns:
        np.ndarray: the time vector.
        np.ndarray: the rates with the first dimension representing the trials.
        np.adarray: the spike trains.
    """
    block = block_map[(contrast, df, condition)]
    spikes = get_spikes(block)
    duration = float(block.metadata["stimulus parameter"]["duration"])
    dt = float(block.metadata["stimulus parameter"]["dt"])

    rates, time = get_rates(spikes, duration, dt, kernel_width)
    return time, rates, spikes


def get_signals(block):
    """Read the fish signals from block.

    Args:
        block ([type]): the block containing the data for a given df, contrast and condition

    Raises:
        ValueError: when the  complete stimulus data is not found
        ValueError: when the no-other animal data is not found

    Returns:
        np.ndarray: the complete signal
        np.ndarray: the frequency profile of the recorded fish
        np.ndarray: the frequency profile of the other fish
        np.ndarray: the time axis
    """
    self_freq = None
    other_freq = None
    signal = None
    time = None
    if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
        raise ValueError("Signals not stored in block!")
    if "no-other" not in block.name and "other frequency" not in block.data_arrays:
        raise ValueError("Signals not stored in block!")
    
    signal = block.data_arrays["complete stimulus"][:]
    time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
    self_freq = block.data_arrays["self frequency"][:]
    if "no-other" not in block.name:
        other_freq = block.data_arrays["other frequency"][:]
    return signal, self_freq, other_freq, time


def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
    conditions = ["no-other", "self", "other"]
    condition_labels = ["soliloquy", "self chirping", "other chirping"]
    min_time = 0.5
    max_time = min_time + 0.5

    fig = plt.figure(figsize=(6.5, 5.5))
    fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
    all_contrasts = sorted(all_contrasts, reverse=True)

    for i, condition in enumerate(conditions):
        # plot the signals
        block = block_map[(all_contrasts[0], current_df, condition)]
        signal, self_freq, other_freq, time = get_signals(block)
        am = extract_am(signal)
        
        self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
        other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
        
        # plot frequency traces
        ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
                color="#ff7f0e", label="%iHz" % self_eodf)
        ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
        if other_freq is not None:
            ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
                    color="#1f77b4", label="%iHz" % other_eodf)
            ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)       
        ax.set_title(condition_labels[i])
        despine(ax, ["top", "bottom", "left", "right"], True)
        
        # plot the am
        ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
                color="#2ca02c", label="signal")
        ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
                color="#d62728", label="am")
        despine(ax, ["top", "bottom", "left", "right"], True)
        ax.set_ylim([-1.25, 1.25])
        ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
   
        # for each contrast plot the firing rate
        for j, contrast in enumerate(all_contrasts):
            t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
            avg_resp = np.mean(rates, axis=0)
            error = np.std(rates, axis=0)
            ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
            ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
            ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
                            (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
            ax.set_ylim([0, 750])
            ax.set_xlabel("")
            ax.set_ylabel("")
            ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
            ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
            ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
            if j < len(all_contrasts) -1:
                ax.set_xticklabels([])
            ax.set_yticks(np.arange(0.0, 751., 500)) 
            ax.set_yticks(np.arange(0.0, 751., 125), minor=True)   
            if i > 0:
                ax.set_yticklabels([])
            despine(ax, ["top", "right"], False)
            if i == 2:
                ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], 
                        color="#d62728", ha="left", fontsize=7)

        if i == 1:
            ax.set_xlabel("time [ms]")
        if i == 0:
            ax.set_ylabel("frequency [Hz]", va="center")
            ax.yaxis.set_label_coords(-0.45, 3.5)
        
    name = figure_name if figure_name is not None else "chirp_responses.pdf"
    name = (name + ".pdf") if ".pdf" not in name else name
    plt.savefig(os.path.join(figure_folder, name))
    plt.close()


def get_chirp_metadata(block):
    trial_duration = float(block.metadata["stimulus parameter"]["duration"])
    dt = float(block.metadata["stimulus parameter"]["dt"])
    chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
    chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
    chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
        
    return trial_duration, dt, chirp_size, chirp_duration, chirp_times


def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
    detection_performances = [] 

    for contrast in all_contrasts:
        print(" " * 50, end="\r")
        print("Contrast: %.3f" % contrast, end="\r")
        no_other_block = block_map[(contrast, df, "no-other")]
        self_block = block_map[(contrast, df, "self")]
        
        # get some metadata assuming they are all the same for each condition, which they should
        duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
        
        interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
        interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
        ici  = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000

        # get the spiking responses
        no_other_spikes = get_spikes(no_other_block)
        self_spikes = get_spikes(self_block)
        
        # get firing rates
        no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
        self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
        
        # get the response snippets between chrips
        no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
        self_snippets = np.zeros_like(no_other_snippets)
        for i in range(no_other_rates.shape[0]):
            for j, start in enumerate(interchirp_starts):
                start_index = int(start/dt)
                end_index = start_index + no_other_snippets.shape[1]
                index = i * len(interchirp_starts) + j
                no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
                self_snippets[index, :] = self_rates[i, start_index:end_index]
        
        # get the distances
        baseline_dist = within_group_distance(no_other_snippets)  
        comp_dist = across_group_distance(no_other_snippets, self_snippets)

        # sort and perfom ROC analysis
        triangle_indices = np.tril_indices_from(baseline_dist, -1)
        valid_distances_baseline = baseline_dist[triangle_indices]
        temp1 = np.zeros_like(valid_distances_baseline)

        valid_distances_comparison = comp_dist.ravel()
        temp2 = np.ones_like(valid_distances_comparison)

        group = np.hstack((temp1, temp2))
        score = np.hstack((valid_distances_baseline, valid_distances_comparison))
        fpr, tpr, _ = roc_curve(group, score, pos_label=1)
        auc = roc_auc_score(group, score)
        if store_roc:
            detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc})
    print("\n")
    return detection_performances


def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
    detection_performances = []

    for contrast in all_contrasts:
        print(" " * 50, end="\r")
        print("Contrast: %.3f" % contrast, end="\r")
        no_other_block = block_map[(contrast, df, "no-other")]
        self_block = block_map[(contrast, df, "self")]
        other_block = block_map[(contrast, df, "self")]
        
        # get some metadata assuming they are all the same for each condition, which they should
        duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
                
        # get the spiking responses
        no_other_spikes = get_spikes(no_other_block)
        self_spikes = get_spikes(self_block)
        other_spikes = get_spikes(other_block)
        
        # get firing rates
        no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
        self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
        other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width)

        # get the chirp response snippets
        alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt)))
        self_snippets = np.zeros_like(alone_chirping_snippets)
        other_snippets = np.zeros_like(alone_chirping_snippets)
        silence_snippets = np.zeros_like(alone_chirping_snippets)

        for i in range(no_other_rates.shape[0]):
            for j, chirp_time in enumerate(chirp_times):
                start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt)
                end_index = start_index + alone_chirping_snippets.shape[1]
                index = i * len(chirp_times) + j
                alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
                self_snippets[index, :] = self_rates[i, start_index:end_index]
                other_snippets[index, :] = other_rates[i, start_index:end_index]
                silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
                silence_end_index = silence_start_index + alone_chirping_snippets.shape[1]
                silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index]
        
        # get the distances 
        # 1. Soliloquy
        # 2. Nobody chirps, all alone aka baseline response
        # 3. I chirp while the other is present compared to self chirping without the other one present
        # 4. the otherone chrips to me compared to baseline with anyone chirping
        alone_chirping_dist = within_group_distance(alone_chirping_snippets)  
        silence_dist = within_group_distance(silence_snippets)
        self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets)
        other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets)

        # sort and perfom ROC analysis for two comparisons
        # 1. soliloquy vs. self chirping in company
        # 2. other chirping vs. nobody is chirping
        triangle_indices = np.tril_indices_from(alone_chirping_dist, -1)
        valid_no_other_distances = alone_chirping_dist[triangle_indices]
        no_other_temp = np.zeros_like(valid_no_other_distances)

        valid_silence_distances = silence_dist[triangle_indices]
        silence_temp = np.zeros_like(valid_silence_distances)

        valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
        self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)

        valid_other_vs_silence_distances = other_vs_silence_dist.ravel()
        other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances)

        group = np.hstack((no_other_temp, self_vs_alone_temp))
        score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
        fpr, tpr, _ = roc_curve(group, score, pos_label=1)
        auc = roc_auc_score(group, score)
        if store_roc:
            detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc})
        group = np.hstack((silence_temp, other_vs_silence_temp))
        score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances))
        fpr, tpr, _ = roc_curve(group, score, pos_label=1)
        auc = roc_auc_score(group, score)
        if store_roc:
            detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
        else:
            detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc})

    print("\n")
    return detection_performances


def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None):
    cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
    conditions = sorted(cell_results.detection_task.unique())
    kernels = sorted(cell_results.kernel_width.unique())

    fig = plt.figure(figsize=(6.5, 5.5))
    fig_grid = (8, 7)
    for i, c in enumerate(conditions):
        condition_results = cell_results[cell_results.detection_task == c]
        roc_data =  condition_results[condition_results.kernel_width == kernel_width]
        contrasts = roc_data.contrast.unique()
        
        roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
        roc_ax.set_title(c, fontsize=9, loc="left")
        auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
        
        for c in contrasts:
            tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
            fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
            roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
        if i == 0:
            roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
        roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
        roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
        roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
        roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
        roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
        if i == len(conditions) - 1:
            roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
            roc_ax.set_xlabel("false positive rate", fontsize=9)
        roc_ax.set_ylabel("true positive rate", fontsize=9)
        roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)

        for k in kernels:
            contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
            aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
            aucs_sorted = aucs[np.argsort(contrasts)]
            contrasts_sorted = np.sort(contrasts)
            auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
        if i == len(conditions) - 1:
            auc_ax.set_xlabel("contrast [%]", fontsize=9)
        auc_ax.set_ylim([0.25, 1.0])
        auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
        auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
        auc_ax.set_ylabel("discriminability", fontsize=9)
        if i == 0:
            auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
        auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
    name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
    name = (name + ".pdf") if ".pdf" not in name else name
    fig.savefig(os.path.join(figure_folder, name))


def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df):
    conditions = ["no-other", "self", "other"]
    condition_labels = ["soliloquy", "self chirping", "other chirping"]
    min_time = 0.5
    max_time = min_time + 0.5

    fig = plt.figure(figsize=(6.5, 2.))
    fig_grid = (3, len(all_conditions)*3+2)
    axes = []
    for i, condition in enumerate(conditions):
        # plot the signals
        block = block_map[(all_contrasts[0], current_df, condition)]
        signal, self_freq, other_freq, time = get_signals(block)
        
        self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
        other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
        
        # plot frequency traces
        ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
                color="#ff7f0e", label="%iHz" % self_eodf)
        ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
        if other_freq is not None:
            ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
                    color="#1f77b4", label="%iHz" % other_eodf)
            ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)       
        # ax.set_title(condition_labels[i])
        ax.set_ylim([735, 885])
        despine(ax, ["top", "bottom", "left", "right"], True)
        axes.append(ax)

    rects = []
    rect = Rectangle((0.675, 740), 0.098, 140)
    rects.append(rect)
    rect = Rectangle((0.57, 740), 0.098, 140)
    rects.append(rect)
       
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[0].add_collection(pc)
        
    rects = []
    rect = Rectangle((0.675, 740), 0.098, 140)
    rects.append(rect)
    rect = Rectangle((0.575, 740), 0.098, 140)
    rects.append(rect)
       
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[1].add_collection(pc)

    rects = []
    rect = Rectangle((0.57, 740), 0.098, 140)
    rects.append(rect)   
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[2].add_collection(pc)

    con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
    axes[1].add_artist(con)
    con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
    axes[1].add_artist(con)
    con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
                          axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
    axes[1].add_artist(con)

    axes[0].text(1., 660, "2.")
    axes[1].text(1.05, 660, "3.")
    axes[0].text(1.1, 890, "1.")
    fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
    fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
    plt.close()


def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
    dfs = [current_df] if current_df is not None else all_dfs 
    kernels = [0.00025, 0.0005, 0.001, 0.0025]
    result_dicts = []
    for df in dfs:
        for kw in kernels:
            print("df: %i, kernel: %.4f" % (df, kw))
            print("Foreign fish detection during beat:")
            result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))
            print("Foreign fish detection during chirp:")
            result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))

    return result_dicts


def estimate_chirp_phase(am, chirp_times):
    
    pass


def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
    block_map, all_contrasts, all_dfs, all_conditions  = sort_blocks(nf)
    if "baseline" in block_map.keys():
        baseline_spikes = read_baseline(block_map["baseline"])
    else:
        print("ERROR: no baseline data for file %s!" % filename)
    results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, 
                                     cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False)
    nf.close()
    return results


def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
    block_map, all_contrasts, all_dfs, all_conditions  = sort_blocks(nf)
    if "baseline" in block_map.keys():
        baseline_spikes = read_baseline(block_map["baseline"])
    else:
        print("ERROR: no baseline data for file %s!" % filename)
    
    # plot the responses 
    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
    
    # sketch showing the comparisons
    # plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)

    # plot the discrimination analyses
    #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
    # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, 
    #                                  cell_name=cell_name, store_roc=True)
    # pdf = pd.DataFrame(results)
    # plot_detection_results(pdf, 20, 0.001, cell_name)
    
    nf.close()    


def main():
    nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
    for nix_file in nix_files:
        #plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
        results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"])
        # break
    embed()


if __name__ == "__main__":
    main()