From 1f8d9a36243876e0e0f2ddafa51e16bb3137db52 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 21 Sep 2020 08:54:35 +0200 Subject: [PATCH] more on the foreign fish detection, some notes --- chirps_as_probing_signals.md | 11 ++- response_discriminability.py | 179 ++++++++++++++++++++++++++--------- util.py | 66 ++++++++++++- 3 files changed, 207 insertions(+), 49 deletions(-) diff --git a/chirps_as_probing_signals.md b/chirps_as_probing_signals.md index 188b3ed..244fedd 100644 --- a/chirps_as_probing_signals.md +++ b/chirps_as_probing_signals.md @@ -32,13 +32,19 @@ Dimensionalities involved, The beat frequency, the distance (contrast), the chir * with foreign generated chirps Won't do, this is trivial?! -### 2. Use Alex' model to get the P-unit responses +### 2. Use Alex' model to get the P-unit responses --> Done * implement the Chripstimulus class * move along the same lines as for the input signals * create the stimulus for a range of contrasts, with self of the other fish chirping, each stimulus phase contains a phase in wich there is no foreign fish. * calculate a bunch (10) trials for each condition and estimate the detecatability of a foreign fish -* estimate the distance between the responses without the other fish and the beat response as well as the chirp response. + +### 3. Does the chirp increase the detectablility of another animal? + +* Work out the difference between baseline activity and a foreign chirp response: + * calculate the discriminability between the baseline (no-other fish present) and the another fish is present for each contrast +* Work out the difference between the soliloquy and the response to self generated chirp in a communication context +* Compare to the beat alone parts of the responses. ## Random thoughts @@ -46,3 +52,4 @@ Won't do, this is trivial?! * Raab et al show this is also the case with rises. * Check role of AFRs and rises in Tallarovic et al, Hupe et al. * we actually do not observe chirps without stimulation +* diff --git a/response_discriminability.py b/response_discriminability.py index 96d8d05..fdcd156 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -43,25 +43,89 @@ def sort_blocks(nix_file): return block_map, contrasts, deltafs, conditions -def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): - block = block_map[(contrast, df, condition)] +def get_spikes(block): + """Get the spike trains. + + Args: + block ([type]): [description] + + Returns: + list of np.ndarray: the spike trains. + """ response_map = {} spikes = [] + for da in block.data_arrays: if "spike_times" in da.type and "response" in da.name: resp_id = int(da.name.split("_")[-1]) response_map[resp_id] = da + for k in sorted(response_map.keys()): + spikes.append(response_map[k][:]) + + return spikes + + +def get_rates(spike_trains, duration, dt, kernel_width): + """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. + + Args: + spike_trains ([type]): [description] + duration ([type]): [description] + dt ([type]): [description] + kernel_width ([type]): [description] + + Returns: + np.ndarray: Matrix of firing rates, 1. dimension is the number of trials + np.ndarray: the time vector + """ + time = np.arange(0.0, duration, dt) + rates = np.zeros((len(spike_trains), len(time))) + for i, sp in enumerate(spike_trains): + rates[i, :] = firing_rate(sp, duration, kernel_width, dt) + + return rates, time + + +def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): + """Retruns the firing rates and the spikes + + Args: + block_map ([type]): [description] + df ([type]): [description] + contrast ([type]): [description] + condition ([type]): [description] + kernel_width (float, optional): [description]. Defaults to 0.0005. + + Returns: + np.ndarray: the time vector. + np.ndarray: the rates with the first dimension representing the trials. + np.adarray: the spike trains. + """ + block = block_map[(contrast, df, condition)] + spikes = get_spikes(block) duration = float(block.metadata["stimulus parameter"]["duration"]) dt = float(block.metadata["stimulus parameter"]["dt"]) - time = np.arange(0.0, duration, dt) - rates = np.zeros((len(response_map.keys()), len(time))) - for i, k in enumerate(response_map.keys()): - spikes.append(response_map[k][:]) - rates[i,:] = firing_rate(spikes[-1], duration, kernel_width, dt) + + rates, time = get_rates(spikes, duration, dt, kernel_width) return time, rates, spikes def get_signals(block): + """Read the fish signals from block. + + Args: + block ([type]): the block containing the data for a given df, contrast and condition + + Raises: + ValueError: when the complete stimulus data is not found + ValueError: when the no-other animal data is not found + + Returns: + np.ndarray: the complete signal + np.ndarray: the frequency profile of the recorded fish + np.ndarray: the frequency profile of the other fish + np.ndarray: the time axis + """ self_freq = None other_freq = None signal = None @@ -80,11 +144,19 @@ def get_signals(block): def extract_am(signal): - # first add some padding + """Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end. + + Args: + signal (np.ndarray): the signal + + Returns: + np.ndarray: the am, i.e. the absolute value of the Hilbert transform. + """ + # first add some padding to both ends front_pad = np.flip(signal[:int(len(signal)/100)]) back_pad = np.flip(signal[-int(len(signal)/100):]) padded = np.hstack((front_pad, signal, back_pad)) - # do the hilbert and take abs + # do the hilbert and take abs, cut away the padding am = np.abs(sig.hilbert(padded)) am = am[len(front_pad):-len(back_pad)] return am @@ -130,33 +202,9 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) - """ - # for the largest contrast plot the raster with psth, only a section of the data (e.g. 1s) - t, rates, spikes = get_firing_rate(block_map, current_df, all_contrasts[0], condition, kernel_width=0.001) - avg_resp = np.mean(rates, axis=0) - error = np.std(rates, axis=0) - - ax = plt.subplot2grid(fig_grid, (6, i * 3 + i), rowspan=2, colspan=3) - ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], - color="k", lw=0.5) - ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], - (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) - ax.set_ylim([0, 750]) - ax.set_xlabel("") - ax.set_ylabel("") - ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) - ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) - ax.set_xticks(np.arange(min_time, max_time+.01, 0.0625), minor=True) - ax.set_xticklabels([]) - ax.set_yticks(np.arange(0.0, 751., 500)) - ax.set_yticks(np.arange(0.0, 751., 125), minor=True) - if i > 0: - ax.set_yticklabels([]) - despine(ax, ["top", "right"], False) - """ - # for all other contrast plot the firing rate alone - for j in range(0, len(all_contrasts)): - contrast = all_contrasts[j] + + # for each contrast plot the firing rate + for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) @@ -193,14 +241,53 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr plt.close() -def chrip_detection_soliloquy(spikes, chirp_times, kernel_width=0.0005): - # - pass +def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): + detection_performance = {} + for contrast in all_contrasts: + no_other_block = block_map[(contrast, df, "no-other")] + self_block = block_map[(contrast, df, "self")] + + # get some metadata assuming they are all the same for each condition + duration = float(self_block.metadata["stimulus parameter"]["duration"]) + dt = float(self_block.metadata["stimulus parameter"]["dt"]) + chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"] + chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"] + interchirp_starts = [] + interchirp_ends = [] + for ct in chirp_times: + interchirp_starts.append(ct + 1.5 * chirp_duration) + interchirp_ends.append(ct - 1.5 * chirp_duration) + del interchirp_ends[0] + del interchirp_starts[-1] + # get the spiking responses + no_other_spikes = get_spikes(no_other_block) + self_spikes = get_spikes(self_block) + # get firing rates + no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width) + self_rates = get_rates(self_spikes, duration, dt, kernel_width) + + # get the response snippets between chrips + + # get the distances and do the roc + embed() + break; + return detection_performance -def chirp_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, current_condition=None): +def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): + # + return None + - pass +def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005): + dfs = [current_df] if current_df is not None else all_dfs + detection_performance_beat = [] + detection_performance_chirp = [] + for df in dfs: + detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)) + detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)) + + return detection_performance_beat, detection_performance_chirp def process_cell(filename, dfs=[], contrasts=[], conditions=[]): @@ -210,11 +297,11 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]): baseline_spikes = read_baseline(block_map["baseline"]) else: print("ERROR: no baseline data for file %s!" % filename) - fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" - create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) - fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" - create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) - chirp_detection(block_map, all_dfs, all_contrasts, all_conditions) + # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" + # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) + # fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" + # create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) + foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20) nf.close() diff --git a/util.py b/util.py index 2832dc7..682dc5c 100644 --- a/util.py +++ b/util.py @@ -1,4 +1,7 @@ +from typing import ValuesView import numpy as np +from numpy.lib.function_base import iterable +from numpy.lib.index_tricks import diag_indices def despine(axis, spines=None, hide_ticks=True): @@ -49,4 +52,65 @@ def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.): kernel = gaussKernel(sigma, dt) rate = np.convolve(kernel, binary, mode="same") - return rate \ No newline at end of file + return rate + + +def spiketrain_distance(spikes, duration, dt, kernel_width=0.001): + """Calculate the Euclidean distance between spike trains. Firing rates are estimated using the kernel + convloution technique applying a Gaussian kernel of the given standard deviation. + + Args: + spikes (list of iterable): list of spike trains. event times are given in seconds. + duration (float): duration of a trial given in seconds. + dt (float): stepsize of the recording, given in seconds. + kernel_width (float, optional): standard deviation of the Gaussian kernel used to estimate the firing rate. Defaults to 0.001. + + Returns: + np.ndarray: the distances + """ + # perform some checks + if not isinstance(spikes, list): + raise ValueError("spikes must be a list of spike trains, aka iterables of spike times.") + if len(spikes) > 1 and not isinstance(spikes[0], iterable): + raise ValueError("spikes must be a list of spike trains, aka iterables of spike times.") + + rates = np.zeros((len(spikes), int(duration/dt))) + for i in range(len(spikes)): + rates[i,:] = firing_rate(spikes[0], duration, kernel_width, dt) + + distances = np.zeros((len(spikes), len(spikes))) + for i in range(len(spikes)): + for j in range(len(spikes)): + if i < j: + distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)) + distances[j, i] = distances[i, j] + elif i == j: + distances[i, j] = 0.0 + else: + break + + return distances + + +def rate_distance(rates1, rates2, axis=0): + distances = np.zeros((rates1.shape[axis], rates2.shape[axis])) + for i in range(distances.shape[0]): + for j in range(distances.shape[1]): + distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2)) + + return distances + + +def rate_distance(rates, axis=0): + distances = np.zeros((rates.shape[axis], rates.shape[axis])) + if axis == 1: + rates = rates.T + for i in range(distances.shape[0]): + for j in range(distances.shape[1]): + if i < j: + distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)) + distances[j, i] = distances[i, j] + elif i == j: + distances[i, j] = 0.0 + else: + break