chirp_probing/response_discriminability.py

167 lines
6.5 KiB
Python

import os
import glob
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from util import firing_rate, despine
from IPython import embed
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
block = block_map[(contrast, df, condition)]
print(block.name)
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
time = np.arange(0.0, duration, dt)
rates = np.zeros((len(response_map.keys()), len(time)))
for i, k in enumerate(response_map.keys()):
spikes.append(response_map[k][:])
rates[i,:] = firing_rate(spikes[-1], duration, kernel_width, dt)
return time, rates, spikes
def get_signals(block):
print(block.name)
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df):
conditions = ["no-other", "self", "other"]
condition_labels = ["alone", "self", "other"]
max_time = 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts) + 1, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
_, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=1, colspan=3, fig=fig)
ax.plot(time[time < max_time], self_freq[time < max_time], color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[time < max_time], other_freq[time < max_time], color="#1f77b4", label="%iHz" % other_eodf)
ax.text(-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# for the largest contrast plot the raster with psth, only a section of the data (e.g. 1s)
t, rates, spikes = get_firing_rate(block_map, current_df, all_contrasts[0], condition, kernel_width=0.001)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (1, i * 3 + i), rowspan=1, colspan=3)
ax.plot(t[t < max_time], avg_resp[t < max_time], color="k", lw=0.5)
ax.fill_between(t[t < max_time], (avg_resp - error)[t < max_time], (avg_resp + error)[t < max_time], color="k", lw=None, alpha=0.25)
despine(ax, ["top", "right"], False)
# for all other contrast plot the firing rate alone
for j in range(1, len(all_contrasts)):
contrast = all_contrasts[j]
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j+1, i * 3 + i), rowspan=1, colspan=3)
ax.plot(t[t < max_time], avg_resp[t < max_time], color="k", lw=0.5)
#ax.fill_between(t[t < max_time], (avg_resp - error)[t < max_time], (avg_resp + error)[t < max_time], color="k", lw=None, alpha=0.25)
despine(ax, ["top", "right"], False)
plt.savefig("chirp_responses.pdf")
plt.close()
return
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_dfs, all_contrasts, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
create_response_plot(block_map, all_contrasts, all_dfs, all_conditions, 20)
"""
if len(dfs) == 0:
dfs = all_dfs
if len(contrasts) == 0:
contrasts = all_contrasts
if len(conditions) == 0:
conditions = all_conditions
for df in dfs:
for condition in conditions:
for contrast in contrasts:
time, rates = get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0025)
"""
nf.close()
def main():
nix_files = sorted(glob.glob("cell*.nix"))
for nix_file in nix_files:
process_cell(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
if __name__ == "__main__":
main()