beat discrimination working
This commit is contained in:
parent
1f8d9a3624
commit
8ef2e672c5
@ -4,9 +4,10 @@ import nixio as nix
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.signal as sig
|
import scipy.signal as sig
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.metrics import roc_curve, roc_auc_score
|
||||||
from IPython import embed
|
from IPython import embed
|
||||||
|
|
||||||
from util import firing_rate, despine
|
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
|
||||||
figure_folder = "figures"
|
figure_folder = "figures"
|
||||||
data_folder = "data"
|
data_folder = "data"
|
||||||
|
|
||||||
@ -143,25 +144,6 @@ def get_signals(block):
|
|||||||
return signal, self_freq, other_freq, time
|
return signal, self_freq, other_freq, time
|
||||||
|
|
||||||
|
|
||||||
def extract_am(signal):
|
|
||||||
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
signal (np.ndarray): the signal
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
|
|
||||||
"""
|
|
||||||
# first add some padding to both ends
|
|
||||||
front_pad = np.flip(signal[:int(len(signal)/100)])
|
|
||||||
back_pad = np.flip(signal[-int(len(signal)/100):])
|
|
||||||
padded = np.hstack((front_pad, signal, back_pad))
|
|
||||||
# do the hilbert and take abs, cut away the padding
|
|
||||||
am = np.abs(sig.hilbert(padded))
|
|
||||||
am = am[len(front_pad):-len(back_pad)]
|
|
||||||
return am
|
|
||||||
|
|
||||||
|
|
||||||
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
|
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
|
||||||
conditions = ["no-other", "self", "other"]
|
conditions = ["no-other", "self", "other"]
|
||||||
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
||||||
@ -243,34 +225,60 @@ def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, curr
|
|||||||
|
|
||||||
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
|
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
|
||||||
detection_performance = {}
|
detection_performance = {}
|
||||||
|
|
||||||
for contrast in all_contrasts:
|
for contrast in all_contrasts:
|
||||||
|
print(" " * 50, end="\r")
|
||||||
|
print("Contrast: %.3f" % contrast, end="\r")
|
||||||
no_other_block = block_map[(contrast, df, "no-other")]
|
no_other_block = block_map[(contrast, df, "no-other")]
|
||||||
self_block = block_map[(contrast, df, "self")]
|
self_block = block_map[(contrast, df, "self")]
|
||||||
|
|
||||||
# get some metadata assuming they are all the same for each condition
|
# get some metadata assuming they are all the same for each conditionm, which they should
|
||||||
duration = float(self_block.metadata["stimulus parameter"]["duration"])
|
duration = float(self_block.metadata["stimulus parameter"]["duration"])
|
||||||
dt = float(self_block.metadata["stimulus parameter"]["dt"])
|
dt = float(self_block.metadata["stimulus parameter"]["dt"])
|
||||||
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
|
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
|
||||||
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
|
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
|
||||||
interchirp_starts = []
|
|
||||||
interchirp_ends = []
|
interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
|
||||||
for ct in chirp_times:
|
interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
|
||||||
interchirp_starts.append(ct + 1.5 * chirp_duration)
|
ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
|
||||||
interchirp_ends.append(ct - 1.5 * chirp_duration)
|
|
||||||
del interchirp_ends[0]
|
|
||||||
del interchirp_starts[-1]
|
|
||||||
# get the spiking responses
|
# get the spiking responses
|
||||||
no_other_spikes = get_spikes(no_other_block)
|
no_other_spikes = get_spikes(no_other_block)
|
||||||
self_spikes = get_spikes(self_block)
|
self_spikes = get_spikes(self_block)
|
||||||
|
|
||||||
# get firing rates
|
# get firing rates
|
||||||
no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width)
|
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
|
||||||
self_rates = get_rates(self_spikes, duration, dt, kernel_width)
|
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
|
||||||
|
|
||||||
# get the response snippets between chrips
|
# get the response snippets between chrips
|
||||||
|
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
|
||||||
|
self_snippets = np.zeros_like(no_other_snippets)
|
||||||
|
for i in range(no_other_rates.shape[0]):
|
||||||
|
for j, start in enumerate(interchirp_starts):
|
||||||
|
start_index = int(start/dt)
|
||||||
|
end_index = start_index + no_other_snippets.shape[1]
|
||||||
|
index = i * len(interchirp_starts) + j
|
||||||
|
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
|
||||||
|
self_snippets[index, :] = self_rates[i, start_index:end_index]
|
||||||
|
|
||||||
# get the distances and do the roc
|
# get the distances
|
||||||
embed()
|
baseline_dist = within_group_distance(no_other_snippets)
|
||||||
break;
|
comp_dist = across_group_distance(no_other_snippets, self_snippets)
|
||||||
|
|
||||||
|
# sort and perfom roc
|
||||||
|
triangle_indices = np.tril_indices_from(baseline_dist, -1)
|
||||||
|
valid_distances_baseline = baseline_dist[triangle_indices]
|
||||||
|
temp1 = np.zeros_like(valid_distances_baseline)
|
||||||
|
|
||||||
|
valid_distances_comparison = comp_dist.ravel()
|
||||||
|
temp2 = np.ones_like(valid_distances_comparison)
|
||||||
|
|
||||||
|
group = np.hstack((temp1, temp2))
|
||||||
|
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
|
||||||
|
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
|
||||||
|
auc = roc_auc_score(group, score)
|
||||||
|
detection_performance[(contrast, kernel_width)] = {"auc": auc, "true positives": tpr, "false positives": fpr}
|
||||||
|
print("\n")
|
||||||
return detection_performance
|
return detection_performance
|
||||||
|
|
||||||
|
|
||||||
@ -281,11 +289,11 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
|
|||||||
|
|
||||||
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
|
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
|
||||||
dfs = [current_df] if current_df is not None else all_dfs
|
dfs = [current_df] if current_df is not None else all_dfs
|
||||||
detection_performance_beat = []
|
detection_performance_beat = {}
|
||||||
detection_performance_chirp = []
|
detection_performance_chirp = {}
|
||||||
for df in dfs:
|
for df in dfs:
|
||||||
detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width))
|
detection_performance_beat[df] = foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width)
|
||||||
detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width))
|
detection_performance_chirp[df] = foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width)
|
||||||
|
|
||||||
return detection_performance_beat, detection_performance_chirp
|
return detection_performance_beat, detection_performance_chirp
|
||||||
|
|
||||||
|
35
util.py
35
util.py
@ -34,6 +34,26 @@ def gaussKernel(sigma, dt):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def extract_am(signal):
|
||||||
|
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signal (np.ndarray): the signal
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
|
||||||
|
"""
|
||||||
|
# first add some padding to both ends
|
||||||
|
front_pad = np.flip(signal[:int(len(signal)/100)])
|
||||||
|
back_pad = np.flip(signal[-int(len(signal)/100):])
|
||||||
|
padded = np.hstack((front_pad, signal, back_pad))
|
||||||
|
# do the hilbert and take abs, cut away the padding
|
||||||
|
am = np.abs(sig.hilbert(padded))
|
||||||
|
am = am[len(front_pad):-len(back_pad)]
|
||||||
|
return am
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
|
def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
|
||||||
"""Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel
|
"""Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel
|
||||||
|
|
||||||
@ -92,25 +112,30 @@ def spiketrain_distance(spikes, duration, dt, kernel_width=0.001):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def rate_distance(rates1, rates2, axis=0):
|
def across_group_distance(rates1, rates2, axis=0):
|
||||||
|
if axis == 1:
|
||||||
|
rates1 = rates1.T
|
||||||
|
rates2 = rates2.T
|
||||||
distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
|
distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
|
||||||
for i in range(distances.shape[0]):
|
for i in range(distances.shape[0]):
|
||||||
for j in range(distances.shape[1]):
|
for j in range(distances.shape[1]):
|
||||||
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))
|
distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis]
|
||||||
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def rate_distance(rates, axis=0):
|
def within_group_distance(rates, axis=0):
|
||||||
distances = np.zeros((rates.shape[axis], rates.shape[axis]))
|
distances = np.zeros((rates.shape[axis], rates.shape[axis]))
|
||||||
if axis == 1:
|
if axis == 1:
|
||||||
rates = rates.T
|
rates = rates.T
|
||||||
for i in range(distances.shape[0]):
|
for i in range(distances.shape[0]):
|
||||||
for j in range(distances.shape[1]):
|
for j in range(distances.shape[1]):
|
||||||
if i < j:
|
if j < i:
|
||||||
distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2))
|
distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis]
|
||||||
distances[j, i] = distances[i, j]
|
distances[j, i] = distances[i, j]
|
||||||
elif i == j:
|
elif i == j:
|
||||||
distances[i, j] = 0.0
|
distances[i, j] = 0.0
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
return distances
|
||||||
|
Loading…
Reference in New Issue
Block a user