From 8ef2e672c5902bf152a48c1f7cfc85e80f50e87b Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Wed, 23 Sep 2020 13:31:08 +0200 Subject: [PATCH] beat discrimination working --- response_discriminability.py | 82 ++++++++++++++++++++---------------- util.py | 35 ++++++++++++--- 2 files changed, 75 insertions(+), 42 deletions(-) diff --git a/response_discriminability.py b/response_discriminability.py index fdcd156..1fdd6b9 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -4,9 +4,10 @@ import nixio as nix import numpy as np import scipy.signal as sig import matplotlib.pyplot as plt +from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed -from util import firing_rate, despine +from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance figure_folder = "figures" data_folder = "data" @@ -143,25 +144,6 @@ def get_signals(block): return signal, self_freq, other_freq, time -def extract_am(signal): - """Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end. - - Args: - signal (np.ndarray): the signal - - Returns: - np.ndarray: the am, i.e. the absolute value of the Hilbert transform. - """ - # first add some padding to both ends - front_pad = np.flip(signal[:int(len(signal)/100)]) - back_pad = np.flip(signal[-int(len(signal)/100):]) - padded = np.hstack((front_pad, signal, back_pad)) - # do the hilbert and take abs, cut away the padding - am = np.abs(sig.hilbert(padded)) - am = am[len(front_pad):-len(back_pad)] - return am - - def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] @@ -243,34 +225,60 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): detection_performance = {} + for contrast in all_contrasts: + print(" " * 50, end="\r") + print("Contrast: %.3f" % contrast, end="\r") no_other_block = block_map[(contrast, df, "no-other")] self_block = block_map[(contrast, df, "self")] - # get some metadata assuming they are all the same for each condition + # get some metadata assuming they are all the same for each conditionm, which they should duration = float(self_block.metadata["stimulus parameter"]["duration"]) dt = float(self_block.metadata["stimulus parameter"]["dt"]) chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"] chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"] - interchirp_starts = [] - interchirp_ends = [] - for ct in chirp_times: - interchirp_starts.append(ct + 1.5 * chirp_duration) - interchirp_ends.append(ct - 1.5 * chirp_duration) - del interchirp_ends[0] - del interchirp_starts[-1] + + interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1] + interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:] + ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000 + # get the spiking responses no_other_spikes = get_spikes(no_other_block) self_spikes = get_spikes(self_block) + # get firing rates - no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width) - self_rates = get_rates(self_spikes, duration, dt, kernel_width) + no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width) + self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width) # get the response snippets between chrips + no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) + self_snippets = np.zeros_like(no_other_snippets) + for i in range(no_other_rates.shape[0]): + for j, start in enumerate(interchirp_starts): + start_index = int(start/dt) + end_index = start_index + no_other_snippets.shape[1] + index = i * len(interchirp_starts) + j + no_other_snippets[index, :] = no_other_rates[i, start_index:end_index] + self_snippets[index, :] = self_rates[i, start_index:end_index] - # get the distances and do the roc - embed() - break; + # get the distances + baseline_dist = within_group_distance(no_other_snippets) + comp_dist = across_group_distance(no_other_snippets, self_snippets) + + # sort and perfom roc + triangle_indices = np.tril_indices_from(baseline_dist, -1) + valid_distances_baseline = baseline_dist[triangle_indices] + temp1 = np.zeros_like(valid_distances_baseline) + + valid_distances_comparison = comp_dist.ravel() + temp2 = np.ones_like(valid_distances_comparison) + + group = np.hstack((temp1, temp2)) + score = np.hstack((valid_distances_baseline, valid_distances_comparison)) + fpr, tpr, _ = roc_curve(group, score, pos_label=1) + auc = roc_auc_score(group, score) + detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr} + print("\n") return detection_performance @@ -281,11 +289,11 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005): dfs = [current_df] if current_df is not None else all_dfs - detection_performance_beat = [] - detection_performance_chirp = [] + detection_performance_beat = {} + detection_performance_chirp = {} for df in dfs: - detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)) - detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)) + detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width) + detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width) return detection_performance_beat, detection_performance_chirp diff --git a/util.py b/util.py index 682dc5c..4b5192a 100644 --- a/util.py +++ b/util.py @@ -34,6 +34,26 @@ def gaussKernel(sigma, dt): return y +def extract_am(signal): + """Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end. + + Args: + signal (np.ndarray): the signal + + Returns: + np.ndarray: the am, i.e. the absolute value of the Hilbert transform. + """ + # first add some padding to both ends + front_pad = np.flip(signal[:int(len(signal)/100)]) + back_pad = np.flip(signal[-int(len(signal)/100):]) + padded = np.hstack((front_pad, signal, back_pad)) + # do the hilbert and take abs, cut away the padding + am = np.abs(sig.hilbert(padded)) + am = am[len(front_pad):-len(back_pad)] + return am + + + def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.): """Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel @@ -92,25 +112,30 @@ def spiketrain_distance(spikes, duration, dt, kernel_width=0.001): return distances -def rate_distance(rates1, rates2, axis=0): +def across_group_distance(rates1, rates2, axis=0): + if axis == 1: + rates1 = rates1.T + rates2 = rates2.T distances = np.zeros((rates1.shape[axis], rates2.shape[axis])) for i in range(distances.shape[0]): for j in range(distances.shape[1]): - distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2)) + distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis] return distances -def rate_distance(rates, axis=0): +def within_group_distance(rates, axis=0): distances = np.zeros((rates.shape[axis], rates.shape[axis])) if axis == 1: rates = rates.T for i in range(distances.shape[0]): for j in range(distances.shape[1]): - if i < j: - distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)) + if j < i: + distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis] distances[j, i] = distances[i, j] elif i == j: distances[i, j] = 0.0 else: break + + return distances