chirp_probing/response_discriminability.py
2020-09-23 18:10:27 +02:00

467 lines
21 KiB
Python

import os
import glob
import pandas as pd
import nixio as nix
import numpy as np
import scipy.signal as sig
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
figure_folder = "figures"
data_folder = "data"
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_spikes(block):
"""Get the spike trains.
Args:
block ([type]): [description]
Returns:
list of np.ndarray: the spike trains.
"""
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
for k in sorted(response_map.keys()):
spikes.append(response_map[k][:])
return spikes
def get_rates(spike_trains, duration, dt, kernel_width):
"""Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.
Args:
spike_trains ([type]): [description]
duration ([type]): [description]
dt ([type]): [description]
kernel_width ([type]): [description]
Returns:
np.ndarray: Matrix of firing rates, 1. dimension is the number of trials
np.ndarray: the time vector
"""
time = np.arange(0.0, duration, dt)
rates = np.zeros((len(spike_trains), len(time)))
for i, sp in enumerate(spike_trains):
rates[i, :] = firing_rate(sp, duration, kernel_width, dt)
return rates, time
def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
"""Retruns the firing rates and the spikes
Args:
block_map ([type]): [description]
df ([type]): [description]
contrast ([type]): [description]
condition ([type]): [description]
kernel_width (float, optional): [description]. Defaults to 0.0005.
Returns:
np.ndarray: the time vector.
np.ndarray: the rates with the first dimension representing the trials.
np.adarray: the spike trains.
"""
block = block_map[(contrast, df, condition)]
spikes = get_spikes(block)
duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
rates, time = get_rates(spikes, duration, dt, kernel_width)
return time, rates, spikes
def get_signals(block):
"""Read the fish signals from block.
Args:
block ([type]): the block containing the data for a given df, contrast and condition
Raises:
ValueError: when the complete stimulus data is not found
ValueError: when the no-other animal data is not found
Returns:
np.ndarray: the complete signal
np.ndarray: the frequency profile of the recorded fish
np.ndarray: the frequency profile of the other fish
np.ndarray: the time axis
"""
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
name = figure_name if figure_name is not None else "chirp_responses.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
plt.close()
def get_chirp_metadata(block):
trial_duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
return trial_duration, dt, chirp_size, chirp_duration, chirp_times
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name=""):
detection_performances = []
for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, "no-other")]
self_block = block_map[(contrast, df, "self")]
# get some metadata assuming they are all the same for each condition, which they should
duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
interchirp_starts = np.add(chirp_times, 1.5 * chirp_duration)[:-1]
interchirp_ends = np.subtract(chirp_times, 1.5 * chirp_duration)[1:]
ici = np.floor(np.mean(np.subtract(interchirp_ends, interchirp_starts))*1000) / 1000
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
# get firing rates
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
# get the response snippets between chrips
no_other_snippets = np.zeros((len(interchirp_starts) * no_other_rates.shape[0], int(ici / dt)))
self_snippets = np.zeros_like(no_other_snippets)
for i in range(no_other_rates.shape[0]):
for j, start in enumerate(interchirp_starts):
start_index = int(start/dt)
end_index = start_index + no_other_snippets.shape[1]
index = i * len(interchirp_starts) + j
no_other_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
# get the distances
baseline_dist = within_group_distance(no_other_snippets)
comp_dist = across_group_distance(no_other_snippets, self_snippets)
# sort and perfom roc
triangle_indices = np.tril_indices_from(baseline_dist, -1)
valid_distances_baseline = baseline_dist[triangle_indices]
temp1 = np.zeros_like(valid_distances_baseline)
valid_distances_comparison = comp_dist.ravel()
temp2 = np.ones_like(valid_distances_comparison)
group = np.hstack((temp1, temp2))
score = np.hstack((valid_distances_baseline, valid_distances_comparison))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
detection_performances.append({"cell": cell_name, "detection_task": "beat", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
print("\n")
return detection_performances
def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name=""):
detection_performances = []
for contrast in all_contrasts:
print(" " * 50, end="\r")
print("Contrast: %.3f" % contrast, end="\r")
no_other_block = block_map[(contrast, df, "no-other")]
self_block = block_map[(contrast, df, "self")]
other_block = block_map[(contrast, df, "self")]
# get some metadata assuming they are all the same for each condition, which they should
duration, dt, _, chirp_duration, chirp_times = get_chirp_metadata(self_block)
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
other_spikes = get_spikes(other_block)
# get firing rates
no_other_rates, _ = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates, _ = get_rates(self_spikes, duration, dt, kernel_width)
other_rates, _ = get_rates(other_spikes, duration, dt, kernel_width)
# get the chirp response snippets
alone_chirping_snippets = np.zeros((len(chirp_times) * no_other_rates.shape[0], int(chirp_duration / dt)))
self_snippets = np.zeros_like(alone_chirping_snippets)
other_snippets = np.zeros_like(alone_chirping_snippets)
baseline_snippets = np.zeros_like(alone_chirping_snippets)
for i in range(no_other_rates.shape[0]):
for j, chirp_time in enumerate(chirp_times):
start_index = int((chirp_time - chirp_duration/2 + 0.003)/dt)
end_index = start_index + alone_chirping_snippets.shape[1]
index = i * len(chirp_times) + j
alone_chirping_snippets[index, :] = no_other_rates[i, start_index:end_index]
self_snippets[index, :] = self_rates[i, start_index:end_index]
other_snippets[index, :] = other_rates[i, start_index:end_index]
baseline_start_index = int((chirp_time + 1.5 * chirp_duration)/dt)
baseline_end_index = baseline_start_index + alone_chirping_snippets.shape[1]
baseline_snippets[index, :] = no_other_rates[i, baseline_start_index:baseline_end_index]
# get the distances
# 1. Soliloquy
# 2. Nobody chirps, all alone aka baseline response
# 3. I chirp while the other is present compared to self chirping without the other one present
# 4. the otherone chrips to me compared to baseline with anyone chirping
alone_chirping_dist = within_group_distance(alone_chirping_snippets)
baseline_dist = within_group_distance(baseline_snippets)
self_vs_alone_dist = across_group_distance(alone_chirping_snippets, self_snippets)
other_vs_baseline_dist = across_group_distance(baseline_snippets, other_snippets)
# sort and perfom roc for two comparisons
# 1. soliloquy vs. self chirping in company
# 2. other chirping vs. nobody is chirping
triangle_indices = np.tril_indices_from(alone_chirping_dist, -1)
valid_no_other_distances = alone_chirping_dist[triangle_indices]
no_other_temp = np.zeros_like(valid_no_other_distances)
valid_baseline_distances = baseline_dist[triangle_indices]
baseline_temp = np.zeros_like(valid_baseline_distances)
valid_self_vs_alone_distances = self_vs_alone_dist.ravel()
self_vs_alone_temp = np.ones_like(valid_self_vs_alone_distances)
valid_other_vs_baseline_distances = other_vs_baseline_dist.ravel()
other_vs_baseline_temp = np.ones_like(valid_other_vs_baseline_distances)
group = np.hstack((no_other_temp, self_vs_alone_temp))
score = np.hstack((valid_no_other_distances, valid_self_vs_alone_distances))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
detection_performances.append({"cell": cell_name, "detection_task": "self vs soliloquy", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
group = np.hstack((baseline_temp, other_vs_baseline_temp))
score = np.hstack((valid_baseline_distances, valid_other_vs_baseline_distances))
fpr, tpr, _ = roc_curve(group, score, pos_label=1)
auc = roc_auc_score(group, score)
detection_performances.append({"cell": cell_name, "detection_task": "other vs quietness", "contrast": contrast, "df": df, "kernel_width": kernel_width, "auc": auc, "true_positives": tpr, "false_positives": fpr})
print("\n")
return detection_performances
def plot_detection_results(data_frame, df, kernel_width, cell):
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
conditions = sorted(cell_results.detection_task.unique())
kernels = sorted(cell_results.kernel_width.unique())
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (8, 7)
for i, c in enumerate(conditions):
condition_results = cell_results[cell_results.detection_task == c]
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
roc_data = condition_results[condition_results.kernel_width == kernel_width]
contrasts = roc_data.contrast.unique()
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
roc_ax.legend(loc="best", fontsize=6, ncol=2, frameon=False)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xlabel("false positive rate", fontsize=9)
roc_ax.set_ylabel("true positive rate", fontsize=9)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
for k in kernels:
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.4f" % k)
auc_ax.set_xlabel("contrast [%]")
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_ylabel("discriminability")
auc_ax.legend(ncol=2, fontsize=6)
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls"--",)
fig.savefig("discrimination.pdf")
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name=""):
dfs = [current_df] if current_df is not None else all_dfs
kernels = [0.00025, 0.0005, 0.001, 0.0025, 0.005]
result_dicts = []
for df in dfs:
for kw in kernels:
print("df: %i, kernel: %.4f" % (df, kw))
print("Foreign fish detection during beat:")
result_dicts.extend(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kw, cell_name))
print("Foreign fish detection during chirp:")
result_dicts.extend(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kw, cell_name))
break
embed()
return result_dicts
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
cell_name=filename.split(os.path.sep)[-1].split(".nix")[0])
nf.close()
def main():
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
for nix_file in nix_files:
process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
if __name__ == "__main__":
main()