beat discrimination working
This commit is contained in:
parent
1f8d9a3624
commit
8ef2e672c5
@ -4,9 +4,10 @@ import nixio as nix
|
||||
import numpy as np
|
||||
import scipy.signal as sig
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import roc_curve, roc_auc_score
|
||||
from IPython import embed
|
||||
|
||||
from util import firing_rate, despine
|
||||
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
|
||||
figure_folder = "figures"
|
||||
data_folder = "data"
|
||||
|
||||
@ -143,25 +144,6 @@ def get_signals(block):
|
||||
return signal, self_freq, other_freq, time
|
||||
|
||||
|
||||
def extract_am(signal):
|
||||
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
|
||||
|
||||
Args:
|
||||
signal (np.ndarray): the signal
|
||||
|
||||
Returns:
|
||||
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
|
||||
"""
|
||||
# first add some padding to both ends
|
||||
front_pad = np.flip(signal[:int(len(signal)/100)])
|
||||
back_pad = np.flip(signal[-int(len(signal)/100):])
|
||||
padded = np.hstack((front_pad, signal, back_pad))
|
||||
# do the hilbert and take abs, cut away the padding
|
||||
am = np.abs(sig.hilbert(padded))
|
||||
am = am[len(front_pad):-len(back_pad)]
|
||||
return am
|
||||
|
||||
|
||||
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
|
||||
conditions = ["no-other", "self", "other"]
|
||||
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
||||
@ -243,34 +225,60 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr
|
||||
|
||||
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
|
||||
detection_performance = {}
|
||||
|
||||
for contrast in all_contrasts:
|
||||
print(" " * 50, end="\r")
|
||||
print("Contrast: %.3f" % contrast, end="\r")
|
||||
no_other_block = block_map[(contrast, df, "no-other")]
|
||||
self_block = block_map[(contrast, df, "self")]
|
||||
|
||||
# get some metadata assuming they are all the same for each condition
|
||||
# get some metadata assuming they are all the same for each conditionm, which they should
|
||||
duration = float(self_block.metadata["stimulus parameter"]["duration"])
|
||||
dt = float(self_block.metadata["stimulus parameter"]["dt"])
|
||||
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
|
||||
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
|
||||
interchirp_starts = []
|
||||
interchirp_ends = []
|
||||
for ct in chirp_times:
|
||||
interchirp_starts.append(ct + 1.5 * chirp_duration)
|
||||
interchirp_ends.append(ct - 1.5 * chirp_duration)
|
||||
del interchirp_ends[0]
|
||||
del interchirp_starts[-1]
|
||||
|
||||
interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
|
||||
interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
|
||||
ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
|
||||
|
||||
# get the spiking responses
|
||||
no_other_spikes = get_spikes(no_other_block)
|
||||
self_spikes = get_spikes(self_block)
|
||||
|
||||
# get firing rates
|
||||
no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width)
|
||||
self_rates = get_rates(self_spikes, duration, dt, kernel_width)
|
||||
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
|
||||
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
|
||||
|
||||
# get the response snippets between chrips
|
||||
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
|
||||
self_snippets = np.zeros_like(no_other_snippets)
|
||||
for i in range(no_other_rates.shape[0]):
|
||||
for j, start in enumerate(interchirp_starts):
|
||||
start_index = int(start/dt)
|
||||
end_index = start_index + no_other_snippets.shape[1]
|
||||
index = i * len(interchirp_starts) + j
|
||||
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
|
||||
self_snippets[index, :] = self_rates[i, start_index:end_index]
|
||||
|
||||
# get the distances and do the roc
|
||||
embed()
|
||||
break;
|
||||
# get the distances
|
||||
baseline_dist = within_group_distance(no_other_snippets)
|
||||
comp_dist = across_group_distance(no_other_snippets, self_snippets)
|
||||
|
||||
# sort and perfom roc
|
||||
triangle_indices = np.tril_indices_from(baseline_dist, -1)
|
||||
valid_distances_baseline = baseline_dist[triangle_indices]
|
||||
temp1 = np.zeros_like(valid_distances_baseline)
|
||||
|
||||
valid_distances_comparison = comp_dist.ravel()
|
||||
temp2 = np.ones_like(valid_distances_comparison)
|
||||
|
||||
group = np.hstack((temp1, temp2))
|
||||
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
|
||||
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
|
||||
auc = roc_auc_score(group, score)
|
||||
detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr}
|
||||
print("\n")
|
||||
return detection_performance
|
||||
|
||||
|
||||
@ -281,11 +289,11 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
|
||||
|
||||
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
|
||||
dfs = [current_df] if current_df is not None else all_dfs
|
||||
detection_performance_beat = []
|
||||
detection_performance_chirp = []
|
||||
detection_performance_beat = {}
|
||||
detection_performance_chirp = {}
|
||||
for df in dfs:
|
||||
detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width))
|
||||
detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width))
|
||||
detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)
|
||||
detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)
|
||||
|
||||
return detection_performance_beat, detection_performance_chirp
|
||||
|
||||
|
35
util.py
35
util.py
@ -34,6 +34,26 @@ def gaussKernel(sigma, dt):
|
||||
return y
|
||||
|
||||
|
||||
def extract_am(signal):
|
||||
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
|
||||
|
||||
Args:
|
||||
signal (np.ndarray): the signal
|
||||
|
||||
Returns:
|
||||
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
|
||||
"""
|
||||
# first add some padding to both ends
|
||||
front_pad = np.flip(signal[:int(len(signal)/100)])
|
||||
back_pad = np.flip(signal[-int(len(signal)/100):])
|
||||
padded = np.hstack((front_pad, signal, back_pad))
|
||||
# do the hilbert and take abs, cut away the padding
|
||||
am = np.abs(sig.hilbert(padded))
|
||||
am = am[len(front_pad):-len(back_pad)]
|
||||
return am
|
||||
|
||||
|
||||
|
||||
def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
|
||||
"""Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel
|
||||
|
||||
@ -92,25 +112,30 @@ def spiketrain_distance(spikes, duration, dt, kernel_width=0.001):
|
||||
return distances
|
||||
|
||||
|
||||
def rate_distance(rates1, rates2, axis=0):
|
||||
def across_group_distance(rates1, rates2, axis=0):
|
||||
if axis == 1:
|
||||
rates1 = rates1.T
|
||||
rates2 = rates2.T
|
||||
distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
|
||||
for i in range(distances.shape[0]):
|
||||
for j in range(distances.shape[1]):
|
||||
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))
|
||||
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis]
|
||||
|
||||
return distances
|
||||
|
||||
|
||||
def rate_distance(rates, axis=0):
|
||||
def within_group_distance(rates, axis=0):
|
||||
distances = np.zeros((rates.shape[axis], rates.shape[axis]))
|
||||
if axis == 1:
|
||||
rates = rates.T
|
||||
for i in range(distances.shape[0]):
|
||||
for j in range(distances.shape[1]):
|
||||
if i < j:
|
||||
distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2))
|
||||
if j < i:
|
||||
distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis]
|
||||
distances[j, i] = distances[i, j]
|
||||
elif i == j:
|
||||
distances[i, j] = 0.0
|
||||
else:
|
||||
break
|
||||
|
||||
return distances
|
||||
|
Loading…
Reference in New Issue
Block a user