From 116ebfd70f8422d4474ec8cf786c2ef32c1ab89d Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Sat, 1 Apr 2023 18:16:13 +0200 Subject: [PATCH] work work --- plots.py | 57 +++++++++++++++++++++--------------- punit_responses.py | 7 ++--- response_discriminability.py | 20 ++++++------- 3 files changed, 46 insertions(+), 38 deletions(-) diff --git a/plots.py b/plots.py index 6512b75..9a64c92 100644 --- a/plots.py +++ b/plots.py @@ -281,34 +281,22 @@ def foreign_fish_detection_example_plot(args): store.close() -def performance_plot(args): - if not os.path.exists(args.inputfile): - raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile) - df = pd.read_csv(args.inputfile, sep=";") - dfs = np.sort(df.df.unique()) - contrasts = np.sort(df.contrast.unique()) - tasks = df.detection_task.unique() - kernel_widths = list(df.kernel_width.unique()) - kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0] - chirpsizes = list(df.chirpsize.unique()) - if args.chirpsize not in chirpsizes: - raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes)) - +def plot_surfaces(data_frame, dfs, contrasts, tasks, selected_dfs, selected_contrasts, kernel_width, chirpsize, filename): X, Y = np.meshgrid(dfs, contrasts) Z = np.zeros_like(X) fig = plt.figure(figsize=(8.0, 10)) fig_grid = (19, 10) - for index, t in enumerate(tasks): + for index, t in enumerate(tasks): ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" ) ax.set_title(t, loc="left", pad=-0.5) for i, d in enumerate(dfs): for j, c in enumerate(contrasts): - data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)] + data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] Z[j, i] = np.mean(data_df.auc) ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0) - ax.set_xlabel(r"$\Delta_f [Hz]$", fontsize=8) + ax.set_xlabel(r"$\Delta f [Hz]$", fontsize=8) ax.set_ylabel("contrast [%]", fontsize=8) ax.set_zlabel("performance", fontsize=8, rotation=180) ax.set_zlim([0.45, 1.0]) @@ -318,12 +306,12 @@ def performance_plot(args): cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2) performances = np.zeros_like(contrasts) errors = np.zeros_like(contrasts) - for d in args.deltafs: + for d in selected_dfs: for i, c in enumerate(contrasts): - data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)] + data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] performances[i] = np.mean(data_df.auc) errors[i] = np.std(data_df.auc) - cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d) + cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta f:$ %i Hz" % d) cntrst_ax.set_ylim([0.25, 1.0]) cntrst_ax.set_ylabel("performance", fontsize=8) cntrst_ax.set_xlabel("contrast [%]", fontsize=8) @@ -333,23 +321,46 @@ def performance_plot(args): df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2) performances = np.zeros_like(dfs) errors = np.zeros_like(dfs) - for c in args.contrasts: + for c in selected_contrasts: for i, d in enumerate(dfs): - data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)] + data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] performances[i] = np.mean(data_df.auc) errors[i] = np.std(data_df.auc) df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c) df_ax.set_ylim([0.25, 1.0]) df_ax.set_ylabel("performance", fontsize=8) - df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8) + df_ax.set_xlabel(r"$\Delta f$ [Hz]", fontsize=8) df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2) df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975) - fig.savefig(args.outfile) + fig.savefig(filename) plt.close() +def performance_plot(args): + if not os.path.exists(args.inputfile): + raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile) + df = pd.read_csv(args.inputfile, sep=";") + all_dfs = np.sort(df.df.unique()) + all_contrasts = np.sort(df.contrast.unique()) + tasks = df.detection_task.unique() + kernel_widths = list(df.kernel_width.unique()) + kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0] + chirpsizes = list(df.chirpsize.unique()) + if args.chirpsize not in chirpsizes: + raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes)) + selected_dfs = args.deltafs + selected_contrasts = args.contrasts + + filename = args.outfile + temp = filename.split('.') + filename = temp[0] + "_1." + temp[-1] + plot_surfaces(df, all_dfs, all_contrasts, tasks[:3], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename) + filename = temp[0] + "_2." + temp[-1] + plot_surfaces(df, all_dfs, all_contrasts, tasks[3:], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename) + + def main(): parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.") subparsers = parser.add_subparsers(title="commands", diff --git a/punit_responses.py b/punit_responses.py index 208169d..749be03 100644 --- a/punit_responses.py +++ b/punit_responses.py @@ -2,12 +2,9 @@ import numpy as np import nixio as nix import argparse import os -from numpy.core.fromnumeric import repeat -from traitlets.traitlets import Instance from chirp_ams import get_signals from model import simulate, load_models from IPython import embed -import matplotlib.pyplot as plt import multiprocessing from joblib import Parallel, delayed @@ -26,7 +23,7 @@ def append_settings(section, sec_name, sec_type, settings): else: section[k] = settings[k] - + def save(filename, name, stimulus_settings, model_settings, self_signal, other_signal, self_freq, other_freq, complete_stimulus, responses, overwrite=False): if os.path.exists(filename) and not overwrite: nf = nix.File.open(filename, nix.FileMode.ReadWrite) @@ -237,7 +234,7 @@ def main(): num_models = len(models) indices = list(range(len(models))) np.random.shuffle(indices) - + Parallel(n_jobs=args.jobs)(delayed(simulate_cell)(cell_id, models, args) for cell_id in indices[:num_models]) diff --git a/response_discriminability.py b/response_discriminability.py index 995abd9..a8652cc 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -82,8 +82,8 @@ def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions detection_performances = [] for contrast in all_contrasts: - print(" " * 50, end="\r") - print("Contrast: %.3f" % contrast, end="\r") + # print(" " * 50, end="\r") + # print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, cs, "no-other")] self_block = block_map[(contrast, df, cs, "self")] @@ -133,7 +133,7 @@ def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) - print("\n") + # print("\n") return detection_performances @@ -168,8 +168,8 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition detection_performances = [] for contrast in all_contrasts: - print(" " * 50, end="\r") - print("Contrast: %.3f" % contrast, end="\r") + # print(" " * 50, end="\r") + # print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, cs, "no-other")] self_block = block_map[(contrast, df, cs, "self")] other_block = block_map[(contrast, df, cs, "self")] @@ -297,7 +297,7 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) - print("\n") + # print("\n") return detection_performances @@ -308,18 +308,18 @@ def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, al result_dicts = [] for cs in chirp_sizes: for df in dfs: + print("%s, chirp size: %i Hz, deltaf %.1f Hz" % (cell_name, cs, df)) for kw in kernels: - print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw)) - print("Foreign fish detection during beat:") + #print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw)) + #print("Foreign fish detection during beat:") result_dicts.extend(foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc)) - print("Foreign fish detection during chirp:") + #print("Foreign fish detection during chirp:") result_dicts.extend(foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc)) return result_dicts def process_cell(filename): - print(filename) nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_chirpsizes, all_conditions = sort_blocks(nf) if "baseline" not in block_map.keys():