diff --git a/plots.py b/plots.py index 6b43073..6512b75 100644 --- a/plots.py +++ b/plots.py @@ -38,7 +38,7 @@ def plot_comparisons(args): axes = [] for i, condition in enumerate(conditions): # plot the signals - block = block_map[(all_contrasts[0], args.current_df, condition)] + block = block_map[(all_contrasts[0], args.deltaf, args.chirpsize, condition)] _, self_freq, other_freq, time = get_signals(block) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] @@ -54,47 +54,70 @@ def plot_comparisons(args): color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) # ax.set_title(condition_labels[i]) - ax.set_ylim([735, 885]) + ax.set_ylim([735, 895]) despine(ax, ["top", "bottom", "left", "right"], True) axes.append(ax) rects = [] - rect = Rectangle((0.675, 740), 0.098, 140) + rect = Rectangle((0.675, 740), 0.098, 150) rects.append(rect) - rect = Rectangle((0.57, 740), 0.098, 140) + rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[0].add_collection(pc) + axes[0].text(0.625, 860, "a)", ha="center", fontsize=7) + axes[0].text(0.724, 860, "b)", ha="center", fontsize=7) rects = [] - rect = Rectangle((0.675, 740), 0.098, 140) + rect = Rectangle((0.675, 740), 0.098, 150) rects.append(rect) - rect = Rectangle((0.575, 740), 0.098, 140) + rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[1].add_collection(pc) - + axes[1].text(0.625, 860, "c)", ha="center", fontsize=7) + axes[1].text(0.724, 860, "d)", ha="center", fontsize=7) + rects = [] - rect = Rectangle((0.57, 740), 0.098, 140) + rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[2].add_collection(pc) + axes[2].text(0.625, 860, "e)", ha="center", fontsize=7) + con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) - con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data", - axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") + con = ConnectionPatch(xyA=(0.725, 895), xyB=(0.725, 890), coordsA="data", coordsB="data", + axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, + connectionstyle="arc3,rad=-.25") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) + con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.725, 890), coordsA="data", coordsB="data", + axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, + connectionstyle="arc3,rad=-.35") + axes[1].add_artist(con) - axes[0].text(1., 660, "2.") - axes[1].text(1.05, 660, "3.") - axes[0].text(1.1, 890, "1.") + con = ConnectionPatch(xyA=(0.615, 735), xyB=(0.735, 745), coordsA="data", coordsB="data", + axesA=axes[1], axesB=axes[1], arrowstyle="<->", shrinkB=5, + connectionstyle="arc3,rad=1.") + axes[1].add_artist(con) + + con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.625, 895), coordsA="data", coordsB="data", + axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, + connectionstyle="arc3,rad=-.35") + axes[1].add_artist(con) + axes[0].text(1., 655, "2.") + axes[1].text(1.05, 655, "3.") + axes[0].text(1.1, 895, "1.") + axes[0].text(0.6, 925, "4.") + axes[1].text(0.675, 680, "5.") + axes[1].text(1.05, 895, "6.") fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) fig.savefig(args.outfile) plt.close() @@ -333,7 +356,8 @@ def main(): help="Sub commands for plotting different figures", description="", dest="explore_cmd") comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons") - comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting") + comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting. Defaults to 20 Hz") + comp_parser.add_argument("-cs", "--chirpsize", type=int, default=60, help="The chirpsize. Defaults to 60 Hz.") comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot") comp_parser.set_defaults(func=plot_comparisons) diff --git a/punit_responses.py b/punit_responses.py index f360bb6..208169d 100644 --- a/punit_responses.py +++ b/punit_responses.py @@ -220,7 +220,7 @@ def main(): parser = argparse.ArgumentParser(description="Simulate P-unit responses using the model parameters from the models.csv file. Calling it without any arguments works with the defaults, may need some time.") parser.add_argument("-n", "--number", type=int, default=20, help="Number of simulated neurons. Randomly chosen from model list. Defaults to 20") parser.add_argument("-t", "--trials", type=int, default=25, help="Number of stimulus repetitions, trials. Defaults to 25") - parser.add_argument("-dfs", "--deltafs", type=float, nargs="+", default=[-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200], + parser.add_argument("-dfs", "--deltafs", type=float, nargs="+", default=[-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 150, 200, 250, 300, 350, 400, 500], help="List of difference frequencies. Defaults to [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200]") parser.add_argument("-cs", "--chirpsizes", type=float, nargs="+", default=[40, 60, 100], help="List of chirp sizes. Defaults to [40, 60, 100]") @@ -232,13 +232,13 @@ def main(): models = load_models("models.csv") num_models = args.number - if args.number > len(models): + if num_models > len(models): print("INFO: number of cells larger than number of available models. Reset to max number of models.") num_models = len(models) indices = list(range(len(models))) np.random.shuffle(indices) - Parallel(n_jobs=args.jobs)(delayed(simulate_cell)(cell_id, models, args) for cell_id in indices) + Parallel(n_jobs=args.jobs)(delayed(simulate_cell)(cell_id, models, args) for cell_id in indices[:num_models]) if __name__ == "__main__": diff --git a/response_discriminability.py b/response_discriminability.py index 4dfb449..995abd9 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -103,8 +103,8 @@ def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips - no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) - self_snippets = np.zeros_like(no_other_snippets) + no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) # section b, alone, no chirps + self_snippets = np.zeros_like(no_other_snippets) # section d, in company, no chirps, just beat for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) @@ -188,10 +188,10 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width) # get the chirp response snippets - alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) - self_snippets = np.zeros_like(alone_chirping_snippets) - other_snippets = np.zeros_like(alone_chirping_snippets) - silence_snippets = np.zeros_like(alone_chirping_snippets) + alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) # section a, alone self-chirping + self_snippets = np.zeros_like(alone_chirping_snippets) # section c, self chirping in company + other_snippets = np.zeros_like(alone_chirping_snippets) # section e, other chirping in company + silence_snippets = np.zeros_like(alone_chirping_snippets) # section d, in company no one chirping for i in range(no_other_rates.shape[0]): for j, chirp_time in enumerate(chirp_times): @@ -210,10 +210,15 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition # 2. Nobody chirps, all alone aka baseline response # 3. I chirp while the other is present compared to self chirping without the other one present # 4. the otherone chrips to me compared to baseline with anyone chirping - alone_chirping_dist = within_group_distance(alone_chirping_snippets) - silence_dist = within_group_distance(silence_snippets) - self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) - other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) + alone_chirping_dist = within_group_distance(alone_chirping_snippets) # within section a + silence_dist = within_group_distance(silence_snippets) # within section d + other_chirp_dist = within_group_distance(other_snippets) # within section e + + self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) # section a vs. section c + other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) # section d vs. section e + self_other_chirp_dist = across_group_distance(self_snippets, other_snippets) # section c vs. section e + self_chirp_beat_dist = across_group_distance(self_snippets, silence_snippets) # section c vs. section d + alone_chirp_beat_dist = across_group_distance(alone_chirping_snippets, silence_snippets) # section a vs. section d # sort and perfom ROC analysis for two comparisons # 1. soliloquy vs. self chirping in company @@ -225,12 +230,25 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition valid_silence_distances = silence_dist[triangle_indices] silence_temp = np.zeros_like(valid_silence_distances) + valid_other_chirp_distances = other_chirp_dist[triangle_indices] + other_chirp_temp = np.zeros_like(valid_other_chirp_distances) + valid_self_vs_alone_distances = self_vs_alone_dist.ravel() self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances) valid_other_vs_silence_distances = other_vs_silence_dist.ravel() other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances) + valid_self_vs_other_chirp_distances = self_other_chirp_dist.ravel() + self_vs_other_chirps_temp = np.ones_like(valid_self_vs_other_chirp_distances) + + valid_self_beat_distances = self_chirp_beat_dist.ravel() + self_vs_beat_temp = np.ones_like(valid_self_beat_distances) + + valid_alone_chirp_beat_distance = alone_chirp_beat_dist.ravel() + alone_chirp_beat_temp = np.ones_like(valid_alone_chirp_beat_distance) + + # Comparison 2: alone chirping (soliloquy) vs. self-chirping in company group = np.hstack((no_other_temp, self_vs_alone_temp)) score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances)) auc = roc_auc_score(group, score) @@ -239,6 +257,8 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) + + # Comparison 3: other fish chirping vs. beat group = np.hstack((silence_temp, other_vs_silence_temp)) score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances)) auc = roc_auc_score(group, score) @@ -248,6 +268,35 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition else: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) + # Comparison 4: soliloquy vs. beat + group = np.hstack((no_other_temp, alone_chirp_beat_temp)) + score = np.hstack((valid_no_other_distances, valid_alone_chirp_beat_distance)) + auc = roc_auc_score(group, score) + if store_roc: + fpr, tpr, _ = roc_curve(group, score, pos_label=1) + detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) + else: + detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) + + # Comparison 5: beat vs self-chirping in company + group = np.hstack((silence_temp, self_vs_beat_temp)) + score = np.hstack((valid_silence_distances, valid_alone_chirp_beat_distance)) + auc = roc_auc_score(group, score) + if store_roc: + fpr, tpr, _ = roc_curve(group, score, pos_label=1) + detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) + else: + detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) + + # Comparison 6: self vs other-chirping in company + group = np.hstack((other_chirp_temp, self_vs_other_chirps_temp)) + score = np.hstack((valid_other_chirp_distances, valid_self_vs_other_chirp_distances)) + auc = roc_auc_score(group, score) + if store_roc: + fpr, tpr, _ = roc_curve(group, score, pos_label=1) + detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) + else: + detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) print("\n") return detection_performances