import os import glob import pandas as pd import nixio as nix import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection from matplotlib.patches import ConnectionPatch from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance figure_folder = "figures" data_folder = "data" def read_baseline(block): spikes = [] if "baseline" not in block.name: print("Block %s does not appear to be a baseline block!" % block.name ) return spikes spikes = block.data_arrays[0][:] return spikes def sort_blocks(nix_file): block_map = {} contrasts = [] deltafs = [] conditions = [] for b in nix_file.blocks: if "baseline" not in b.name.lower(): name_parts = b.name.split("_") cntrst = float(name_parts[1]) if cntrst not in contrasts: contrasts.append(cntrst) cndtn = name_parts[3] if cndtn not in conditions: conditions.append(cndtn) dltf = float(name_parts[5]) if dltf not in deltafs: deltafs.append(dltf) block_map[(cntrst, dltf, cndtn)] = b else: block_map["baseline"] = b return block_map, contrasts, deltafs, conditions def get_spikes(block): """Get the spike trains. Args: block ([type]): [description] Returns: list of np.ndarray: the spike trains. """ response_map = {} spikes = [] for da in block.data_arrays: if "spike_times" in da.type and "response" in da.name: resp_id = int(da.name.split("_")[-1]) response_map[resp_id] = da for k in sorted(response_map.keys()): spikes.append(response_map[k][:]) return spikes def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. Args: spike_trains ([type]): [description] duration ([type]): [description] dt ([type]): [description] kernel_width ([type]): [description] Returns: np.ndarray: Matrix of firing rates, 1. dimension is the number of trials np.ndarray: the time vector """ time = np.arange(0.0, duration, dt) rates = np.zeros((len(spike_trains), len(time))) for i, sp in enumerate(spike_trains): rates[i, :] = firing_rate(sp, duration, kernel_width, dt) return rates, time def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): """Retruns the firing rates and the spikes Args: block_map ([type]): [description] df ([type]): [description] contrast ([type]): [description] condition ([type]): [description] kernel_width (float, optional): [description]. Defaults to 0.0005. Returns: np.ndarray: the time vector. np.ndarray: the rates with the first dimension representing the trials. np.adarray: the spike trains. """ block = block_map[(contrast, df, condition)] spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def get_signals(block): """Read the fish signals from block. Args: block ([type]): the block containing the data for a given df, contrast and condition Raises: ValueError: when the complete stimulus data is not found ValueError: when the no-other animal data is not found Returns: np.ndarray: the complete signal np.ndarray: the frequency profile of the recorded fish np.ndarray: the frequency profile of the other fish np.ndarray: the time axis """ self_freq = None other_freq = None signal = None time = None if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") if "no-other" not in block.name and "other frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") signal = block.data_arrays["complete stimulus"][:] time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal))) self_freq = block.data_arrays["self frequency"][:] if "no-other" not in block.name: other_freq = block.data_arrays["other frequency"][:] return signal, self_freq, other_freq, time def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) all_contrasts = sorted(all_contrasts, reverse=True) for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) am = extract_am(signal) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) ax.set_title(condition_labels[i]) despine(ax, ["top", "bottom", "left", "right"], True) # plot the am ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], color="#2ca02c", label="signal") ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], color="#d62728", label="am") despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) # for each contrast plot the firing rate for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) ax.set_ylim([0, 750]) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) if j < len(all_contrasts) -1: ax.set_xticklabels([]) ax.set_yticks(np.arange(0.0, 751., 500)) ax.set_yticks(np.arange(0.0, 751., 125), minor=True) if i > 0: ax.set_yticklabels([]) despine(ax, ["top", "right"], False) if i == 2: ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], color="#d62728", ha="left", fontsize=7) if i == 1: ax.set_xlabel("time [ms]") if i == 0: ax.set_ylabel("frequency [Hz]", va="center") ax.yaxis.set_label_coords(-0.45, 3.5) name = figure_name if figure_name is not None else "chirp_responses.pdf" name = (name + ".pdf") if ".pdf" not in name else name plt.savefig(os.path.join(figure_folder, name)) plt.close() def get_chirp_metadata(block): trial_duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"] chirp_size = block.metadata["stimulus parameter"]["chirp_size"] chirp_times = block.metadata["stimulus parameter"]["chirp_times"] return trial_duration, dt, chirp_size, chirp_duration, chirp_times def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) self_snippets = np.zeros_like(no_other_snippets) for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) end_index = start_index + no_other_snippets.shape[1] index = i * len(interchirp_starts) + j no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] # get the distances baseline_dist = within_group_distance(no_other_snippets) comp_dist = across_group_distance(no_other_snippets, self_snippets) # sort and perfom ROC analysis triangle_indices = np.tril_indices_from(baseline_dist, -1) valid_distances_baseline = baseline_dist[triangle_indices] temp1 = np.zeros_like(valid_distances_baseline) valid_distances_comparison = comp_dist.ravel() temp2 = np.ones_like(valid_distances_comparison) group = np.hstack((temp1, temp2)) score = np.hstack((valid_distances_baseline, valid_distances_comparison)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) print("\n") return detection_performances def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] other_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) other_spikes = get_spikes(other_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width) # get the chirp response snippets alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) self_snippets = np.zeros_like(alone_chirping_snippets) other_snippets = np.zeros_like(alone_chirping_snippets) silence_snippets = np.zeros_like(alone_chirping_snippets) for i in range(no_other_rates.shape[0]): for j, chirp_time in enumerate(chirp_times): start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt) end_index = start_index + alone_chirping_snippets.shape[1] index = i * len(chirp_times) + j alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] other_snippets[index, :] = other_rates[i, start_index:end_index] silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt) silence_end_index = silence_start_index + alone_chirping_snippets.shape[1] silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index] # get the distances # 1. Soliloquy # 2. Nobody chirps, all alone aka baseline response # 3. I chirp while the other is present compared to self chirping without the other one present # 4. the otherone chrips to me compared to baseline with anyone chirping alone_chirping_dist = within_group_distance(alone_chirping_snippets) silence_dist = within_group_distance(silence_snippets) self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) # sort and perfom ROC analysis for two comparisons # 1. soliloquy vs. self chirping in company # 2. other chirping vs. nobody is chirping triangle_indices = np.tril_indices_from(alone_chirping_dist, -1) valid_no_other_distances = alone_chirping_dist[triangle_indices] no_other_temp = np.zeros_like(valid_no_other_distances) valid_silence_distances = silence_dist[triangle_indices] silence_temp = np.zeros_like(valid_silence_distances) valid_self_vs_alone_distances = self_vs_alone_dist.ravel() self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances) valid_other_vs_silence_distances = other_vs_silence_dist.ravel() other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances) group = np.hstack((no_other_temp, self_vs_alone_temp)) score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) group = np.hstack((silence_temp, other_vs_silence_temp)) score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) print("\n") return detection_performances def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None): cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] conditions = sorted(cell_results.detection_task.unique()) kernels = sorted(cell_results.kernel_width.unique()) fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (8, 7) for i, c in enumerate(conditions): condition_results = cell_results[cell_results.detection_task == c] roc_data = condition_results[condition_results.kernel_width == kernel_width] contrasts = roc_data.contrast.unique() roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) roc_ax.set_title(c, fontsize=9, loc="left") auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) for c in contrasts: tpr = roc_data.true_positives[roc_data.contrast == c].values[0] fpr = roc_data.false_positives[roc_data.contrast == c].values[0] roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) if i == 0: roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25) roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) if i == len(conditions) - 1: roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) roc_ax.set_xlabel("false positive rate", fontsize=9) roc_ax.set_ylabel("true positive rate", fontsize=9) roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) for k in kernels: contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) aucs_sorted = aucs[np.argsort(contrasts)] contrasts_sorted = np.sort(contrasts) auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1) if i == len(conditions) - 1: auc_ax.set_xlabel("contrast [%]", fontsize=9) auc_ax.set_ylim([0.25, 1.0]) auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25)) auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8) auc_ax.set_ylabel("discriminability", fontsize=9) if i == 0: auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center") auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0) name = figure_name if figure_name is not None else "foreign_fish_detection.pdf" name = (name + ".pdf") if ".pdf" not in name else name fig.savefig(os.path.join(figure_folder, name)) def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df): conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 2.)) fig_grid = (3, len(all_conditions)*3+2) axes = [] for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) # ax.set_title(condition_labels[i]) ax.set_ylim([735, 885]) despine(ax, ["top", "bottom", "left", "right"], True) axes.append(ax) rects = [] rect = Rectangle((0.675, 740), 0.098, 140) rects.append(rect) rect = Rectangle((0.57, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[0].add_collection(pc) rects = [] rect = Rectangle((0.675, 740), 0.098, 140) rects.append(rect) rect = Rectangle((0.575, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[1].add_collection(pc) rects = [] rect = Rectangle((0.57, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[2].add_collection(pc) con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) axes[0].text(1., 660, "2.") axes[1].text(1.05, 660, "3.") axes[0].text(1.1, 890, "1.") fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) fig.savefig(os.path.join(figure_folder, "comparisons.pdf")) plt.close() def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False): dfs = [current_df] if current_df is not None else all_dfs kernels = [0.00025, 0.0005, 0.001, 0.0025] result_dicts = [] for df in dfs: for kw in kernels: print("df: %i, kernel: %.4f" % (df, kw)) print("Foreign fish detection during beat:") result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc)) print("Foreign fish detection during chirp:") result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc)) return result_dicts def estimate_chirp_phase(am, chirp_times): pass def process_cell(filename, dfs=[], contrasts=[], conditions=[]): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False) nf.close() return results def plot_examples(filename, dfs=[], contrasts=[], conditions=[]): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) # plot the responses #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) # sketch showing the comparisons # plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20) # plot the discrimination analyses #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0] # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, # cell_name=cell_name, store_roc=True) # pdf = pd.DataFrame(results) # plot_detection_results(pdf, 20, 0.001, cell_name) nf.close() def main(): nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) for nix_file in nix_files: #plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"]) results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"]) # break embed() if __name__ == "__main__": main()