chirp_probing/response_discriminability.py

317 lines
12 KiB
Python

import os
import glob
import nixio as nix
import numpy as np
import scipy.signal as sig
import matplotlib.pyplot as plt
from IPython import embed
from util import firing_rate, despine
figure_folder = "figures"
data_folder = "data"
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_spikes(block):
"""Get the spike trains.
Args:
block ([type]): [description]
Returns:
list of np.ndarray: the spike trains.
"""
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
for k in sorted(response_map.keys()):
spikes.append(response_map[k][:])
return spikes
def get_rates(spike_trains, duration, dt, kernel_width):
"""Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.
Args:
spike_trains ([type]): [description]
duration ([type]): [description]
dt ([type]): [description]
kernel_width ([type]): [description]
Returns:
np.ndarray: Matrix of firing rates, 1. dimension is the number of trials
np.ndarray: the time vector
"""
time = np.arange(0.0, duration, dt)
rates = np.zeros((len(spike_trains), len(time)))
for i, sp in enumerate(spike_trains):
rates[i, :] = firing_rate(sp, duration, kernel_width, dt)
return rates, time
def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
"""Retruns the firing rates and the spikes
Args:
block_map ([type]): [description]
df ([type]): [description]
contrast ([type]): [description]
condition ([type]): [description]
kernel_width (float, optional): [description]. Defaults to 0.0005.
Returns:
np.ndarray: the time vector.
np.ndarray: the rates with the first dimension representing the trials.
np.adarray: the spike trains.
"""
block = block_map[(contrast, df, condition)]
spikes = get_spikes(block)
duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
rates, time = get_rates(spikes, duration, dt, kernel_width)
return time, rates, spikes
def get_signals(block):
"""Read the fish signals from block.
Args:
block ([type]): the block containing the data for a given df, contrast and condition
Raises:
ValueError: when the complete stimulus data is not found
ValueError: when the no-other animal data is not found
Returns:
np.ndarray: the complete signal
np.ndarray: the frequency profile of the recorded fish
np.ndarray: the frequency profile of the other fish
np.ndarray: the time axis
"""
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def extract_am(signal):
"""Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.
Args:
signal (np.ndarray): the signal
Returns:
np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
"""
# first add some padding to both ends
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs, cut away the padding
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
name = figure_name if figure_name is not None else "chirp_responses.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
plt.close()
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
detection_performance = {}
for contrast in all_contrasts:
no_other_block = block_map[(contrast, df, "no-other")]
self_block = block_map[(contrast, df, "self")]
# get some metadata assuming they are all the same for each condition
duration = float(self_block.metadata["stimulus parameter"]["duration"])
dt = float(self_block.metadata["stimulus parameter"]["dt"])
chirp_duration = self_block.metadata["stimulus parameter"]["chirp_duration"]
chirp_times = self_block.metadata["stimulus parameter"]["chirp_times"]
interchirp_starts = []
interchirp_ends = []
for ct in chirp_times:
interchirp_starts.append(ct + 1.5 * chirp_duration)
interchirp_ends.append(ct - 1.5 * chirp_duration)
del interchirp_ends[0]
del interchirp_starts[-1]
# get the spiking responses
no_other_spikes = get_spikes(no_other_block)
self_spikes = get_spikes(self_block)
# get firing rates
no_other_rates = get_rates(no_other_spikes, duration, dt, kernel_width)
self_rates = get_rates(self_spikes, duration, dt, kernel_width)
# get the response snippets between chrips
# get the distances and do the roc
embed()
break;
return detection_performance
def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005):
#
return None
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, kernel_width=0.0005):
dfs = [current_df] if current_df is not None else all_dfs
detection_performance_beat = []
detection_performance_chirp = []
for df in dfs:
detection_performance_beat.append(foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width))
detection_performance_chirp.append(foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, kernel_width))
return detection_performance_beat, detection_performance_chirp
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
# fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
# create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20)
nf.close()
def main():
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
for nix_file in nix_files:
process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
if __name__ == "__main__":
main()