larger stimulus range, add didactic figure for

This commit is contained in:
Jan Grewe 2020-09-24 17:27:36 +02:00
parent bdbab448ed
commit 511fddbeb2
3 changed files with 135 additions and 38 deletions

View File

@ -183,7 +183,7 @@ def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
def main():
models = load_models("models.csv")
deltafs = [-200, -100, -20, 20, 100, 200] # Hz, difference frequency between self and other
deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200] # Hz, difference frequency between self and other
stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
"contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
"chirp_size": 100, # Hz, frequency excursion
@ -210,7 +210,8 @@ def main():
1./stimulus_params["chirp_frequency"])
stimulus_params["chirp_times"] = chirp_times
simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
exit() # the first cell only for now!
if cell_id == 9:
exit() # the first 10 cell only for now!
if __name__ == "__main__":

View File

@ -3,8 +3,10 @@ import glob
import pandas as pd
import nixio as nix
import numpy as np
import scipy.signal as sig
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed
@ -320,7 +322,7 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt)))
self_snippets = np.zeros_like(alone_chirping_snippets)
other_snippets = np.zeros_like(alone_chirping_snippets)
baseline_snippets = np.zeros_like(alone_chirping_snippets)
silence_snippets = np.zeros_like(alone_chirping_snippets)
for i in range(no_other_rates.shape[0]):
for j, chirp_time in enumerate(chirp_times):
@ -330,9 +332,9 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
other_snippets[index, :] = other_rates[i, start_index:end_index]
baseline_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
baseline_end_index = baseline_start_index + alone_chirping_snippets.shape[1]
baseline_snippets[index, :] = no_other_rates[i, baseline_start_index:baseline_end_index]
silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
silence_end_index = silence_start_index + alone_chirping_snippets.shape[1]
silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index]
# get the distances
# 1. Soliloquy
@ -340,9 +342,9 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
# 3. I chirp while the other is present compared to self chirping without the other one present
# 4. the otherone chrips to me compared to baseline with anyone chirping
alone_chirping_dist = within_group_distance(alone_chirping_snippets)
baseline_dist = within_group_distance(baseline_snippets)
silence_dist = within_group_distance(silence_snippets)
self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets)
other_vs_baseline_dist = across_group_distance(baseline_snippets, other_snippets)
other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets)
# sort and perfom ROC analysis for two comparisons
# 1. soliloquy vs. self chirping in company
@ -351,14 +353,14 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
valid_no_other_distances = alone_chirping_dist[triangle_indices]
no_other_temp = np.zeros_like(valid_no_other_distances)
valid_baseline_distances = baseline_dist[triangle_indices]
baseline_temp = np.zeros_like(valid_baseline_distances)
valid_silence_distances = silence_dist[triangle_indices]
silence_temp = np.zeros_like(valid_silence_distances)
valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)
valid_other_vs_baseline_distances = other_vs_baseline_dist.ravel()
other_vs_baseline_temp = np.ones_like(valid_other_vs_baseline_distances)
valid_other_vs_silence_distances = other_vs_silence_dist.ravel()
other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances)
group = np.hstack((no_other_temp, self_vs_alone_temp))
score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
@ -368,8 +370,8 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc})
group = np.hstack((baseline_temp, other_vs_baseline_temp))
score = np.hstack((valid_baseline_distances, valid_other_vs_baseline_distances))
group = np.hstack((silence_temp, other_vs_silence_temp))
score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
if store_roc:
@ -394,14 +396,15 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
contrasts = roc_data.contrast.unique()
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
roc_ax.set_title(c, fontsize=9, ha="left")
roc_ax.set_title(c, fontsize=9, loc="left")
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
roc_ax.legend(loc="best", fontsize=6, ncol=2, frameon=False)
if i == 0:
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
@ -418,23 +421,95 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.4f" % k, zorder=1)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
if i == len(conditions) - 1:
auc_ax.set_xlabel("contrast [%]", fontsize=9)
else:
auc_ax.set_xticklabels("")
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
auc_ax.set_ylabel("discriminability", fontsize=9)
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
if i == 0:
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
fig.savefig(os.path.join(figure_folder, name))
def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 2.))
fig_grid = (3, len(all_conditions)*3+2)
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
# ax.set_title(condition_labels[i])
ax.set_ylim([735, 885])
despine(ax, ["top", "bottom", "left", "right"], True)
axes.append(ax)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[0].add_collection(pc)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.575, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[1].add_collection(pc)
rects = []
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[2].add_collection(pc)
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
axes[0].text(1., 660, "2.")
axes[1].text(1.05, 660, "3.")
axes[0].text(1.1, 890, "1.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
plt.close()
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=""):
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
dfs = [current_df] if current_df is not None else all_dfs
kernels = [0.00025, 0.0005, 0.001, 0.0025]
result_dicts = []
@ -442,15 +517,10 @@ def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, cu
for kw in kernels:
print("df: %i, kernel: %.4f" % (df, kw))
print("Foreign fish detection during beat:")
result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name))
result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))
print("Foreign fish detection during chirp:")
result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name))
result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc))
break
embed()
return result_dicts
@ -466,21 +536,46 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
cell_name=filename.split(os.path.sep)[-1].split(".nix")[0])
results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None,
cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False)
nf.close()
return results
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
nf.close()
# plot the responses
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
# sketch showing the comparisons
# plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
# plot the discrimination analyses
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
# cell_name=cell_name, store_roc=True)
# pdf = pd.DataFrame(results)
# plot_detection_results(pdf, 20, 0.001, cell_name)
nf.close()
def main():
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
for nix_file in nix_files:
process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
#plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"])
# break
embed()
if __name__ == "__main__":

View File

@ -1,5 +1,6 @@
from typing import ValuesView
import numpy as np
import scipy.signal as sig
from numpy.lib.function_base import iterable
from numpy.lib.index_tricks import diag_indices