beat discrimination working

This commit is contained in:
Jan Grewe 2020-09-23 13:31:08 +02:00
parent 1f8d9a3624
commit 8ef2e672c5
2 changed files with 75 additions and 42 deletions

View File

@ -4,9 +4,10 @@ import nixio as nix
import numpy as np
import scipy.signal as sig
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed
from util import firing_rate, despine
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
figure_folder = "figures"
data_folder = "data"
@ -143,25 +144,6 @@ def get_signals(block):
return signal, self_freq, other_freq, time
def extract_am(signal):
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
Args:
signal (np.ndarray): the signal
Returns:
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
"""
# first add some padding to both ends
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs, cut away the padding
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
@ -243,34 +225,60 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
detection_performance = {}
for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, "no-other")]
self_block = block_map[(contrast, df, "self")]
# get some metadata assuming they are all the same for each condition
# get some metadata assuming they are all the same for each conditionm, which they should
duration = float(self_block.metadata["stimulus parameter"]["duration"])
dt = float(self_block.metadata["stimulus parameter"]["dt"])
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
interchirp_starts = []
interchirp_ends = []
for ct in chirp_times:
interchirp_starts.append(ct + 1.5 * chirp_duration)
interchirp_ends.append(ct - 1.5 * chirp_duration)
del interchirp_ends[0]
del interchirp_starts[-1]
interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
# get firing rates
no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates = get_rates(self_spikes, duration, dt, kernel_width)
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
# get the response snippets between chrips
# get the distances and do the roc
embed()
break;
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
self_snippets = np.zeros_like(no_other_snippets)
for i in range(no_other_rates.shape[0]):
for j, start in enumerate(interchirp_starts):
start_index = int(start/dt)
end_index = start_index + no_other_snippets.shape[1]
index = i * len(interchirp_starts) + j
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
# get the distances
baseline_dist = within_group_distance(no_other_snippets)
comp_dist = across_group_distance(no_other_snippets, self_snippets)
# sort and perfom roc
triangle_indices = np.tril_indices_from(baseline_dist, -1)
valid_distances_baseline = baseline_dist[triangle_indices]
temp1 = np.zeros_like(valid_distances_baseline)
valid_distances_comparison = comp_dist.ravel()
temp2 = np.ones_like(valid_distances_comparison)
group = np.hstack((temp1, temp2))
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr}
print("\n")
return detection_performance
@ -281,11 +289,11 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
dfs = [current_df] if current_df is not None else all_dfs
detection_performance_beat = []
detection_performance_chirp = []
detection_performance_beat = {}
detection_performance_chirp = {}
for df in dfs:
detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width))
detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width))
detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)
detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)
return detection_performance_beat, detection_performance_chirp

35
util.py
View File

@ -34,6 +34,26 @@ def gaussKernel(sigma, dt):
return y
def extract_am(signal):
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
Args:
signal (np.ndarray): the signal
Returns:
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
"""
# first add some padding to both ends
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs, cut away the padding
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
"""Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel
@ -92,25 +112,30 @@ def spiketrain_distance(spikes, duration, dt, kernel_width=0.001):
return distances
def rate_distance(rates1, rates2, axis=0):
def across_group_distance(rates1, rates2, axis=0):
if axis == 1:
rates1 = rates1.T
rates2 = rates2.T
distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
for i in range(distances.shape[0]):
for j in range(distances.shape[1]):
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis]
return distances
def rate_distance(rates, axis=0):
def within_group_distance(rates, axis=0):
distances = np.zeros((rates.shape[axis], rates.shape[axis]))
if axis == 1:
rates = rates.T
for i in range(distances.shape[0]):
for j in range(distances.shape[1]):
if i < j:
distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2))
if j < i:
distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis]
distances[j, i] = distances[i, j]
elif i == j:
distances[i, j] = 0.0
else:
break
return distances