chirp_probing/response_discriminability.py
2023-04-01 18:16:13 +02:00

347 lines
19 KiB
Python

import os
import glob
import pandas as pd
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed
import multiprocessing
from joblib import Parallel, delayed
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata
figure_folder = "figures"
data_folder = "data"
def get_rates(spike_trains, duration, dt, kernel_width):
"""Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.
Args:
spike_trains ([type]): [description]
duration ([type]): [description]
dt ([type]): [description]
kernel_width ([type]): [description]
Returns:
np.ndarray: Matrix of firing rates, 1. dimension is the number of trials
np.ndarray: the time vector
"""
time = np.arange(0.0, duration, dt)
rates = np.zeros((len(spike_trains), len(time)))
for i, sp in enumerate(spike_trains):
rates[i, :] = firing_rate(sp, duration, kernel_width, dt)
return rates, time
def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
"""Retruns the firing rates and the spikes
Args:
block_map ([type]): [description]
df ([type]): [description]
contrast ([type]): [description]
condition ([type]): [description]
kernel_width (float, optional): [description]. Defaults to 0.0005.
Returns:
np.ndarray: the time vector.
np.ndarray: the rates with the first dimension representing the trials.
np.adarray: the spike trains.
"""
block = block_map[(contrast, df, condition)]
spikes = get_spikes(block)
duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
rates, time = get_rates(spikes, duration, dt, kernel_width)
return time, rates, spikes
def foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
"""Tries to detect the presence of a foreign fish by estimating the discriminability of the responses during the beat
versus the responses without another fish beeing there, i.e. the baseline activity.
Applies a ROC analysis to the response segments between chirps. Calculates a) the distances between the baseline responses and
b) distances between the baseline and beat responses. Tests whether distances in b) are larger than a)
Args:
block_map ([type]): maps nix blocks to combination of stimulus parameters
df ([type]): the difference frequency that should be used
cs ([type]): ths chirpsize that should be used
all_contrasts ([type]): list of all used contrasts
all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other
kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005.
cell_name (str, optional): name of the cell. Defaults to "".
store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False.
Returns:
list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1]. The 'detection_task' is 'beat'
"""
detection_performances = []
for contrast in all_contrasts:
# print(" " * 50, end="\r")
# print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, cs, "no-other")]
self_block = block_map[(contrast, df, cs, "self")]
# get some metadata assuming they are all the same for each condition, which they should
duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
# get firing rates
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
# get the response snippets between chrips
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt))) # section b, alone, no chirps
self_snippets = np.zeros_like(no_other_snippets) # section d, in company, no chirps, just beat
for i in range(no_other_rates.shape[0]):
for j, start in enumerate(interchirp_starts):
start_index = int(start/dt)
end_index = start_index + no_other_snippets.shape[1]
index = i * len(interchirp_starts) + j
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
# get the distances
baseline_dist = within_group_distance(no_other_snippets)
comp_dist = across_group_distance(no_other_snippets, self_snippets)
# sort and perfom ROC analysis
triangle_indices = np.tril_indices_from(baseline_dist, -1)
valid_distances_baseline = baseline_dist[triangle_indices]
temp1 = np.zeros_like(valid_distances_baseline)
valid_distances_comparison = comp_dist.ravel()
temp2 = np.ones_like(valid_distances_comparison)
group = np.hstack((temp1, temp2))
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# print("\n")
return detection_performances
def foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
"""Tries to detect the presence of a foreign fish by estimating the discriminability of the chirp
responses in the presence of another fish versus the responses without another fish beeing around.
Applies a ROC analysis to the response segments containing the chirp. Does two discrimination tests:
1) compares the responses to self-chirping alone to the responses to self-chriping in company.
2) compares the responess to other-chirping to the response during the beat.
Tests the assumptions that the distances a) between the self-chriping alone and self-chriping in company
are larger than the distances within the the self-chirping alone condition and b) the distances between
other-chirping in company and no one is chirping in company (i.e. beat) are larger than the distances
within the beat responses.
Args:
block_map ([type]): maps nix blocks to combination of stimulus parameters
df ([type]): the difference frequency that should be used
cs ([type]): ths chirpsize that should be used
all_contrasts ([type]): list of all used contrasts
all_conditions ([type]): list of all chirp conditions, i.e. self, other, or no-other
kernel_width (float, optional): std of Gaussian kernel. Defaults to 0.0005.
cell_name (str, optional): name of the cell. Defaults to "".
store_roc (bool, optional): if true the full false positives and true positives will be returned leads to huge file sizes!. Defaults to False.
Returns:
list of dictionaries: the results, auc is the area under the curve, i.e. the discrimination performance in the range [0, 1].
The 'detection_task' is either "self vs soliloquy" for 1) or "other vs quietness" for 2)
"""
detection_performances = []
for contrast in all_contrasts:
# print(" " * 50, end="\r")
# print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, cs, "no-other")]
self_block = block_map[(contrast, df, cs, "self")]
other_block = block_map[(contrast, df, cs, "self")]
# get some metadata assuming they are all the same for each condition, which they should
duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
other_spikes = get_spikes(other_block)
# get firing rates
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width)
# get the chirp response snippets
alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt))) # section a, alone self-chirping
self_snippets = np.zeros_like(alone_chirping_snippets) # section c, self chirping in company
other_snippets = np.zeros_like(alone_chirping_snippets) # section e, other chirping in company
silence_snippets = np.zeros_like(alone_chirping_snippets) # section d, in company no one chirping
for i in range(no_other_rates.shape[0]):
for j, chirp_time in enumerate(chirp_times):
start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt)
end_index = start_index + alone_chirping_snippets.shape[1]
index = i * len(chirp_times) + j
alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
other_snippets[index, :] = other_rates[i, start_index:end_index]
silence_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
silence_end_index = silence_start_index + alone_chirping_snippets.shape[1]
silence_snippets[index, :] = other_rates[i, silence_start_index:silence_end_index]
# get the distances
# 1. Soliloquy
# 2. Nobody chirps, all alone aka baseline response
# 3. I chirp while the other is present compared to self chirping without the other one present
# 4. the otherone chrips to me compared to baseline with anyone chirping
alone_chirping_dist = within_group_distance(alone_chirping_snippets) # within section a
silence_dist = within_group_distance(silence_snippets) # within section d
other_chirp_dist = within_group_distance(other_snippets) # within section e
self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets) # section a vs. section c
other_vs_silence_dist = across_group_distance(silence_snippets, other_snippets) # section d vs. section e
self_other_chirp_dist = across_group_distance(self_snippets, other_snippets) # section c vs. section e
self_chirp_beat_dist = across_group_distance(self_snippets, silence_snippets) # section c vs. section d
alone_chirp_beat_dist = across_group_distance(alone_chirping_snippets, silence_snippets) # section a vs. section d
# sort and perfom ROC analysis for two comparisons
# 1. soliloquy vs. self chirping in company
# 2. other chirping vs. nobody is chirping
triangle_indices = np.tril_indices_from(alone_chirping_dist, -1)
valid_no_other_distances = alone_chirping_dist[triangle_indices]
no_other_temp = np.zeros_like(valid_no_other_distances)
valid_silence_distances = silence_dist[triangle_indices]
silence_temp = np.zeros_like(valid_silence_distances)
valid_other_chirp_distances = other_chirp_dist[triangle_indices]
other_chirp_temp = np.zeros_like(valid_other_chirp_distances)
valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)
valid_other_vs_silence_distances = other_vs_silence_dist.ravel()
other_vs_silence_temp = np.ones_like(valid_other_vs_silence_distances)
valid_self_vs_other_chirp_distances = self_other_chirp_dist.ravel()
self_vs_other_chirps_temp = np.ones_like(valid_self_vs_other_chirp_distances)
valid_self_beat_distances = self_chirp_beat_dist.ravel()
self_vs_beat_temp = np.ones_like(valid_self_beat_distances)
valid_alone_chirp_beat_distance = alone_chirp_beat_dist.ravel()
alone_chirp_beat_temp = np.ones_like(valid_alone_chirp_beat_distance)
# Comparison 2: alone chirping (soliloquy) vs. self-chirping in company
group = np.hstack((no_other_temp, self_vs_alone_temp))
score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# Comparison 3: other fish chirping vs. beat
group = np.hstack((silence_temp, other_vs_silence_temp))
score = np.hstack((valid_silence_distances, valid_other_vs_silence_distances))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# Comparison 4: soliloquy vs. beat
group = np.hstack((no_other_temp, alone_chirp_beat_temp))
score = np.hstack((valid_no_other_distances, valid_alone_chirp_beat_distance))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "soliliquy vs beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# Comparison 5: beat vs self-chirping in company
group = np.hstack((silence_temp, self_vs_beat_temp))
score = np.hstack((valid_silence_distances, valid_alone_chirp_beat_distance))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "beat vs self", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# Comparison 6: self vs other-chirping in company
group = np.hstack((other_chirp_temp, self_vs_other_chirps_temp))
score = np.hstack((valid_other_chirp_distances, valid_self_vs_other_chirp_distances))
auc = roc_auc_score(group, score)
if store_roc:
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc, "true_positives": tpr, "false_positives": fpr})
else:
detection_performances.append({"cell": cell_name, "detection_task": "self vs other", "contrast": contrast, "df": df, "kernel_width": kernel_width, "chirpsize": cs, "auc": auc})
# print("\n")
return detection_performances
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes, current_df=None, current_chirpsize=None, cell_name="", store_roc=False):
dfs = [current_df] if current_df is not None else all_dfs
chirp_sizes = [current_chirpsize] if current_chirpsize is not None else all_chirpsizes
kernels = [0.00025, 0.0005, 0.001, 0.0025]
result_dicts = []
for cs in chirp_sizes:
for df in dfs:
print("%s, chirp size: %i Hz, deltaf %.1f Hz" % (cell_name, cs, df))
for kw in kernels:
#print("cs: %i Hz, df: %i Hz, kernel: %.4fs" % (cs, df, kw))
#print("Foreign fish detection during beat:")
result_dicts.extend(foreign_fish_detection_beat(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))
#print("Foreign fish detection during chirp:")
result_dicts.extend(foreign_fish_detection_chirp(block_map, df, cs, all_contrasts, all_conditions, kw, cell_name, store_roc))
return result_dicts
def process_cell(filename):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_chirpsizes, all_conditions = sort_blocks(nf)
if "baseline" not in block_map.keys():
print("ERROR: no baseline data for file %s!" % filename)
results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, all_chirpsizes,
current_df=None, current_chirpsize=None,
cell_name=filename.split(os.path.sep)[-1].split(".nix")[0], store_roc=False)
nf.close()
return results
def main():
num_cores = multiprocessing.cpu_count() - 6
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files)
results = []
for pr in processed_list:
results.extend(pr)
df = pd.DataFrame(results)
df.to_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";")
if __name__ == "__main__":
main()