import os import glob import pandas as pd import nixio as nix import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed import multiprocessing from joblib import Parallel, delayed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata figure_folder = "figures" data_folder = "data" def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. Args: spike_trains ([type]): [description] duration ([type]): [description] dt ([type]): [description] kernel_width ([type]): [description] Returns: np.ndarray: Matrix of firing rates, 1. dimension is the number of trials np.ndarray: the time vector """ time = np.arange(0.0, duration, dt) rates = np.zeros((len(spike_trains), len(time))) for i, sp in enumerate(spike_trains): rates[i, :] = firing_rate(sp, duration, kernel_width, dt) return rates, time def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): """Retruns the firing rates and the spikes Args: block_map ([type]): [description] df ([type]): [description] contrast ([type]): [description] condition ([type]): [description] kernel_width (float, optional): [description]. Defaults to 0.0005. Returns: np.ndarray: the time vector. np.ndarray: the rates with the first dimension representing the trials. np.adarray: the spike trains. """ block = block_map[(contrast, df, condition)] spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): """Tries to detect the presence of a foreign fish by estimating the discriminability of the responses during the beat versus the responses without another fish beeing there, i.e. the baseline activity. Applies a ROC analysis to the response segments between chirps. Calculates a) the distances between the baseline responses and b) distances between the baseline and beat responses. Tests whether distances in b) are larger than a) Args: block_map ([type]): maps nix blocks to combination of stimulus parameters df ([type]): the difference frequency that should be used cs ([type]): ths chirpsize that should be used all_contrasts ([type]): list of all used contrasts all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005. cell_name (str, optional): name of the cell. Defaults to "". store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False. Returns: list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1]. The 'detection_task' is 'beat' """ detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, cs, "no-other")] self_block = block_map[(contrast, df, cs, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) # section b, alone, no chirps self_snippets = np.zeros_like(no_other_snippets) # section d, in company, no chirps, just beat for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) end_index = start_index + no_other_snippets.shape[1] index = i * len(interchirp_starts) + j no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] # get the distances baseline_dist = within_group_distance(no_other_snippets) comp_dist = across_group_distance(no_other_snippets, self_snippets) # sort and perfom ROC analysis triangle_indices = np.tril_indices_from(baseline_dist, -1) valid_distances_baseline = baseline_dist[triangle_indices] temp1 = np.zeros_like(valid_distances_baseline) valid_distances_comparison = comp_dist.ravel() temp2 = np.ones_like(valid_distances_comparison) group = np.hstack((temp1, temp2)) score = np.hstack((valid_distances_baseline, valid_distances_comparison)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) print("\n") return detection_performances def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): """Tries to detect the presence of a foreign fish by estimating the discriminability of the chirp responses in the presence of another fish versus the responses without another fish beeing around. Applies a ROC analysis to the response segments containing the chirp. Does two discrimination tests: 1) compares the responses to self-chirping alone to the responses to self-chriping in company. 2) compares the responess to other-chirping to the response during the beat. Tests the assumptions that the distances a) between the self-chriping alone and self-chriping in company are larger than the distances within the the self-chirping alone condition and b) the distances between other-chirping in company and no one is chirping in company (i.e. beat) are larger than the distances within the beat responses. Args: block_map ([type]): maps nix blocks to combination of stimulus parameters df ([type]): the difference frequency that should be used cs ([type]): ths chirpsize that should be used all_contrasts ([type]): list of all used contrasts all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005. cell_name (str, optional): name of the cell. Defaults to "". store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False. Returns: list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1]. The 'detection_task' is either "self vs soliloquy" for 1) or "other vs quietness" for 2) """ detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, cs, "no-other")] self_block = block_map[(contrast, df, cs, "self")] other_block = block_map[(contrast, df, cs, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) other_spikes = get_spikes(other_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width) # get the chirp response snippets alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) # section a, alone self-chirping self_snippets = np.zeros_like(alone_chirping_snippets) # section c, self chirping in company other_snippets = np.zeros_like(alone_chirping_snippets) # section e, other chirping in company silence_snippets = np.zeros_like(alone_chirping_snippets) # section d, in company no one chirping for i in range(no_other_rates.shape[0]): for j, chirp_time in enumerate(chirp_times): start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt) end_index = start_index + alone_chirping_snippets.shape[1] index = i * len(chirp_times) + j alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] other_snippets[index, :] = other_rates[i, start_index:end_index] silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt) silence_end_index = silence_start_index + alone_chirping_snippets.shape[1] silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index] # get the distances # 1. Soliloquy # 2. Nobody chirps, all alone aka baseline response # 3. I chirp while the other is present compared to self chirping without the other one present # 4. the otherone chrips to me compared to baseline with anyone chirping alone_chirping_dist = within_group_distance(alone_chirping_snippets) # within section a silence_dist = within_group_distance(silence_snippets) # within section d other_chirp_dist = within_group_distance(other_snippets) # within section e self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) # section a vs. section c other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) # section d vs. section e self_other_chirp_dist = across_group_distance(self_snippets, other_snippets) # section c vs. section e self_chirp_beat_dist = across_group_distance(self_snippets, silence_snippets) # section c vs. section d alone_chirp_beat_dist = across_group_distance(alone_chirping_snippets, silence_snippets) # section a vs. section d # sort and perfom ROC analysis for two comparisons # 1. soliloquy vs. self chirping in company # 2. other chirping vs. nobody is chirping triangle_indices = np.tril_indices_from(alone_chirping_dist, -1) valid_no_other_distances = alone_chirping_dist[triangle_indices] no_other_temp = np.zeros_like(valid_no_other_distances) valid_silence_distances = silence_dist[triangle_indices] silence_temp = np.zeros_like(valid_silence_distances) valid_other_chirp_distances = other_chirp_dist[triangle_indices] other_chirp_temp = np.zeros_like(valid_other_chirp_distances) valid_self_vs_alone_distances = self_vs_alone_dist.ravel() self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances) valid_other_vs_silence_distances = other_vs_silence_dist.ravel() other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances) valid_self_vs_other_chirp_distances = self_other_chirp_dist.ravel() self_vs_other_chirps_temp = np.ones_like(valid_self_vs_other_chirp_distances) valid_self_beat_distances = self_chirp_beat_dist.ravel() self_vs_beat_temp = np.ones_like(valid_self_beat_distances) valid_alone_chirp_beat_distance = alone_chirp_beat_dist.ravel() alone_chirp_beat_temp = np.ones_like(valid_alone_chirp_beat_distance) # Comparison 2: alone chirping (soliloquy) vs. self-chirping in company group = np.hstack((no_other_temp, self_vs_alone_temp)) score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) # Comparison 3: other fish chirping vs. beat group = np.hstack((silence_temp, other_vs_silence_temp)) score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) # Comparison 4: soliloquy vs. beat group = np.hstack((no_other_temp, alone_chirp_beat_temp)) score = np.hstack((valid_no_other_distances, valid_alone_chirp_beat_distance)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) # Comparison 5: beat vs self-chirping in company group = np.hstack((silence_temp, self_vs_beat_temp)) score = np.hstack((valid_silence_distances, valid_alone_chirp_beat_distance)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) # Comparison 6: self vs other-chirping in company group = np.hstack((other_chirp_temp, self_vs_other_chirps_temp)) score = np.hstack((valid_other_chirp_distances, valid_self_vs_other_chirp_distances)) auc = roc_auc_score(group, score) if store_roc: fpr, tpr, _ = roc_curve(group, score, pos_label=1) detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr}) else: detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc}) print("\n") return detection_performances def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes, current_df=None, current_chirpsize=None, cell_name="", store_roc=False): dfs = [current_df] if current_df is not None else all_dfs chirp_sizes = [current_chirpsize] if current_chirpsize is not None else all_chirpsizes kernels = [0.00025, 0.0005, 0.001, 0.0025] result_dicts = [] for cs in chirp_sizes: for df in dfs: for kw in kernels: print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw)) print("Foreign fish detection during beat:") result_dicts.extend(foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc)) print("Foreign fish detection during chirp:") result_dicts.extend(foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc)) return result_dicts def process_cell(filename): print(filename) nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_chirpsizes, all_conditions = sort_blocks(nf) if "baseline" not in block_map.keys(): print("ERROR: no baseline data for file %s!" % filename) results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes, current_df=None, current_chirpsize=None, cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False) nf.close() return results def main(): num_cores = multiprocessing.cpu_count() - 6 nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files) results = [] for pr in processed_list: results.extend(pr) df = pd.DataFrame(results) df.to_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";") if __name__ == "__main__": main()