import os import glob import pandas as pd import nixio as nix import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed import multiprocessing from joblib import Parallel, delayed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata figure_folder = "figures" data_folder = "data" def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. Args: spike_trains ([type]): [description] duration ([type]): [description] dt ([type]): [description] kernel_width ([type]): [description] Returns: np.ndarray: Matrix of firing rates, 1. dimension is the number of trials np.ndarray: the time vector """ time = np.arange(0.0, duration, dt) rates = np.zeros((len(spike_trains), len(time))) for i, sp in enumerate(spike_trains): rates[i, :] = firing_rate(sp, duration, kernel_width, dt) return rates, time def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): """Retruns the firing rates and the spikes Args: block_map ([type]): [description] df ([type]): [description] contrast ([type]): [description] condition ([type]): [description] kernel_width (float, optional): [description]. Defaults to 0.0005. Returns: np.ndarray: the time vector. np.ndarray: the rates with the first dimension representing the trials. np.adarray: the spike trains. """ block = block_map[(contrast, df, condition)] spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) self_snippets = np.zeros_like(no_other_snippets) for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) end_index = start_index + no_other_snippets.shape[1] index = i * len(interchirp_starts) + j no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] # get the distances baseline_dist = within_group_distance(no_other_snippets) comp_dist = across_group_distance(no_other_snippets, self_snippets) # sort and perfom ROC analysis triangle_indices = np.tril_indices_from(baseline_dist, -1) valid_distances_baseline = baseline_dist[triangle_indices] temp1 = np.zeros_like(valid_distances_baseline) valid_distances_comparison = comp_dist.ravel() temp2 = np.ones_like(valid_distances_comparison) group = np.hstack((temp1, temp2)) score = np.hstack((valid_distances_baseline, valid_distances_comparison)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) print("\n") return detection_performances def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] other_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) other_spikes = get_spikes(other_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width) # get the chirp response snippets alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) self_snippets = np.zeros_like(alone_chirping_snippets) other_snippets = np.zeros_like(alone_chirping_snippets) silence_snippets = np.zeros_like(alone_chirping_snippets) for i in range(no_other_rates.shape[0]): for j, chirp_time in enumerate(chirp_times): start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt) end_index = start_index + alone_chirping_snippets.shape[1] index = i * len(chirp_times) + j alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] other_snippets[index, :] = other_rates[i, start_index:end_index] silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt) silence_end_index = silence_start_index + alone_chirping_snippets.shape[1] silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index] # get the distances # 1. Soliloquy # 2. Nobody chirps, all alone aka baseline response # 3. I chirp while the other is present compared to self chirping without the other one present # 4. the otherone chrips to me compared to baseline with anyone chirping alone_chirping_dist = within_group_distance(alone_chirping_snippets) silence_dist = within_group_distance(silence_snippets) self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) # sort and perfom ROC analysis for two comparisons # 1. soliloquy vs. self chirping in company # 2. other chirping vs. nobody is chirping triangle_indices = np.tril_indices_from(alone_chirping_dist, -1) valid_no_other_distances = alone_chirping_dist[triangle_indices] no_other_temp = np.zeros_like(valid_no_other_distances) valid_silence_distances = silence_dist[triangle_indices] silence_temp = np.zeros_like(valid_silence_distances) valid_self_vs_alone_distances = self_vs_alone_dist.ravel() self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances) valid_other_vs_silence_distances = other_vs_silence_dist.ravel() other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances) group = np.hstack((no_other_temp, self_vs_alone_temp)) score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) group = np.hstack((silence_temp, other_vs_silence_temp)) score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) if store_roc: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc}) print("\n") return detection_performances def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False): dfs = [current_df] if current_df is not None else all_dfs kernels = [0.00025, 0.0005, 0.001, 0.0025] result_dicts = [] for df in dfs: for kw in kernels: print("df: %i, kernel: %.4f" % (df, kw)) print("Foreign fish detection during beat:") result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc)) print("Foreign fish detection during chirp:") result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name, store_roc)) return result_dicts def estimate_chirp_phase(am, chirp_times): pass def process_cell(filename, dfs=[], contrasts=[], conditions=[]): print(filename) nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False) nf.close() return results def main(): num_cores = multiprocessing.cpu_count() - 6 nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files) results = [] for pr in processed_list: results.extend(pr) df = pd.DataFrame(results) df.to_csv(os.path.join(data_folder, "discimination_results.csv"), sep=";") if __name__ == "__main__": main()