import os import glob import pandas as pd import nixio as nix import numpy as np import scipy.signal as sig import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance figure_folder = "figures" data_folder = "data" def read_baseline(block): spikes = [] if "baseline" not in block.name: print("Block %s does not appear to be a baseline block!" % block.name ) return spikes spikes = block.data_arrays[0][:] return spikes def sort_blocks(nix_file): block_map = {} contrasts = [] deltafs = [] conditions = [] for b in nix_file.blocks: if "baseline" not in b.name.lower(): name_parts = b.name.split("_") cntrst = float(name_parts[1]) if cntrst not in contrasts: contrasts.append(cntrst) cndtn = name_parts[3] if cndtn not in conditions: conditions.append(cndtn) dltf = float(name_parts[5]) if dltf not in deltafs: deltafs.append(dltf) block_map[(cntrst, dltf, cndtn)] = b else: block_map["baseline"] = b return block_map, contrasts, deltafs, conditions def get_spikes(block): """Get the spike trains. Args: block ([type]): [description] Returns: list of np.ndarray: the spike trains. """ response_map = {} spikes = [] for da in block.data_arrays: if "spike_times" in da.type and "response" in da.name: resp_id = int(da.name.split("_")[-1]) response_map[resp_id] = da for k in sorted(response_map.keys()): spikes.append(response_map[k][:]) return spikes def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. Args: spike_trains ([type]): [description] duration ([type]): [description] dt ([type]): [description] kernel_width ([type]): [description] Returns: np.ndarray: Matrix of firing rates, 1. dimension is the number of trials np.ndarray: the time vector """ time = np.arange(0.0, duration, dt) rates = np.zeros((len(spike_trains), len(time))) for i, sp in enumerate(spike_trains): rates[i, :] = firing_rate(sp, duration, kernel_width, dt) return rates, time def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): """Retruns the firing rates and the spikes Args: block_map ([type]): [description] df ([type]): [description] contrast ([type]): [description] condition ([type]): [description] kernel_width (float, optional): [description]. Defaults to 0.0005. Returns: np.ndarray: the time vector. np.ndarray: the rates with the first dimension representing the trials. np.adarray: the spike trains. """ block = block_map[(contrast, df, condition)] spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def get_signals(block): """Read the fish signals from block. Args: block ([type]): the block containing the data for a given df, contrast and condition Raises: ValueError: when the complete stimulus data is not found ValueError: when the no-other animal data is not found Returns: np.ndarray: the complete signal np.ndarray: the frequency profile of the recorded fish np.ndarray: the frequency profile of the other fish np.ndarray: the time axis """ self_freq = None other_freq = None signal = None time = None if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") if "no-other" not in block.name and "other frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") signal = block.data_arrays["complete stimulus"][:] time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal))) self_freq = block.data_arrays["self frequency"][:] if "no-other" not in block.name: other_freq = block.data_arrays["other frequency"][:] return signal, self_freq, other_freq, time def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) all_contrasts = sorted(all_contrasts, reverse=True) for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) am = extract_am(signal) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) ax.set_title(condition_labels[i]) despine(ax, ["top", "bottom", "left", "right"], True) # plot the am ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], color="#2ca02c", label="signal") ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], color="#d62728", label="am") despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) # for each contrast plot the firing rate for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) ax.set_ylim([0, 750]) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) if j < len(all_contrasts) -1: ax.set_xticklabels([]) ax.set_yticks(np.arange(0.0, 751., 500)) ax.set_yticks(np.arange(0.0, 751., 125), minor=True) if i > 0: ax.set_yticklabels([]) despine(ax, ["top", "right"], False) if i == 2: ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], color="#d62728", ha="left", fontsize=7) if i == 1: ax.set_xlabel("time [ms]") if i == 0: ax.set_ylabel("frequency [Hz]", va="center") ax.yaxis.set_label_coords(-0.45, 3.5) name = figure_name if figure_name is not None else "chirp_responses.pdf" name = (name + ".pdf") if ".pdf" not in name else name plt.savefig(os.path.join(figure_folder, name)) plt.close() def get_chirp_metadata(block): trial_duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"] chirp_size = block.metadata["stimulus parameter"]["chirp_size"] chirp_times = block.metadata["stimulus parameter"]["chirp_times"] return trial_duration, dt, chirp_size, chirp_duration, chirp_times def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name=""): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) self_snippets = np.zeros_like(no_other_snippets) for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) end_index = start_index + no_other_snippets.shape[1] index = i * len(interchirp_starts) + j no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] # get the distances baseline_dist = within_group_distance(no_other_snippets) comp_dist = across_group_distance(no_other_snippets, self_snippets) # sort and perfom roc triangle_indices = np.tril_indices_from(baseline_dist, -1) valid_distances_baseline = baseline_dist[triangle_indices] temp1 = np.zeros_like(valid_distances_baseline) valid_distances_comparison = comp_dist.ravel() temp2 = np.ones_like(valid_distances_comparison) group = np.hstack((temp1, temp2)) score = np.hstack((valid_distances_baseline, valid_distances_comparison)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) print("\n") return detection_performances def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name=""): detection_performances = [] for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] other_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each condition, which they should duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block) # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) other_spikes = get_spikes(other_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width) # get the chirp response snippets alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) self_snippets = np.zeros_like(alone_chirping_snippets) other_snippets = np.zeros_like(alone_chirping_snippets) baseline_snippets = np.zeros_like(alone_chirping_snippets) for i in range(no_other_rates.shape[0]): for j, chirp_time in enumerate(chirp_times): start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt) end_index = start_index + alone_chirping_snippets.shape[1] index = i * len(chirp_times) + j alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] other_snippets[index, :] = other_rates[i, start_index:end_index] baseline_start_index = int((chirp_time + 1.5 * chirp_duration)/dt) baseline_end_index = baseline_start_index + alone_chirping_snippets.shape[1] baseline_snippets[index, :] = no_other_rates[i, baseline_start_index:baseline_end_index] # get the distances # 1. Soliloquy # 2. Nobody chirps, all alone aka baseline response # 3. I chirp while the other is present compared to self chirping without the other one present # 4. the otherone chrips to me compared to baseline with anyone chirping alone_chirping_dist = within_group_distance(alone_chirping_snippets) baseline_dist = within_group_distance(baseline_snippets) self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) other_vs_baseline_dist = across_group_distance(baseline_snippets, other_snippets) # sort and perfom roc for two comparisons # 1. soliloquy vs. self chirping in company # 2. other chirping vs. nobody is chirping triangle_indices = np.tril_indices_from(alone_chirping_dist, -1) valid_no_other_distances = alone_chirping_dist[triangle_indices] no_other_temp = np.zeros_like(valid_no_other_distances) valid_baseline_distances = baseline_dist[triangle_indices] baseline_temp = np.zeros_like(valid_baseline_distances) valid_self_vs_alone_distances = self_vs_alone_dist.ravel() self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances) valid_other_vs_baseline_distances = other_vs_baseline_dist.ravel() other_vs_baseline_temp = np.ones_like(valid_other_vs_baseline_distances) group = np.hstack((no_other_temp, self_vs_alone_temp)) score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) group = np.hstack((baseline_temp, other_vs_baseline_temp)) score = np.hstack((valid_baseline_distances, valid_other_vs_baseline_distances)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr}) print("\n") return detection_performances def plot_detection_results(data_frame, df, kernel_width, cell): cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] conditions = sorted(cell_results.detection_task.unique()) kernels = sorted(cell_results.kernel_width.unique()) fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (8, 7) for i, c in enumerate(conditions): condition_results = cell_results[cell_results.detection_task == c] roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) roc_data = condition_results[condition_results.kernel_width == kernel_width] contrasts = roc_data.contrast.unique() for c in contrasts: tpr = roc_data.true_positives[roc_data.contrast == c].values[0] fpr = roc_data.false_positives[roc_data.contrast == c].values[0] roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) roc_ax.legend(loc="best", fontsize=6, ncol=2, frameon=False) roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) roc_ax.set_xlabel("false positive rate", fontsize=9) roc_ax.set_ylabel("true positive rate", fontsize=9) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) for k in kernels: contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) aucs_sorted = aucs[np.argsort(contrasts)] contrasts_sorted = np.sort(contrasts) auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.4f" % k) auc_ax.set_xlabel("contrast [%]") auc_ax.set_ylim([0.25, 1.0]) auc_ax.set_ylabel("discriminability") auc_ax.legend(ncol=2, fontsize=6) auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls"--",) fig.savefig("discrimination.pdf") def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=""): dfs = [current_df] if current_df is not None else all_dfs kernels = [0.00025, 0.0005, 0.001, 0.0025, 0.005] result_dicts = [] for df in dfs: for kw in kernels: print("df: %i, kernel: %.4f" % (df, kw)) print("Foreign fish detection during beat:") result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name)) print("Foreign fish detection during chirp:") result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name)) break embed() return result_dicts def process_cell(filename, dfs=[], contrasts=[], conditions=[]): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, cell_name=filename.split(os.path.sep)[-1].split(".nix")[0]) nf.close() def main(): nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) for nix_file in nix_files: process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"]) if __name__ == "__main__": main()