separate plotting and analysis
This commit is contained in:
parent
511fddbeb2
commit
d0bde3b673
@ -45,11 +45,12 @@ Won't do, this is trivial?!
|
||||
* calculate the discriminability between the baseline (no-other fish present) and the another fish is present for each contrast
|
||||
* Work out the difference between the soliloquy and the response to self generated chirp in a communication context -> done
|
||||
* Compare to the beat alone parts of the responses. -> done
|
||||
* What kernels to use?
|
||||
* What kernels to use? -> done
|
||||
* Duration of the chrip window?
|
||||
* sorting according to phase?
|
||||
* we could filter the P-unit responses to model the ELL filering
|
||||
* we could filter the P-unit responses to model the ELL filtering
|
||||
|
||||
### 4 plot discrimination results
|
||||
## Random thoughts
|
||||
|
||||
* who is sending the chrips? Henninger and also Hupe illustrate the subordinant fish is chirping.
|
||||
|
98
nix_util.py
Normal file
98
nix_util.py
Normal file
@ -0,0 +1,98 @@
|
||||
import nixio as nix
|
||||
import numpy as np
|
||||
|
||||
def read_baseline(block):
|
||||
spikes = []
|
||||
if "baseline" not in block.name:
|
||||
print("Block %s does not appear to be a baseline block!" % block.name )
|
||||
return spikes
|
||||
spikes = block.data_arrays[0][:]
|
||||
return spikes
|
||||
|
||||
|
||||
def sort_blocks(nix_file):
|
||||
block_map = {}
|
||||
contrasts = []
|
||||
deltafs = []
|
||||
conditions = []
|
||||
for b in nix_file.blocks:
|
||||
if "baseline" not in b.name.lower():
|
||||
name_parts = b.name.split("_")
|
||||
cntrst = float(name_parts[1])
|
||||
if cntrst not in contrasts:
|
||||
contrasts.append(cntrst)
|
||||
cndtn = name_parts[3]
|
||||
if cndtn not in conditions:
|
||||
conditions.append(cndtn)
|
||||
dltf = float(name_parts[5])
|
||||
if dltf not in deltafs:
|
||||
deltafs.append(dltf)
|
||||
block_map[(cntrst, dltf, cndtn)] = b
|
||||
else:
|
||||
block_map["baseline"] = b
|
||||
return block_map, contrasts, deltafs, conditions
|
||||
|
||||
|
||||
def get_spikes(block):
|
||||
"""Get the spike trains.
|
||||
|
||||
Args:
|
||||
block ([type]): [description]
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: the spike trains.
|
||||
"""
|
||||
response_map = {}
|
||||
spikes = []
|
||||
|
||||
for da in block.data_arrays:
|
||||
if "spike_times" in da.type and "response" in da.name:
|
||||
resp_id = int(da.name.split("_")[-1])
|
||||
response_map[resp_id] = da
|
||||
for k in sorted(response_map.keys()):
|
||||
spikes.append(response_map[k][:])
|
||||
|
||||
return spikes
|
||||
|
||||
|
||||
def get_signals(block):
|
||||
"""Read the fish signals from block.
|
||||
|
||||
Args:
|
||||
block ([type]): the block containing the data for a given df, contrast and condition
|
||||
|
||||
Raises:
|
||||
ValueError: when the complete stimulus data is not found
|
||||
ValueError: when the no-other animal data is not found
|
||||
|
||||
Returns:
|
||||
np.ndarray: the complete signal
|
||||
np.ndarray: the frequency profile of the recorded fish
|
||||
np.ndarray: the frequency profile of the other fish
|
||||
np.ndarray: the time axis
|
||||
"""
|
||||
self_freq = None
|
||||
other_freq = None
|
||||
signal = None
|
||||
time = None
|
||||
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
|
||||
raise ValueError("Signals not stored in block!")
|
||||
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
|
||||
raise ValueError("Signals not stored in block!")
|
||||
|
||||
signal = block.data_arrays["complete stimulus"][:]
|
||||
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
|
||||
self_freq = block.data_arrays["self frequency"][:]
|
||||
if "no-other" not in block.name:
|
||||
other_freq = block.data_arrays["other frequency"][:]
|
||||
return signal, self_freq, other_freq, time
|
||||
|
||||
|
||||
def get_chirp_metadata(block):
|
||||
trial_duration = float(block.metadata["stimulus parameter"]["duration"])
|
||||
dt = float(block.metadata["stimulus parameter"]["dt"])
|
||||
chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
|
||||
chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
|
||||
chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
|
||||
|
||||
return trial_duration, dt, chirp_size, chirp_duration, chirp_times
|
224
plots.py
Normal file
224
plots.py
Normal file
@ -0,0 +1,224 @@
|
||||
import glob
|
||||
import os
|
||||
import nixio as nix
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import ConnectionPatch
|
||||
|
||||
from nix_util import sort_blocks, read_baseline, get_signals
|
||||
from util import despine
|
||||
|
||||
figure_folder = "figures"
|
||||
data_folder = "data"
|
||||
|
||||
|
||||
def plot_comparisons(current_df=20):
|
||||
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
|
||||
if len(files) < 1:
|
||||
print("plot comparisons: no data found!")
|
||||
return
|
||||
filename = files[0]
|
||||
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
|
||||
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
|
||||
conditions = ["no-other", "self", "other"]
|
||||
min_time = 0.5
|
||||
max_time = min_time + 0.5
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 2.))
|
||||
fig_grid = (3, len(all_conditions)*3+2)
|
||||
axes = []
|
||||
for i, condition in enumerate(conditions):
|
||||
# plot the signals
|
||||
block = block_map[(all_contrasts[0], current_df, condition)]
|
||||
_, self_freq, other_freq, time = get_signals(block)
|
||||
|
||||
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
|
||||
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
|
||||
|
||||
# plot frequency traces
|
||||
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
|
||||
color="#ff7f0e", label="%iHz" % self_eodf)
|
||||
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
|
||||
if other_freq is not None:
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
|
||||
color="#1f77b4", label="%iHz" % other_eodf)
|
||||
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
|
||||
# ax.set_title(condition_labels[i])
|
||||
ax.set_ylim([735, 885])
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
axes.append(ax)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.675, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
rect = Rectangle((0.57, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[0].add_collection(pc)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.675, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
rect = Rectangle((0.575, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[1].add_collection(pc)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.57, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[2].add_collection(pc)
|
||||
|
||||
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
|
||||
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
|
||||
axes[1].add_artist(con)
|
||||
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
|
||||
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
|
||||
axes[1].add_artist(con)
|
||||
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
|
||||
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
|
||||
axes[1].add_artist(con)
|
||||
|
||||
axes[0].text(1., 660, "2.")
|
||||
axes[1].text(1.05, 660, "3.")
|
||||
axes[0].text(1.1, 890, "1.")
|
||||
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
|
||||
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
|
||||
plt.close()
|
||||
nf.close()
|
||||
|
||||
|
||||
def create_response_plot(filename, current_df=20, figure_name=None):
|
||||
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
|
||||
if len(files) < 1:
|
||||
print("plot comparisons: no data found!")
|
||||
return
|
||||
filename = files[0]
|
||||
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
|
||||
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
|
||||
|
||||
conditions = ["no-other", "self", "other"]
|
||||
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
||||
min_time = 0.5
|
||||
max_time = min_time + 0.5
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 5.5))
|
||||
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
|
||||
all_contrasts = sorted(all_contrasts, reverse=True)
|
||||
|
||||
for i, condition in enumerate(conditions):
|
||||
# plot the signals
|
||||
block = block_map[(all_contrasts[0], current_df, condition)]
|
||||
signal, self_freq, other_freq, time = get_signals(block)
|
||||
am = extract_am(signal)
|
||||
|
||||
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
|
||||
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
|
||||
|
||||
# plot frequency traces
|
||||
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
|
||||
color="#ff7f0e", label="%iHz" % self_eodf)
|
||||
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
|
||||
if other_freq is not None:
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
|
||||
color="#1f77b4", label="%iHz" % other_eodf)
|
||||
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
|
||||
ax.set_title(condition_labels[i])
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
|
||||
# plot the am
|
||||
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
|
||||
color="#2ca02c", label="signal")
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
|
||||
color="#d62728", label="am")
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
ax.set_ylim([-1.25, 1.25])
|
||||
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
|
||||
|
||||
# for each contrast plot the firing rate
|
||||
for j, contrast in enumerate(all_contrasts):
|
||||
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
|
||||
avg_resp = np.mean(rates, axis=0)
|
||||
error = np.std(rates, axis=0)
|
||||
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
|
||||
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
|
||||
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
|
||||
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
|
||||
ax.set_ylim([0, 750])
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("")
|
||||
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
|
||||
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
|
||||
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
|
||||
if j < len(all_contrasts) -1:
|
||||
ax.set_xticklabels([])
|
||||
ax.set_yticks(np.arange(0.0, 751., 500))
|
||||
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
|
||||
if i > 0:
|
||||
ax.set_yticklabels([])
|
||||
despine(ax, ["top", "right"], False)
|
||||
if i == 2:
|
||||
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
|
||||
color="#d62728", ha="left", fontsize=7)
|
||||
|
||||
if i == 1:
|
||||
ax.set_xlabel("time [ms]")
|
||||
if i == 0:
|
||||
ax.set_ylabel("frequency [Hz]", va="center")
|
||||
ax.yaxis.set_label_coords(-0.45, 3.5)
|
||||
|
||||
name = figure_name if figure_name is not None else "chirp_responses.pdf"
|
||||
name = (name + ".pdf") if ".pdf" not in name else name
|
||||
plt.savefig(os.path.join(figure_folder, name))
|
||||
plt.close()
|
||||
nf.close()
|
||||
|
||||
def response_examples(*kwargs):
|
||||
|
||||
if filename in kwargs:
|
||||
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
|
||||
|
||||
def main(task=None, parameter={}):
|
||||
plot_tasks = {"comparisons": plot_comparisons,
|
||||
"response_examples": create_response_plot}
|
||||
if task is not None and task in plot_tasks.keys():
|
||||
plot_tasks[task](*parameter)
|
||||
elif task is None:
|
||||
for t in plot_tasks.keys():
|
||||
plot_tasks[t](*parameter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main("comparisons")
|
||||
|
||||
|
||||
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
|
||||
# plot the responses
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
|
||||
|
||||
# sketch showing the comparisons
|
||||
#plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
|
||||
|
||||
# plot the discrimination analyses
|
||||
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
|
||||
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
|
||||
# cell_name=cell_name, store_roc=True)
|
||||
# pdf = pd.DataFrame(results)
|
||||
# plot_detection_results(pdf, 20, 0.001, cell_name)
|
||||
|
||||
#nf.close()
|
||||
pass
|
@ -7,6 +7,8 @@ from chirp_ams import get_signals
|
||||
from model import simulate, load_models
|
||||
from IPython import embed
|
||||
import matplotlib.pyplot as plt
|
||||
import multiprocessing
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
data_folder = "data"
|
||||
|
||||
@ -181,8 +183,7 @@ def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
|
||||
print("\n")
|
||||
|
||||
|
||||
def main():
|
||||
models = load_models("models.csv")
|
||||
def simulate_cell(cell_id, models):
|
||||
deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200] # Hz, difference frequency between self and other
|
||||
stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
|
||||
"contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
|
||||
@ -192,26 +193,30 @@ def main():
|
||||
"chirp_frequency": 5, # Hz, how often does the fish chirp
|
||||
"duration": 5., # s, total duration of simulation
|
||||
"dt": 1, # s, stepsize of the simulation, to be overwritten
|
||||
}
|
||||
}
|
||||
model_params = models[cell_id]
|
||||
baseline_spikes = get_baseline_response(model_params, duration=30)
|
||||
filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"])
|
||||
save_baseline_response(filename, "baseline response", baseline_spikes, model_params)
|
||||
|
||||
print("Cell: %s" % model_params["cell"])
|
||||
for deltaf in deltafs:
|
||||
stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}
|
||||
stimulus_params["dt"] = model_params["deltat"]
|
||||
|
||||
print("\t Deltaf: %i" % deltaf)
|
||||
chirp_times = np.arange(stimulus_params["chirp_duration"],
|
||||
stimulus_params["duration"] - stimulus_params["chirp_duration"],
|
||||
1./stimulus_params["chirp_frequency"])
|
||||
stimulus_params["chirp_times"] = chirp_times
|
||||
simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
|
||||
|
||||
|
||||
def main():
|
||||
models = load_models("models.csv")
|
||||
num_cores = multiprocessing.cpu_count() - 6
|
||||
|
||||
for cell_id in range(len(models)):
|
||||
model_params = models[cell_id]
|
||||
baseline_spikes = get_baseline_response(model_params, duration=30)
|
||||
save_baseline_response( "cell_%s.nix" % model_params["cell"], "baseline response", baseline_spikes, model_params)
|
||||
|
||||
print("Cell: %s" % model_params["cell"])
|
||||
for deltaf in deltafs:
|
||||
stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}
|
||||
stimulus_params["dt"] = model_params["deltat"]
|
||||
|
||||
print("\t Deltaf: %i" % deltaf)
|
||||
chirp_times = np.arange(stimulus_params["chirp_duration"],
|
||||
stimulus_params["duration"] - stimulus_params["chirp_duration"],
|
||||
1./stimulus_params["chirp_frequency"])
|
||||
stimulus_params["chirp_times"] = chirp_times
|
||||
simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
|
||||
if cell_id == 9:
|
||||
exit() # the first 10 cell only for now!
|
||||
Parallel(n_jobs=num_cores)(delayed(simulate_cell)(cell_id, models) for cell_id in range(len(models[:10])))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4,71 +4,17 @@ import pandas as pd
|
||||
import nixio as nix
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import ConnectionPatch
|
||||
from sklearn.metrics import roc_curve, roc_auc_score
|
||||
from IPython import embed
|
||||
import multiprocessing
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
|
||||
from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata
|
||||
|
||||
figure_folder = "figures"
|
||||
data_folder = "data"
|
||||
|
||||
|
||||
def read_baseline(block):
|
||||
spikes = []
|
||||
if "baseline" not in block.name:
|
||||
print("Block %s does not appear to be a baseline block!" % block.name )
|
||||
return spikes
|
||||
spikes = block.data_arrays[0][:]
|
||||
return spikes
|
||||
|
||||
|
||||
def sort_blocks(nix_file):
|
||||
block_map = {}
|
||||
contrasts = []
|
||||
deltafs = []
|
||||
conditions = []
|
||||
for b in nix_file.blocks:
|
||||
if "baseline" not in b.name.lower():
|
||||
name_parts = b.name.split("_")
|
||||
cntrst = float(name_parts[1])
|
||||
if cntrst not in contrasts:
|
||||
contrasts.append(cntrst)
|
||||
cndtn = name_parts[3]
|
||||
if cndtn not in conditions:
|
||||
conditions.append(cndtn)
|
||||
dltf = float(name_parts[5])
|
||||
if dltf not in deltafs:
|
||||
deltafs.append(dltf)
|
||||
block_map[(cntrst, dltf, cndtn)] = b
|
||||
else:
|
||||
block_map["baseline"] = b
|
||||
return block_map, contrasts, deltafs, conditions
|
||||
|
||||
|
||||
def get_spikes(block):
|
||||
"""Get the spike trains.
|
||||
|
||||
Args:
|
||||
block ([type]): [description]
|
||||
|
||||
Returns:
|
||||
list of np.ndarray: the spike trains.
|
||||
"""
|
||||
response_map = {}
|
||||
spikes = []
|
||||
|
||||
for da in block.data_arrays:
|
||||
if "spike_times" in da.type and "response" in da.name:
|
||||
resp_id = int(da.name.split("_")[-1])
|
||||
response_map[resp_id] = da
|
||||
for k in sorted(response_map.keys()):
|
||||
spikes.append(response_map[k][:])
|
||||
|
||||
return spikes
|
||||
|
||||
|
||||
def get_rates(spike_trains, duration, dt, kernel_width):
|
||||
"""Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.
|
||||
|
||||
@ -114,126 +60,7 @@ def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
|
||||
return time, rates, spikes
|
||||
|
||||
|
||||
def get_signals(block):
|
||||
"""Read the fish signals from block.
|
||||
|
||||
Args:
|
||||
block ([type]): the block containing the data for a given df, contrast and condition
|
||||
|
||||
Raises:
|
||||
ValueError: when the complete stimulus data is not found
|
||||
ValueError: when the no-other animal data is not found
|
||||
|
||||
Returns:
|
||||
np.ndarray: the complete signal
|
||||
np.ndarray: the frequency profile of the recorded fish
|
||||
np.ndarray: the frequency profile of the other fish
|
||||
np.ndarray: the time axis
|
||||
"""
|
||||
self_freq = None
|
||||
other_freq = None
|
||||
signal = None
|
||||
time = None
|
||||
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
|
||||
raise ValueError("Signals not stored in block!")
|
||||
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
|
||||
raise ValueError("Signals not stored in block!")
|
||||
|
||||
signal = block.data_arrays["complete stimulus"][:]
|
||||
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
|
||||
self_freq = block.data_arrays["self frequency"][:]
|
||||
if "no-other" not in block.name:
|
||||
other_freq = block.data_arrays["other frequency"][:]
|
||||
return signal, self_freq, other_freq, time
|
||||
|
||||
|
||||
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
|
||||
conditions = ["no-other", "self", "other"]
|
||||
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
||||
min_time = 0.5
|
||||
max_time = min_time + 0.5
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 5.5))
|
||||
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
|
||||
all_contrasts = sorted(all_contrasts, reverse=True)
|
||||
|
||||
for i, condition in enumerate(conditions):
|
||||
# plot the signals
|
||||
block = block_map[(all_contrasts[0], current_df, condition)]
|
||||
signal, self_freq, other_freq, time = get_signals(block)
|
||||
am = extract_am(signal)
|
||||
|
||||
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
|
||||
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
|
||||
|
||||
# plot frequency traces
|
||||
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
|
||||
color="#ff7f0e", label="%iHz" % self_eodf)
|
||||
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
|
||||
if other_freq is not None:
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
|
||||
color="#1f77b4", label="%iHz" % other_eodf)
|
||||
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
|
||||
ax.set_title(condition_labels[i])
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
|
||||
# plot the am
|
||||
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
|
||||
color="#2ca02c", label="signal")
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
|
||||
color="#d62728", label="am")
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
ax.set_ylim([-1.25, 1.25])
|
||||
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
|
||||
|
||||
# for each contrast plot the firing rate
|
||||
for j, contrast in enumerate(all_contrasts):
|
||||
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
|
||||
avg_resp = np.mean(rates, axis=0)
|
||||
error = np.std(rates, axis=0)
|
||||
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
|
||||
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
|
||||
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
|
||||
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
|
||||
ax.set_ylim([0, 750])
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("")
|
||||
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
|
||||
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
|
||||
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
|
||||
if j < len(all_contrasts) -1:
|
||||
ax.set_xticklabels([])
|
||||
ax.set_yticks(np.arange(0.0, 751., 500))
|
||||
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
|
||||
if i > 0:
|
||||
ax.set_yticklabels([])
|
||||
despine(ax, ["top", "right"], False)
|
||||
if i == 2:
|
||||
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
|
||||
color="#d62728", ha="left", fontsize=7)
|
||||
|
||||
if i == 1:
|
||||
ax.set_xlabel("time [ms]")
|
||||
if i == 0:
|
||||
ax.set_ylabel("frequency [Hz]", va="center")
|
||||
ax.yaxis.set_label_coords(-0.45, 3.5)
|
||||
|
||||
name = figure_name if figure_name is not None else "chirp_responses.pdf"
|
||||
name = (name + ".pdf") if ".pdf" not in name else name
|
||||
plt.savefig(os.path.join(figure_folder, name))
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_chirp_metadata(block):
|
||||
trial_duration = float(block.metadata["stimulus parameter"]["duration"])
|
||||
dt = float(block.metadata["stimulus parameter"]["dt"])
|
||||
chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
|
||||
chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
|
||||
chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
|
||||
|
||||
return trial_duration, dt, chirp_size, chirp_duration, chirp_times
|
||||
|
||||
|
||||
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
|
||||
@ -436,79 +263,6 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
|
||||
fig.savefig(os.path.join(figure_folder, name))
|
||||
|
||||
|
||||
def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df):
|
||||
conditions = ["no-other", "self", "other"]
|
||||
condition_labels = ["soliloquy", "self chirping", "other chirping"]
|
||||
min_time = 0.5
|
||||
max_time = min_time + 0.5
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 2.))
|
||||
fig_grid = (3, len(all_conditions)*3+2)
|
||||
axes = []
|
||||
for i, condition in enumerate(conditions):
|
||||
# plot the signals
|
||||
block = block_map[(all_contrasts[0], current_df, condition)]
|
||||
signal, self_freq, other_freq, time = get_signals(block)
|
||||
|
||||
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
|
||||
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
|
||||
|
||||
# plot frequency traces
|
||||
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
|
||||
color="#ff7f0e", label="%iHz" % self_eodf)
|
||||
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
|
||||
if other_freq is not None:
|
||||
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
|
||||
color="#1f77b4", label="%iHz" % other_eodf)
|
||||
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
|
||||
# ax.set_title(condition_labels[i])
|
||||
ax.set_ylim([735, 885])
|
||||
despine(ax, ["top", "bottom", "left", "right"], True)
|
||||
axes.append(ax)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.675, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
rect = Rectangle((0.57, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[0].add_collection(pc)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.675, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
rect = Rectangle((0.575, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[1].add_collection(pc)
|
||||
|
||||
rects = []
|
||||
rect = Rectangle((0.57, 740), 0.098, 140)
|
||||
rects.append(rect)
|
||||
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
|
||||
axes[2].add_collection(pc)
|
||||
|
||||
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
|
||||
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
|
||||
axes[1].add_artist(con)
|
||||
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
|
||||
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
|
||||
axes[1].add_artist(con)
|
||||
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
|
||||
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
|
||||
axes[1].add_artist(con)
|
||||
|
||||
axes[0].text(1., 660, "2.")
|
||||
axes[1].text(1.05, 660, "3.")
|
||||
axes[0].text(1.1, 890, "1.")
|
||||
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
|
||||
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
|
||||
plt.close()
|
||||
|
||||
|
||||
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
|
||||
dfs = [current_df] if current_df is not None else all_dfs
|
||||
kernels = [0.00025, 0.0005, 0.001, 0.0025]
|
||||
@ -530,6 +284,7 @@ def estimate_chirp_phase(am, chirp_times):
|
||||
|
||||
|
||||
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
|
||||
print(filename)
|
||||
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
|
||||
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
|
||||
if "baseline" in block_map.keys():
|
||||
@ -542,41 +297,16 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
|
||||
return results
|
||||
|
||||
|
||||
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
|
||||
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
|
||||
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
|
||||
if "baseline" in block_map.keys():
|
||||
baseline_spikes = read_baseline(block_map["baseline"])
|
||||
else:
|
||||
print("ERROR: no baseline data for file %s!" % filename)
|
||||
|
||||
# plot the responses
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
|
||||
|
||||
# sketch showing the comparisons
|
||||
# plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
|
||||
|
||||
# plot the discrimination analyses
|
||||
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
|
||||
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
|
||||
# cell_name=cell_name, store_roc=True)
|
||||
# pdf = pd.DataFrame(results)
|
||||
# plot_detection_results(pdf, 20, 0.001, cell_name)
|
||||
|
||||
nf.close()
|
||||
|
||||
|
||||
def main():
|
||||
num_cores = multiprocessing.cpu_count() - 6
|
||||
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
|
||||
for nix_file in nix_files:
|
||||
#plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
|
||||
results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"])
|
||||
# break
|
||||
embed()
|
||||
|
||||
|
||||
processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files)
|
||||
results = []
|
||||
for pr in processed_list:
|
||||
results.extend(pr)
|
||||
df = pd.DataFrame(results)
|
||||
df.to_csv(os.path.join(data_folder, "discimination_results.csv"), sep=";")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user