work work

This commit is contained in:
Jan Grewe 2023-04-01 18:16:13 +02:00
parent d98d38c9be
commit 116ebfd70f
3 changed files with 46 additions and 38 deletions

View File

@ -281,34 +281,22 @@ def foreign_fish_detection_example_plot(args):
store.close()
def performance_plot(args):
if not os.path.exists(args.inputfile):
raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile)
df = pd.read_csv(args.inputfile, sep=";")
dfs = np.sort(df.df.unique())
contrasts = np.sort(df.contrast.unique())
tasks = df.detection_task.unique()
kernel_widths = list(df.kernel_width.unique())
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
chirpsizes = list(df.chirpsize.unique())
if args.chirpsize not in chirpsizes:
raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes))
def plot_surfaces(data_frame, dfs, contrasts, tasks, selected_dfs, selected_contrasts, kernel_width, chirpsize, filename):
X, Y = np.meshgrid(dfs, contrasts)
Z = np.zeros_like(X)
fig = plt.figure(figsize=(8.0, 10))
fig_grid = (19, 10)
for index, t in enumerate(tasks):
for index, t in enumerate(tasks):
ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" )
ax.set_title(t, loc="left", pad=-0.5)
for i, d in enumerate(dfs):
for j, c in enumerate(contrasts):
data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)]
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
Z[j, i] = np.mean(data_df.auc)
ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0)
ax.set_xlabel(r"$\Delta_f [Hz]$", fontsize=8)
ax.set_xlabel(r"$\Delta f [Hz]$", fontsize=8)
ax.set_ylabel("contrast [%]", fontsize=8)
ax.set_zlabel("performance", fontsize=8, rotation=180)
ax.set_zlim([0.45, 1.0])
@ -318,12 +306,12 @@ def performance_plot(args):
cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2)
performances = np.zeros_like(contrasts)
errors = np.zeros_like(contrasts)
for d in args.deltafs:
for d in selected_dfs:
for i, c in enumerate(contrasts):
data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)]
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
performances[i] = np.mean(data_df.auc)
errors[i] = np.std(data_df.auc)
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d)
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta f:$ %i Hz" % d)
cntrst_ax.set_ylim([0.25, 1.0])
cntrst_ax.set_ylabel("performance", fontsize=8)
cntrst_ax.set_xlabel("contrast [%]", fontsize=8)
@ -333,23 +321,46 @@ def performance_plot(args):
df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2)
performances = np.zeros_like(dfs)
errors = np.zeros_like(dfs)
for c in args.contrasts:
for c in selected_contrasts:
for i, d in enumerate(dfs):
data_df = df[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t) & (df.chirpsize == args.chirpsize)]
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
performances[i] = np.mean(data_df.auc)
errors[i] = np.std(data_df.auc)
df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c)
df_ax.set_ylim([0.25, 1.0])
df_ax.set_ylabel("performance", fontsize=8)
df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8)
df_ax.set_xlabel(r"$\Delta f$ [Hz]", fontsize=8)
df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2)
df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975)
fig.savefig(args.outfile)
fig.savefig(filename)
plt.close()
def performance_plot(args):
if not os.path.exists(args.inputfile):
raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile)
df = pd.read_csv(args.inputfile, sep=";")
all_dfs = np.sort(df.df.unique())
all_contrasts = np.sort(df.contrast.unique())
tasks = df.detection_task.unique()
kernel_widths = list(df.kernel_width.unique())
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
chirpsizes = list(df.chirpsize.unique())
if args.chirpsize not in chirpsizes:
raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes))
selected_dfs = args.deltafs
selected_contrasts = args.contrasts
filename = args.outfile
temp = filename.split('.')
filename = temp[0] + "_1." + temp[-1]
plot_surfaces(df, all_dfs, all_contrasts, tasks[:3], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename)
filename = temp[0] + "_2." + temp[-1]
plot_surfaces(df, all_dfs, all_contrasts, tasks[3:], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename)
def main():
parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.")
subparsers = parser.add_subparsers(title="commands",

View File

@ -2,12 +2,9 @@ import numpy as np
import nixio as nix
import argparse
import os
from numpy.core.fromnumeric import repeat
from traitlets.traitlets import Instance
from chirp_ams import get_signals
from model import simulate, load_models
from IPython import embed
import matplotlib.pyplot as plt
import multiprocessing
from joblib import Parallel, delayed
@ -26,7 +23,7 @@ def append_settings(section, sec_name, sec_type, settings):
else:
section[k] = settings[k]
def save(filename, name, stimulus_settings, model_settings, self_signal, other_signal, self_freq, other_freq, complete_stimulus, responses, overwrite=False):
if os.path.exists(filename) and not overwrite:
nf = nix.File.open(filename, nix.FileMode.ReadWrite)
@ -237,7 +234,7 @@ def main():
num_models = len(models)
indices = list(range(len(models)))
np.random.shuffle(indices)
Parallel(n_jobs=args.jobs)(delayed(simulate_cell)(cell_id, models, args) for cell_id in indices[:num_models])

View File

@ -82,8 +82,8 @@ def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions
detection_performances = []
for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
# print(" " * 50, end="\r")
# print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, cs, "no-other")]
self_block = block_map[(contrast, df, cs, "self")]
@ -133,7 +133,7 @@ def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions
detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
print("\n")
# print("\n")
return detection_performances
@ -168,8 +168,8 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition
detection_performances = []
for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
# print(" " * 50, end="\r")
# print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, cs, "no-other")]
self_block = block_map[(contrast, df, cs, "self")]
other_block = block_map[(contrast, df, cs, "self")]
@ -297,7 +297,7 @@ def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_condition
detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
print("\n")
# print("\n")
return detection_performances
@ -308,18 +308,18 @@ def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, al
result_dicts = []
for cs in chirp_sizes:
for df in dfs:
print("%s, chirp size: %i Hz, deltaf %.1f Hz" % (cell_name, cs, df))
for kw in kernels:
print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw))
print("Foreign fish detection during beat:")
#print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw))
#print("Foreign fish detection during beat:")
result_dicts.extend(foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))
print("Foreign fish detection during chirp:")
#print("Foreign fish detection during chirp:")
result_dicts.extend(foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))
return result_dicts
def process_cell(filename):
print(filename)
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_chirpsizes, all_conditions = sort_blocks(nf)
if "baseline" not in block_map.keys():