beat discrimination working

This commit is contained in:
Jan Grewe 2020-09-23 13:31:08 +02:00
parent 1f8d9a3624
commit 8ef2e672c5
2 changed files with 75 additions and 42 deletions

View File

@ -4,9 +4,10 @@ import nixio as nix
import numpy as np import numpy as np
import scipy.signal as sig import scipy.signal as sig
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed from IPython import embed
from util import firing_rate, despine from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
figure_folder = "figures" figure_folder = "figures"
data_folder = "data" data_folder = "data"
@ -143,25 +144,6 @@ def get_signals(block):
return signal, self_freq, other_freq, time return signal, self_freq, other_freq, time
def extract_am(signal):
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
Args:
signal (np.ndarray): the signal
Returns:
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
"""
# first add some padding to both ends
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs, cut away the padding
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"] conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"] condition_labels = ["soliloquy", "self chirping", "other chirping"]
@ -243,34 +225,60 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005): def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
detection_performance = {} detection_performance = {}
for contrast in all_contrasts: for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, "no-other")] no_other_block = block_map[(contrast, df, "no-other")]
self_block = block_map[(contrast, df, "self")] self_block = block_map[(contrast, df, "self")]
# get some metadata assuming they are all the same for each condition # get some metadata assuming they are all the same for each conditionm, which they should
duration = float(self_block.metadata["stimulus parameter"]["duration"]) duration = float(self_block.metadata["stimulus parameter"]["duration"])
dt = float(self_block.metadata["stimulus parameter"]["dt"]) dt = float(self_block.metadata["stimulus parameter"]["dt"])
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"] chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"] chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
interchirp_starts = []
interchirp_ends = [] interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
for ct in chirp_times: interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
interchirp_starts.append(ct + 1.5 * chirp_duration) ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
interchirp_ends.append(ct - 1.5 * chirp_duration)
del interchirp_ends[0]
del interchirp_starts[-1]
# get the spiking responses # get the spiking responses
no_other_spikes = get_spikes(no_other_block) no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block) self_spikes = get_spikes(self_block)
# get firing rates # get firing rates
no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width) no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates = get_rates(self_spikes, duration, dt, kernel_width) self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
# get the response snippets between chrips # get the response snippets between chrips
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
# get the distances and do the roc self_snippets = np.zeros_like(no_other_snippets)
embed() for i in range(no_other_rates.shape[0]):
break; for j, start in enumerate(interchirp_starts):
start_index = int(start/dt)
end_index = start_index + no_other_snippets.shape[1]
index = i * len(interchirp_starts) + j
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
# get the distances
baseline_dist = within_group_distance(no_other_snippets)
comp_dist = across_group_distance(no_other_snippets, self_snippets)
# sort and perfom roc
triangle_indices = np.tril_indices_from(baseline_dist, -1)
valid_distances_baseline = baseline_dist[triangle_indices]
temp1 = np.zeros_like(valid_distances_baseline)
valid_distances_comparison = comp_dist.ravel()
temp2 = np.ones_like(valid_distances_comparison)
group = np.hstack((temp1, temp2))
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr}
print("\n")
return detection_performance return detection_performance
@ -281,11 +289,11 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005): def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
dfs = [current_df] if current_df is not None else all_dfs dfs = [current_df] if current_df is not None else all_dfs
detection_performance_beat = [] detection_performance_beat = {}
detection_performance_chirp = [] detection_performance_chirp = {}
for df in dfs: for df in dfs:
detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)) detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)
detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)) detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)
return detection_performance_beat, detection_performance_chirp return detection_performance_beat, detection_performance_chirp

35
util.py
View File

@ -34,6 +34,26 @@ def gaussKernel(sigma, dt):
return y return y
def extract_am(signal):
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
Args:
signal (np.ndarray): the signal
Returns:
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
"""
# first add some padding to both ends
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs, cut away the padding
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.): def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
"""Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel """Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel
@ -92,25 +112,30 @@ def spiketrain_distance(spikes, duration, dt, kernel_width=0.001):
return distances return distances
def rate_distance(rates1, rates2, axis=0): def across_group_distance(rates1, rates2, axis=0):
if axis == 1:
rates1 = rates1.T
rates2 = rates2.T
distances = np.zeros((rates1.shape[axis], rates2.shape[axis])) distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
for i in range(distances.shape[0]): for i in range(distances.shape[0]):
for j in range(distances.shape[1]): for j in range(distances.shape[1]):
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2)) distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis]
return distances return distances
def rate_distance(rates, axis=0): def within_group_distance(rates, axis=0):
distances = np.zeros((rates.shape[axis], rates.shape[axis])) distances = np.zeros((rates.shape[axis], rates.shape[axis]))
if axis == 1: if axis == 1:
rates = rates.T rates = rates.T
for i in range(distances.shape[0]): for i in range(distances.shape[0]):
for j in range(distances.shape[1]): for j in range(distances.shape[1]):
if i < j: if j < i:
distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)) distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis]
distances[j, i] = distances[i, j] distances[j, i] = distances[i, j]
elif i == j: elif i == j:
distances[i, j] = 0.0 distances[i, j] = 0.0
else: else:
break break
return distances