import os import glob import nixio as nix import numpy as np import scipy.signal as sig import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance figure_folder = "figures" data_folder = "data" def read_baseline(block): spikes = [] if "baseline" not in block.name: print("Block %s does not appear to be a baseline block!" % block.name ) return spikes spikes = block.data_arrays[0][:] return spikes def sort_blocks(nix_file): block_map = {} contrasts = [] deltafs = [] conditions = [] for b in nix_file.blocks: if "baseline" not in b.name.lower(): name_parts = b.name.split("_") cntrst = float(name_parts[1]) if cntrst not in contrasts: contrasts.append(cntrst) cndtn = name_parts[3] if cndtn not in conditions: conditions.append(cndtn) dltf = float(name_parts[5]) if dltf not in deltafs: deltafs.append(dltf) block_map[(cntrst, dltf, cndtn)] = b else: block_map["baseline"] = b return block_map, contrasts, deltafs, conditions def get_spikes(block): """Get the spike trains. Args: block ([type]): [description] Returns: list of np.ndarray: the spike trains. """ response_map = {} spikes = [] for da in block.data_arrays: if "spike_times" in da.type and "response" in da.name: resp_id = int(da.name.split("_")[-1]) response_map[resp_id] = da for k in sorted(response_map.keys()): spikes.append(response_map[k][:]) return spikes def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. Args: spike_trains ([type]): [description] duration ([type]): [description] dt ([type]): [description] kernel_width ([type]): [description] Returns: np.ndarray: Matrix of firing rates, 1. dimension is the number of trials np.ndarray: the time vector """ time = np.arange(0.0, duration, dt) rates = np.zeros((len(spike_trains), len(time))) for i, sp in enumerate(spike_trains): rates[i, :] = firing_rate(sp, duration, kernel_width, dt) return rates, time def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): """Retruns the firing rates and the spikes Args: block_map ([type]): [description] df ([type]): [description] contrast ([type]): [description] condition ([type]): [description] kernel_width (float, optional): [description]. Defaults to 0.0005. Returns: np.ndarray: the time vector. np.ndarray: the rates with the first dimension representing the trials. np.adarray: the spike trains. """ block = block_map[(contrast, df, condition)] spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def get_signals(block): """Read the fish signals from block. Args: block ([type]): the block containing the data for a given df, contrast and condition Raises: ValueError: when the complete stimulus data is not found ValueError: when the no-other animal data is not found Returns: np.ndarray: the complete signal np.ndarray: the frequency profile of the recorded fish np.ndarray: the frequency profile of the other fish np.ndarray: the time axis """ self_freq = None other_freq = None signal = None time = None if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") if "no-other" not in block.name and "other frequency" not in block.data_arrays: raise ValueError("Signals not stored in block!") signal = block.data_arrays["complete stimulus"][:] time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal))) self_freq = block.data_arrays["self frequency"][:] if "no-other" not in block.name: other_freq = block.data_arrays["other frequency"][:] return signal, self_freq, other_freq, time def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) all_contrasts = sorted(all_contrasts, reverse=True) for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) am = extract_am(signal) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) ax.set_title(condition_labels[i]) despine(ax, ["top", "bottom", "left", "right"], True) # plot the am ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], color="#2ca02c", label="signal") ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], color="#d62728", label="am") despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) # for each contrast plot the firing rate for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) ax.set_ylim([0, 750]) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) if j < len(all_contrasts) -1: ax.set_xticklabels([]) ax.set_yticks(np.arange(0.0, 751., 500)) ax.set_yticks(np.arange(0.0, 751., 125), minor=True) if i > 0: ax.set_yticklabels([]) despine(ax, ["top", "right"], False) if i == 2: ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], color="#d62728", ha="left", fontsize=7) if i == 1: ax.set_xlabel("time [ms]") if i == 0: ax.set_ylabel("frequency [Hz]", va="center") ax.yaxis.set_label_coords(-0.45, 3.5) name = figure_name if figure_name is not None else "chirp_responses.pdf" name = (name + ".pdf") if ".pdf" not in name else name plt.savefig(os.path.join(figure_folder, name)) plt.close() def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): detection_performance = {} for contrast in all_contrasts: print(" " * 50, end="\r") print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] # get some metadata assuming they are all the same for each conditionm, which they should duration = float(self_block.metadata["stimulus parameter"]["duration"]) dt = float(self_block.metadata["stimulus parameter"]["dt"]) chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"] chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"] interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) # get firing rates no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) self_snippets = np.zeros_like(no_other_snippets) for i in range(no_other_rates.shape[0]): for j, start in enumerate(interchirp_starts): start_index = int(start/dt) end_index = start_index + no_other_snippets.shape[1] index = i * len(interchirp_starts) + j no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] self_snippets[index, :] = self_rates[i, start_index:end_index] # get the distances baseline_dist = within_group_distance(no_other_snippets) comp_dist = across_group_distance(no_other_snippets, self_snippets) # sort and perfom roc triangle_indices = np.tril_indices_from(baseline_dist, -1) valid_distances_baseline = baseline_dist[triangle_indices] temp1 = np.zeros_like(valid_distances_baseline) valid_distances_comparison = comp_dist.ravel() temp2 = np.ones_like(valid_distances_comparison) group = np.hstack((temp1, temp2)) score = np.hstack((valid_distances_baseline, valid_distances_comparison)) fpr, tpr, _ = roc_curve(group, score, pos_label=1) auc = roc_auc_score(group, score) detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr} print("\n") return detection_performance def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): # return None def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005): dfs = [current_df] if current_df is not None else all_dfs detection_performance_beat = {} detection_performance_chirp = {} for df in dfs: detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width) detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width) return detection_performance_beat, detection_performance_chirp def process_cell(filename, dfs=[], contrasts=[], conditions=[]): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20) nf.close() def main(): nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) for nix_file in nix_files: process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"]) if __name__ == "__main__": main()