From 511fddbeb23fd1a45404b5e969f6d34907d38d46 Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Thu, 24 Sep 2020 17:27:36 +0200
Subject: [PATCH] larger stimulus range, add didactic figure for

---
 punit_responses.py           |   5 +-
 response_discriminability.py | 167 +++++++++++++++++++++++++++--------
 util.py                      |   1 +
 3 files changed, 135 insertions(+), 38 deletions(-)

diff --git a/punit_responses.py b/punit_responses.py
index efa2be2..090cd1e 100644
--- a/punit_responses.py
+++ b/punit_responses.py
@@ -183,7 +183,7 @@ def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
 
 def main():
     models = load_models("models.csv")
-    deltafs = [-200, -100, -20, 20, 100, 200]  # Hz, difference frequency between self and other
+    deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200]  # Hz, difference frequency between self and other
     stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
                         "contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
                         "chirp_size": 100,  # Hz, frequency excursion
@@ -210,7 +210,8 @@ def main():
                                     1./stimulus_params["chirp_frequency"])
             stimulus_params["chirp_times"] = chirp_times     
             simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
-        exit() # the first cell only for now!
+        if cell_id == 9:
+            exit() # the first 10 cell only for now!
 
 
 if __name__ == "__main__":
diff --git a/response_discriminability.py b/response_discriminability.py
index 0401569..01e4bff 100644
--- a/response_discriminability.py
+++ b/response_discriminability.py
@@ -3,8 +3,10 @@ import glob
 import pandas as pd
 import nixio as nix
 import numpy as np
-import scipy.signal as sig 
 import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import ConnectionPatch
 from sklearn.metrics import roc_curve, roc_auc_score     
 from IPython import embed
 
@@ -320,7 +322,7 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
         alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt)))
         self_snippets = np.zeros_like(alone_chirping_snippets)
         other_snippets = np.zeros_like(alone_chirping_snippets)
-        baseline_snippets = np.zeros_like(alone_chirping_snippets)
+        silence_snippets = np.zeros_like(alone_chirping_snippets)
 
         for i in range(no_other_rates.shape[0]):
             for j, chirp_time in enumerate(chirp_times):
@@ -330,9 +332,9 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
                 alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
                 self_snippets[index, :] = self_rates[i, start_index:end_index]
                 other_snippets[index, :] = other_rates[i, start_index:end_index]
-                baseline_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
-                baseline_end_index = baseline_start_index + alone_chirping_snippets.shape[1]
-                baseline_snippets[index, :] = no_other_rates[i, baseline_start_index:baseline_end_index]
+                silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
+                silence_end_index = silence_start_index + alone_chirping_snippets.shape[1]
+                silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index]
         
         # get the distances 
         # 1. Soliloquy
@@ -340,9 +342,9 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
         # 3. I chirp while the other is present compared to self chirping without the other one present
         # 4. the otherone chrips to me compared to baseline with anyone chirping
         alone_chirping_dist = within_group_distance(alone_chirping_snippets)  
-        baseline_dist = within_group_distance(baseline_snippets)
+        silence_dist = within_group_distance(silence_snippets)
         self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets)
-        other_vs_baseline_dist = across_group_distance(baseline_snippets, other_snippets)
+        other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets)
 
         # sort and perfom ROC analysis for two comparisons
         # 1. soliloquy vs. self chirping in company
@@ -351,14 +353,14 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
         valid_no_other_distances = alone_chirping_dist[triangle_indices]
         no_other_temp = np.zeros_like(valid_no_other_distances)
 
-        valid_baseline_distances = baseline_dist[triangle_indices]
-        baseline_temp = np.zeros_like(valid_baseline_distances)
+        valid_silence_distances = silence_dist[triangle_indices]
+        silence_temp = np.zeros_like(valid_silence_distances)
 
         valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
         self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)
 
-        valid_other_vs_baseline_distances = other_vs_baseline_dist.ravel()
-        other_vs_baseline_temp = np.ones_like(valid_other_vs_baseline_distances)
+        valid_other_vs_silence_distances = other_vs_silence_dist.ravel()
+        other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances)
 
         group = np.hstack((no_other_temp, self_vs_alone_temp))
         score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
@@ -368,8 +370,8 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
             detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
         else:
             detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc})
-        group = np.hstack((baseline_temp, other_vs_baseline_temp))
-        score = np.hstack((valid_baseline_distances, valid_other_vs_baseline_distances))
+        group = np.hstack((silence_temp, other_vs_silence_temp))
+        score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances))
         fpr, tpr, _ = roc_curve(group, score, pos_label=1)
         auc = roc_auc_score(group, score)
         if store_roc:
@@ -394,14 +396,15 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
         contrasts = roc_data.contrast.unique()
         
         roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
-        roc_ax.set_title(c, fontsize=9, ha="left")
+        roc_ax.set_title(c, fontsize=9, loc="left")
         auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
         
         for c in contrasts:
             tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
             fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
             roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
-        roc_ax.legend(loc="best", fontsize=6, ncol=2, frameon=False)
+        if i == 0:
+            roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
         roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
         roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
         roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
@@ -418,23 +421,95 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
             aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
             aucs_sorted = aucs[np.argsort(contrasts)]
             contrasts_sorted = np.sort(contrasts)
-            auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.4f" % k, zorder=1)
+            auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
         if i == len(conditions) - 1:
             auc_ax.set_xlabel("contrast [%]", fontsize=9)
-        else:
-            auc_ax.set_xticklabels("")
         auc_ax.set_ylim([0.25, 1.0])
         auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
         auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
         auc_ax.set_ylabel("discriminability", fontsize=9)
-        auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
+        if i == 0:
+            auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
         auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
     name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
     name = (name + ".pdf") if ".pdf" not in name else name
-    plt.savefig(os.path.join(figure_folder, name))
+    fig.savefig(os.path.join(figure_folder, name))
+
+
+def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df):
+    conditions = ["no-other", "self", "other"]
+    condition_labels = ["soliloquy", "self chirping", "other chirping"]
+    min_time = 0.5
+    max_time = min_time + 0.5
+
+    fig = plt.figure(figsize=(6.5, 2.))
+    fig_grid = (3, len(all_conditions)*3+2)
+    axes = []
+    for i, condition in enumerate(conditions):
+        # plot the signals
+        block = block_map[(all_contrasts[0], current_df, condition)]
+        signal, self_freq, other_freq, time = get_signals(block)
+        
+        self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
+        other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
+        
+        # plot frequency traces
+        ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
+        ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
+                color="#ff7f0e", label="%iHz" % self_eodf)
+        ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
+        if other_freq is not None:
+            ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
+                    color="#1f77b4", label="%iHz" % other_eodf)
+            ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)       
+        # ax.set_title(condition_labels[i])
+        ax.set_ylim([735, 885])
+        despine(ax, ["top", "bottom", "left", "right"], True)
+        axes.append(ax)
+
+    rects = []
+    rect = Rectangle((0.675, 740), 0.098, 140)
+    rects.append(rect)
+    rect = Rectangle((0.57, 740), 0.098, 140)
+    rects.append(rect)
+       
+    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
+    axes[0].add_collection(pc)
+        
+    rects = []
+    rect = Rectangle((0.675, 740), 0.098, 140)
+    rects.append(rect)
+    rect = Rectangle((0.575, 740), 0.098, 140)
+    rects.append(rect)
+       
+    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
+    axes[1].add_collection(pc)
+
+    rects = []
+    rect = Rectangle((0.57, 740), 0.098, 140)
+    rects.append(rect)   
+    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
+    axes[2].add_collection(pc)
+
+    con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
+                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
+    axes[1].add_artist(con)
+    con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
+                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
+    axes[1].add_artist(con)
+    con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
+                          axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
+    axes[1].add_artist(con)
+
+    axes[0].text(1., 660, "2.")
+    axes[1].text(1.05, 660, "3.")
+    axes[0].text(1.1, 890, "1.")
+    fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
+    fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
+    plt.close()
 
 
-def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=""):
+def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
     dfs = [current_df] if current_df is not None else all_dfs 
     kernels = [0.00025, 0.0005, 0.001, 0.0025]
     result_dicts = []
@@ -442,15 +517,10 @@ def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, cu
         for kw in kernels:
             print("df: %i, kernel: %.4f" % (df, kw))
             print("Foreign fish detection during beat:")
-            result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name))
+            result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))
             print("Foreign fish detection during chirp:")
-            result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name))
-        
+            result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))
 
-        break
-
-    embed()
- 
     return result_dicts
 
 
@@ -466,21 +536,46 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
         baseline_spikes = read_baseline(block_map["baseline"])
     else:
         print("ERROR: no baseline data for file %s!" % filename)
-    # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
-    # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
-    # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
-    # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
-    results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, 
-                                     cell_name=filename.split(os.path.sep)[-1].split(".nix")[0])
+    results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, 
+                                     cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False)
+    nf.close()
+    return results
 
+
+def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
+    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
+    block_map, all_contrasts, all_dfs, all_conditions  = sort_blocks(nf)
+    if "baseline" in block_map.keys():
+        baseline_spikes = read_baseline(block_map["baseline"])
+    else:
+        print("ERROR: no baseline data for file %s!" % filename)
     
-    nf.close()
+    # plot the responses 
+    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
+    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
+    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
+    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
+    
+    # sketch showing the comparisons
+    # plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
+
+    # plot the discrimination analyses
+    #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
+    # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, 
+    #                                  cell_name=cell_name, store_roc=True)
+    # pdf = pd.DataFrame(results)
+    # plot_detection_results(pdf, 20, 0.001, cell_name)
+    
+    nf.close()    
 
 
 def main():
     nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
     for nix_file in nix_files:
-        process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
+        #plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
+        results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"])
+        # break
+    embed()
 
 
 if __name__ == "__main__":
diff --git a/util.py b/util.py
index 4b5192a..f2de0b2 100644
--- a/util.py
+++ b/util.py
@@ -1,5 +1,6 @@
 from typing import ValuesView
 import numpy as np
+import scipy.signal as sig 
 from numpy.lib.function_base import iterable
 from numpy.lib.index_tricks import diag_indices