separate plotting and analysis

This commit is contained in:
Jan Grewe 2020-09-25 17:22:55 +02:00
parent 511fddbeb2
commit d0bde3b673
5 changed files with 364 additions and 306 deletions

View File

@ -45,11 +45,12 @@ Won't do, this is trivial?!
* calculate the discriminability between the baseline (no-other fish present) and the another fish is present for each contrast
* Work out the difference between the soliloquy and the response to self generated chirp in a communication context -> done
* Compare to the beat alone parts of the responses. -> done
* What kernels to use?
* What kernels to use? -> done
* Duration of the chrip window?
* sorting according to phase?
* we could filter the P-unit responses to model the ELL filering
* we could filter the P-unit responses to model the ELL filtering
### 4 plot discrimination results
## Random thoughts
* who is sending the chrips? Henninger and also Hupe illustrate the subordinant fish is chirping.

98
nix_util.py Normal file
View File

@ -0,0 +1,98 @@
import nixio as nix
import numpy as np
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_spikes(block):
"""Get the spike trains.
Args:
block ([type]): [description]
Returns:
list of np.ndarray: the spike trains.
"""
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
for k in sorted(response_map.keys()):
spikes.append(response_map[k][:])
return spikes
def get_signals(block):
"""Read the fish signals from block.
Args:
block ([type]): the block containing the data for a given df, contrast and condition
Raises:
ValueError: when the complete stimulus data is not found
ValueError: when the no-other animal data is not found
Returns:
np.ndarray: the complete signal
np.ndarray: the frequency profile of the recorded fish
np.ndarray: the frequency profile of the other fish
np.ndarray: the time axis
"""
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def get_chirp_metadata(block):
trial_duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
return trial_duration, dt, chirp_size, chirp_duration, chirp_times

224
plots.py Normal file
View File

@ -0,0 +1,224 @@
import glob
import os
import nixio as nix
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from nix_util import sort_blocks, read_baseline, get_signals
from util import despine
figure_folder = "figures"
data_folder = "data"
def plot_comparisons(current_df=20):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
return
filename = files[0]
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 2.))
fig_grid = (3, len(all_conditions)*3+2)
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
_, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
# ax.set_title(condition_labels[i])
ax.set_ylim([735, 885])
despine(ax, ["top", "bottom", "left", "right"], True)
axes.append(ax)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[0].add_collection(pc)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.575, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[1].add_collection(pc)
rects = []
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[2].add_collection(pc)
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
axes[0].text(1., 660, "2.")
axes[1].text(1.05, 660, "3.")
axes[0].text(1.1, 890, "1.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
plt.close()
nf.close()
def create_response_plot(filename, current_df=20, figure_name=None):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
return
filename = files[0]
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
name = figure_name if figure_name is not None else "chirp_responses.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
plt.close()
nf.close()
def response_examples(*kwargs):
if filename in kwargs:
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
def main(task=None, parameter={}):
plot_tasks = {"comparisons": plot_comparisons,
"response_examples": create_response_plot}
if task is not None and task in plot_tasks.keys():
plot_tasks[task](*parameter)
elif task is None:
for t in plot_tasks.keys():
plot_tasks[t](*parameter)
if __name__ == "__main__":
main("comparisons")
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
# plot the responses
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
# sketch showing the comparisons
#plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
# plot the discrimination analyses
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
# cell_name=cell_name, store_roc=True)
# pdf = pd.DataFrame(results)
# plot_detection_results(pdf, 20, 0.001, cell_name)
#nf.close()
pass

View File

@ -7,6 +7,8 @@ from chirp_ams import get_signals
from model import simulate, load_models
from IPython import embed
import matplotlib.pyplot as plt
import multiprocessing
from joblib import Parallel, delayed
data_folder = "data"
@ -181,8 +183,7 @@ def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20):
print("\n")
def main():
models = load_models("models.csv")
def simulate_cell(cell_id, models):
deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200] # Hz, difference frequency between self and other
stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten
"contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125],
@ -192,26 +193,30 @@ def main():
"chirp_frequency": 5, # Hz, how often does the fish chirp
"duration": 5., # s, total duration of simulation
"dt": 1, # s, stepsize of the simulation, to be overwritten
}
for cell_id in range(len(models)):
model_params = models[cell_id]
baseline_spikes = get_baseline_response(model_params, duration=30)
save_baseline_response( "cell_%s.nix" % model_params["cell"], "baseline response", baseline_spikes, model_params)
print("Cell: %s" % model_params["cell"])
for deltaf in deltafs:
stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}
stimulus_params["dt"] = model_params["deltat"]
print("\t Deltaf: %i" % deltaf)
chirp_times = np.arange(stimulus_params["chirp_duration"],
stimulus_params["duration"] - stimulus_params["chirp_duration"],
1./stimulus_params["chirp_frequency"])
stimulus_params["chirp_times"] = chirp_times
simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
if cell_id == 9:
exit() # the first 10 cell only for now!
}
model_params = models[cell_id]
baseline_spikes = get_baseline_response(model_params, duration=30)
filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"])
save_baseline_response(filename, "baseline response", baseline_spikes, model_params)
print("Cell: %s" % model_params["cell"])
for deltaf in deltafs:
stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf}
stimulus_params["dt"] = model_params["deltat"]
print("\t Deltaf: %i" % deltaf)
chirp_times = np.arange(stimulus_params["chirp_duration"],
stimulus_params["duration"] - stimulus_params["chirp_duration"],
1./stimulus_params["chirp_frequency"])
stimulus_params["chirp_times"] = chirp_times
simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf)
def main():
models = load_models("models.csv")
num_cores = multiprocessing.cpu_count() - 6
Parallel(n_jobs=num_cores)(delayed(simulate_cell)(cell_id, models) for cell_id in range(len(models[:10])))
if __name__ == "__main__":

View File

@ -4,71 +4,17 @@ import pandas as pd
import nixio as nix
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from sklearn.metrics import roc_curve, roc_auc_score
from IPython import embed
import multiprocessing
from joblib import Parallel, delayed
from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance
from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata
figure_folder = "figures"
data_folder = "data"
def read_baseline(block):
spikes = []
if "baseline" not in block.name:
print("Block %s does not appear to be a baseline block!" % block.name )
return spikes
spikes = block.data_arrays[0][:]
return spikes
def sort_blocks(nix_file):
block_map = {}
contrasts = []
deltafs = []
conditions = []
for b in nix_file.blocks:
if "baseline" not in b.name.lower():
name_parts = b.name.split("_")
cntrst = float(name_parts[1])
if cntrst not in contrasts:
contrasts.append(cntrst)
cndtn = name_parts[3]
if cndtn not in conditions:
conditions.append(cndtn)
dltf = float(name_parts[5])
if dltf not in deltafs:
deltafs.append(dltf)
block_map[(cntrst, dltf, cndtn)] = b
else:
block_map["baseline"] = b
return block_map, contrasts, deltafs, conditions
def get_spikes(block):
"""Get the spike trains.
Args:
block ([type]): [description]
Returns:
list of np.ndarray: the spike trains.
"""
response_map = {}
spikes = []
for da in block.data_arrays:
if "spike_times" in da.type and "response" in da.name:
resp_id = int(da.name.split("_")[-1])
response_map[resp_id] = da
for k in sorted(response_map.keys()):
spikes.append(response_map[k][:])
return spikes
def get_rates(spike_trains, duration, dt, kernel_width):
"""Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size.
@ -114,126 +60,7 @@ def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005):
return time, rates, spikes
def get_signals(block):
"""Read the fish signals from block.
Args:
block ([type]): the block containing the data for a given df, contrast and condition
Raises:
ValueError: when the complete stimulus data is not found
ValueError: when the no-other animal data is not found
Returns:
np.ndarray: the complete signal
np.ndarray: the frequency profile of the recorded fish
np.ndarray: the frequency profile of the other fish
np.ndarray: the time axis
"""
self_freq = None
other_freq = None
signal = None
time = None
if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
if "no-other" not in block.name and "other frequency" not in block.data_arrays:
raise ValueError("Signals not stored in block!")
signal = block.data_arrays["complete stimulus"][:]
time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal)))
self_freq = block.data_arrays["self frequency"][:]
if "no-other" not in block.name:
other_freq = block.data_arrays["other frequency"][:]
return signal, self_freq, other_freq, time
def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
name = figure_name if figure_name is not None else "chirp_responses.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
plt.savefig(os.path.join(figure_folder, name))
plt.close()
def get_chirp_metadata(block):
trial_duration = float(block.metadata["stimulus parameter"]["duration"])
dt = float(block.metadata["stimulus parameter"]["dt"])
chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"]
chirp_size = block.metadata["stimulus parameter"]["chirp_size"]
chirp_times = block.metadata["stimulus parameter"]["chirp_times"]
return trial_duration, dt, chirp_size, chirp_duration, chirp_times
def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False):
@ -436,79 +263,6 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None)
fig.savefig(os.path.join(figure_folder, name))
def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df):
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 2.))
fig_grid = (3, len(all_conditions)*3+2)
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
# ax.set_title(condition_labels[i])
ax.set_ylim([735, 885])
despine(ax, ["top", "bottom", "left", "right"], True)
axes.append(ax)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[0].add_collection(pc)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.575, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[1].add_collection(pc)
rects = []
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[2].add_collection(pc)
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
axes[0].text(1., 660, "2.")
axes[1].text(1.05, 660, "3.")
axes[0].text(1.1, 890, "1.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
plt.close()
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
dfs = [current_df] if current_df is not None else all_dfs
kernels = [0.00025, 0.0005, 0.001, 0.0025]
@ -530,6 +284,7 @@ def estimate_chirp_phase(am, chirp_times):
def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
print(filename)
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
@ -542,41 +297,16 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]):
return results
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf)
if "baseline" in block_map.keys():
baseline_spikes = read_baseline(block_map["baseline"])
else:
print("ERROR: no baseline data for file %s!" % filename)
# plot the responses
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
# sketch showing the comparisons
# plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
# plot the discrimination analyses
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
# cell_name=cell_name, store_roc=True)
# pdf = pd.DataFrame(results)
# plot_detection_results(pdf, 20, 0.001, cell_name)
nf.close()
def main():
num_cores = multiprocessing.cpu_count() - 6
nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix")))
for nix_file in nix_files:
#plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"])
results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"])
# break
embed()
processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files)
results = []
for pr in processed_list:
results.extend(pr)
df = pd.DataFrame(results)
df.to_csv(os.path.join(data_folder, "discimination_results.csv"), sep=";")
if __name__ == "__main__":
main()