chirp_probing/response_discriminability.py

230 lines
9.7 KiB
Python

import os
import glob
import nixio as nix
import numpy as np
import scipy.signal as sig
import matplotlib.pyplot as plt
from IPython import embed
from util import firing_rate, despine
figure_folder = "figures"
data_folder = "data"
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
block = block_map[(contrast, df, condition)]
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
time = np.arange(0.0, duration, dt)
rates = np.zeros((len(response_map.keys()), len(time)))
for i, k in enumerate(response_map.keys()):
spikes.append(response_map[k][:])
rates[i,:] = firing_rate(spikes[-1], duration, kernel_width, dt)
return time, rates, spikes
def get_signals(block):
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def extract_am(signal):
# first add some padding
front_pad = np.flip(signal[:int(len(signal)/100)])
back_pad = np.flip(signal[-int(len(signal)/100):])
padded = np.hstack((front_pad, signal, back_pad))
# do the hilbert and take abs
am = np.abs(sig.hilbert(padded))
am = am[len(front_pad):-len(back_pad)]
return am
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
"""
# for the largest contrast plot the raster with psth, only a section of the data (e.g. 1s)
t, rates, spikes = get_firing_rate(block_map, current_df, all_contrasts[0], condition, kernel_width=0.001)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)],
color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.0625), minor=True)
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
"""
# for all other contrast plot the firing rate alone
for j in range(0, len(all_contrasts)):
contrast = all_contrasts[j]
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
name = figure_name if figure_name is not None else "chirp_responses.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
plt.close()
def chrip_detection_soliloquy(spikes, chirp_times, kernel_width=0.0005):
#
pass
def chirp_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, current_condition=None):
pass
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
chirp_detection(block_map, all_dfs, all_contrasts, all_conditions)
nf.close()
def main():
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
for nix_file in nix_files:
process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
if __name__ == "__main__":
main()