From d0bde3b673880997a9912c82448ccba4ef228eec Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Fri, 25 Sep 2020 17:22:55 +0200 Subject: [PATCH] separate plotting and analysis --- chirps_as_probing_signals.md | 5 +- nix_util.py | 98 ++++++++++++ plots.py | 224 ++++++++++++++++++++++++++ punit_responses.py | 47 +++--- response_discriminability.py | 296 ++--------------------------------- 5 files changed, 364 insertions(+), 306 deletions(-) create mode 100644 nix_util.py create mode 100644 plots.py diff --git a/chirps_as_probing_signals.md b/chirps_as_probing_signals.md index 2ae2212..2869b82 100644 --- a/chirps_as_probing_signals.md +++ b/chirps_as_probing_signals.md @@ -45,11 +45,12 @@ Won't do, this is trivial?! * calculate the discriminability between the baseline (no-other fish present) and the another fish is present for each contrast * Work out the difference between the soliloquy and the response to self generated chirp in a communication context -> done * Compare to the beat alone parts of the responses. -> done -* What kernels to use? +* What kernels to use? -> done * Duration of the chrip window? * sorting according to phase? -* we could filter the P-unit responses to model the ELL filering +* we could filter the P-unit responses to model the ELL filtering +### 4 plot discrimination results ## Random thoughts * who is sending the chrips? Henninger and also Hupe illustrate the subordinant fish is chirping. diff --git a/nix_util.py b/nix_util.py new file mode 100644 index 0000000..e212413 --- /dev/null +++ b/nix_util.py @@ -0,0 +1,98 @@ +import nixio as nix +import numpy as np + +def read_baseline(block): + spikes = [] + if "baseline" not in block.name: + print("Block %s does not appear to be a baseline block!" % block.name ) + return spikes + spikes = block.data_arrays[0][:] + return spikes + + +def sort_blocks(nix_file): + block_map = {} + contrasts = [] + deltafs = [] + conditions = [] + for b in nix_file.blocks: + if "baseline" not in b.name.lower(): + name_parts = b.name.split("_") + cntrst = float(name_parts[1]) + if cntrst not in contrasts: + contrasts.append(cntrst) + cndtn = name_parts[3] + if cndtn not in conditions: + conditions.append(cndtn) + dltf = float(name_parts[5]) + if dltf not in deltafs: + deltafs.append(dltf) + block_map[(cntrst, dltf, cndtn)] = b + else: + block_map["baseline"] = b + return block_map, contrasts, deltafs, conditions + + +def get_spikes(block): + """Get the spike trains. + + Args: + block ([type]): [description] + + Returns: + list of np.ndarray: the spike trains. + """ + response_map = {} + spikes = [] + + for da in block.data_arrays: + if "spike_times" in da.type and "response" in da.name: + resp_id = int(da.name.split("_")[-1]) + response_map[resp_id] = da + for k in sorted(response_map.keys()): + spikes.append(response_map[k][:]) + + return spikes + + +def get_signals(block): + """Read the fish signals from block. + + Args: + block ([type]): the block containing the data for a given df, contrast and condition + + Raises: + ValueError: when the complete stimulus data is not found + ValueError: when the no-other animal data is not found + + Returns: + np.ndarray: the complete signal + np.ndarray: the frequency profile of the recorded fish + np.ndarray: the frequency profile of the other fish + np.ndarray: the time axis + """ + self_freq = None + other_freq = None + signal = None + time = None + if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays: + raise ValueError("Signals not stored in block!") + if "no-other" not in block.name and "other frequency" not in block.data_arrays: + raise ValueError("Signals not stored in block!") + + signal = block.data_arrays["complete stimulus"][:] + time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal))) + self_freq = block.data_arrays["self frequency"][:] + if "no-other" not in block.name: + other_freq = block.data_arrays["other frequency"][:] + return signal, self_freq, other_freq, time + + +def get_chirp_metadata(block): + trial_duration = float(block.metadata["stimulus parameter"]["duration"]) + dt = float(block.metadata["stimulus parameter"]["dt"]) + chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"] + chirp_size = block.metadata["stimulus parameter"]["chirp_size"] + chirp_times = block.metadata["stimulus parameter"]["chirp_times"] + + return trial_duration, dt, chirp_size, chirp_duration, chirp_times diff --git a/plots.py b/plots.py new file mode 100644 index 0000000..eb5ace4 --- /dev/null +++ b/plots.py @@ -0,0 +1,224 @@ +import glob +import os +import nixio as nix +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle +from matplotlib.collections import PatchCollection +from matplotlib.patches import ConnectionPatch + +from nix_util import sort_blocks, read_baseline, get_signals +from util import despine + +figure_folder = "figures" +data_folder = "data" + + +def plot_comparisons(current_df=20): + files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) + if len(files) < 1: + print("plot comparisons: no data found!") + return + filename = files[0] + nf = nix.File.open(filename, nix.FileMode.ReadOnly) + block_map, all_contrasts, _, all_conditions = sort_blocks(nf) + conditions = ["no-other", "self", "other"] + min_time = 0.5 + max_time = min_time + 0.5 + + fig = plt.figure(figsize=(6.5, 2.)) + fig_grid = (3, len(all_conditions)*3+2) + axes = [] + for i, condition in enumerate(conditions): + # plot the signals + block = block_map[(all_contrasts[0], current_df, condition)] + _, self_freq, other_freq, time = get_signals(block) + + self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] + other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] + + # plot frequency traces + ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) + ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], + color="#ff7f0e", label="%iHz" % self_eodf) + ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) + if other_freq is not None: + ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], + color="#1f77b4", label="%iHz" % other_eodf) + ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) + # ax.set_title(condition_labels[i]) + ax.set_ylim([735, 885]) + despine(ax, ["top", "bottom", "left", "right"], True) + axes.append(ax) + + rects = [] + rect = Rectangle((0.675, 740), 0.098, 140) + rects.append(rect) + rect = Rectangle((0.57, 740), 0.098, 140) + rects.append(rect) + + pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") + axes[0].add_collection(pc) + + rects = [] + rect = Rectangle((0.675, 740), 0.098, 140) + rects.append(rect) + rect = Rectangle((0.575, 740), 0.098, 140) + rects.append(rect) + + pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") + axes[1].add_collection(pc) + + rects = [] + rect = Rectangle((0.57, 740), 0.098, 140) + rects.append(rect) + pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") + axes[2].add_collection(pc) + + con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", + axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") + axes[1].add_artist(con) + con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data", + axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") + axes[1].add_artist(con) + con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", + axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") + axes[1].add_artist(con) + + axes[0].text(1., 660, "2.") + axes[1].text(1.05, 660, "3.") + axes[0].text(1.1, 890, "1.") + fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) + fig.savefig(os.path.join(figure_folder, "comparisons.pdf")) + plt.close() + nf.close() + + +def create_response_plot(filename, current_df=20, figure_name=None): + files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) + if len(files) < 1: + print("plot comparisons: no data found!") + return + filename = files[0] + nf = nix.File.open(filename, nix.FileMode.ReadOnly) + block_map, all_contrasts, _, all_conditions = sort_blocks(nf) + + conditions = ["no-other", "self", "other"] + condition_labels = ["soliloquy", "self chirping", "other chirping"] + min_time = 0.5 + max_time = min_time + 0.5 + + fig = plt.figure(figsize=(6.5, 5.5)) + fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) + all_contrasts = sorted(all_contrasts, reverse=True) + + for i, condition in enumerate(conditions): + # plot the signals + block = block_map[(all_contrasts[0], current_df, condition)] + signal, self_freq, other_freq, time = get_signals(block) + am = extract_am(signal) + + self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] + other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] + + # plot frequency traces + ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) + ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], + color="#ff7f0e", label="%iHz" % self_eodf) + ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) + if other_freq is not None: + ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], + color="#1f77b4", label="%iHz" % other_eodf) + ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) + ax.set_title(condition_labels[i]) + despine(ax, ["top", "bottom", "left", "right"], True) + + # plot the am + ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) + ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], + color="#2ca02c", label="signal") + ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], + color="#d62728", label="am") + despine(ax, ["top", "bottom", "left", "right"], True) + ax.set_ylim([-1.25, 1.25]) + ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) + + # for each contrast plot the firing rate + for j, contrast in enumerate(all_contrasts): + t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) + avg_resp = np.mean(rates, axis=0) + error = np.std(rates, axis=0) + ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) + ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) + ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], + (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) + ax.set_ylim([0, 750]) + ax.set_xlabel("") + ax.set_ylabel("") + ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) + ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) + ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) + if j < len(all_contrasts) -1: + ax.set_xticklabels([]) + ax.set_yticks(np.arange(0.0, 751., 500)) + ax.set_yticks(np.arange(0.0, 751., 125), minor=True) + if i > 0: + ax.set_yticklabels([]) + despine(ax, ["top", "right"], False) + if i == 2: + ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], + color="#d62728", ha="left", fontsize=7) + + if i == 1: + ax.set_xlabel("time [ms]") + if i == 0: + ax.set_ylabel("frequency [Hz]", va="center") + ax.yaxis.set_label_coords(-0.45, 3.5) + + name = figure_name if figure_name is not None else "chirp_responses.pdf" + name = (name + ".pdf") if ".pdf" not in name else name + plt.savefig(os.path.join(figure_folder, name)) + plt.close() + nf.close() + +def response_examples(*kwargs): + + if filename in kwargs: + + #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" + #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) + #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" + #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) + +def main(task=None, parameter={}): + plot_tasks = {"comparisons": plot_comparisons, + "response_examples": create_response_plot} + if task is not None and task in plot_tasks.keys(): + plot_tasks[task](*parameter) + elif task is None: + for t in plot_tasks.keys(): + plot_tasks[t](*parameter) + + +if __name__ == "__main__": + main("comparisons") + + +def plot_examples(filename, dfs=[], contrasts=[], conditions=[]): + # plot the responses + #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" + #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) + #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" + #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) + + # sketch showing the comparisons + #plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20) + + # plot the discrimination analyses + #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0] + # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, + # cell_name=cell_name, store_roc=True) + # pdf = pd.DataFrame(results) + # plot_detection_results(pdf, 20, 0.001, cell_name) + + #nf.close() + pass \ No newline at end of file diff --git a/punit_responses.py b/punit_responses.py index 090cd1e..698ae13 100644 --- a/punit_responses.py +++ b/punit_responses.py @@ -7,6 +7,8 @@ from chirp_ams import get_signals from model import simulate, load_models from IPython import embed import matplotlib.pyplot as plt +import multiprocessing +from joblib import Parallel, delayed data_folder = "data" @@ -181,8 +183,7 @@ def simulate_responses(stimulus_params, model_params, repeats=10, deltaf=20): print("\n") -def main(): - models = load_models("models.csv") +def simulate_cell(cell_id, models): deltafs = [-200, -100, -50, -20, -10, -5, 5, 10, 20, 50, 100, 200] # Hz, difference frequency between self and other stimulus_params = { "eodfs": {"self": 0.0, "other": 0.0}, # eod frequency in Hz, to be overwritten "contrasts": [20, 10, 5, 2.5, 1.25, 0.625, 0.3125], @@ -192,26 +193,30 @@ def main(): "chirp_frequency": 5, # Hz, how often does the fish chirp "duration": 5., # s, total duration of simulation "dt": 1, # s, stepsize of the simulation, to be overwritten - } - - for cell_id in range(len(models)): - model_params = models[cell_id] - baseline_spikes = get_baseline_response(model_params, duration=30) - save_baseline_response( "cell_%s.nix" % model_params["cell"], "baseline response", baseline_spikes, model_params) + } + model_params = models[cell_id] + baseline_spikes = get_baseline_response(model_params, duration=30) + filename = os.path.join(data_folder, "cell_%s.nix" % model_params["cell"]) + save_baseline_response(filename, "baseline response", baseline_spikes, model_params) - print("Cell: %s" % model_params["cell"]) - for deltaf in deltafs: - stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf} - stimulus_params["dt"] = model_params["deltat"] - - print("\t Deltaf: %i" % deltaf) - chirp_times = np.arange(stimulus_params["chirp_duration"], - stimulus_params["duration"] - stimulus_params["chirp_duration"], - 1./stimulus_params["chirp_frequency"]) - stimulus_params["chirp_times"] = chirp_times - simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf) - if cell_id == 9: - exit() # the first 10 cell only for now! + print("Cell: %s" % model_params["cell"]) + for deltaf in deltafs: + stimulus_params["eodfs"] = {"self": model_params["EODf"], "other": model_params["EODf"] + deltaf} + stimulus_params["dt"] = model_params["deltat"] + + print("\t Deltaf: %i" % deltaf) + chirp_times = np.arange(stimulus_params["chirp_duration"], + stimulus_params["duration"] - stimulus_params["chirp_duration"], + 1./stimulus_params["chirp_frequency"]) + stimulus_params["chirp_times"] = chirp_times + simulate_responses(stimulus_params, model_params, repeats=25, deltaf=deltaf) + + +def main(): + models = load_models("models.csv") + num_cores = multiprocessing.cpu_count() - 6 + + Parallel(n_jobs=num_cores)(delayed(simulate_cell)(cell_id, models) for cell_id in range(len(models[:10]))) if __name__ == "__main__": diff --git a/response_discriminability.py b/response_discriminability.py index 01e4bff..828a0a9 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -4,71 +4,17 @@ import pandas as pd import nixio as nix import numpy as np import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle -from matplotlib.collections import PatchCollection -from matplotlib.patches import ConnectionPatch from sklearn.metrics import roc_curve, roc_auc_score from IPython import embed +import multiprocessing +from joblib import Parallel, delayed from util import firing_rate, despine, extract_am, within_group_distance, across_group_distance +from nix_util import read_baseline, sort_blocks, get_spikes, get_signals, get_chirp_metadata + figure_folder = "figures" data_folder = "data" - -def read_baseline(block): - spikes = [] - if "baseline" not in block.name: - print("Block %s does not appear to be a baseline block!" % block.name ) - return spikes - spikes = block.data_arrays[0][:] - return spikes - - -def sort_blocks(nix_file): - block_map = {} - contrasts = [] - deltafs = [] - conditions = [] - for b in nix_file.blocks: - if "baseline" not in b.name.lower(): - name_parts = b.name.split("_") - cntrst = float(name_parts[1]) - if cntrst not in contrasts: - contrasts.append(cntrst) - cndtn = name_parts[3] - if cndtn not in conditions: - conditions.append(cndtn) - dltf = float(name_parts[5]) - if dltf not in deltafs: - deltafs.append(dltf) - block_map[(cntrst, dltf, cndtn)] = b - else: - block_map["baseline"] = b - return block_map, contrasts, deltafs, conditions - - -def get_spikes(block): - """Get the spike trains. - - Args: - block ([type]): [description] - - Returns: - list of np.ndarray: the spike trains. - """ - response_map = {} - spikes = [] - - for da in block.data_arrays: - if "spike_times" in da.type and "response" in da.name: - resp_id = int(da.name.split("_")[-1]) - response_map[resp_id] = da - for k in sorted(response_map.keys()): - spikes.append(response_map[k][:]) - - return spikes - - def get_rates(spike_trains, duration, dt, kernel_width): """Convert the spike trains (list of spike_times) to rates using a Gaussian kernel of the given size. @@ -114,126 +60,7 @@ def get_firing_rate(block_map, df, contrast, condition, kernel_width=0.0005): return time, rates, spikes -def get_signals(block): - """Read the fish signals from block. - Args: - block ([type]): the block containing the data for a given df, contrast and condition - - Raises: - ValueError: when the complete stimulus data is not found - ValueError: when the no-other animal data is not found - - Returns: - np.ndarray: the complete signal - np.ndarray: the frequency profile of the recorded fish - np.ndarray: the frequency profile of the other fish - np.ndarray: the time axis - """ - self_freq = None - other_freq = None - signal = None - time = None - if "complete stimulus" not in block.data_arrays or "self frequency" not in block.data_arrays: - raise ValueError("Signals not stored in block!") - if "no-other" not in block.name and "other frequency" not in block.data_arrays: - raise ValueError("Signals not stored in block!") - - signal = block.data_arrays["complete stimulus"][:] - time = np.asarray(block.data_arrays["complete stimulus"].dimensions[0].axis(len(signal))) - self_freq = block.data_arrays["self frequency"][:] - if "no-other" not in block.name: - other_freq = block.data_arrays["other frequency"][:] - return signal, self_freq, other_freq, time - - -def create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, current_df, figure_name=None): - conditions = ["no-other", "self", "other"] - condition_labels = ["soliloquy", "self chirping", "other chirping"] - min_time = 0.5 - max_time = min_time + 0.5 - - fig = plt.figure(figsize=(6.5, 5.5)) - fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) - all_contrasts = sorted(all_contrasts, reverse=True) - - for i, condition in enumerate(conditions): - # plot the signals - block = block_map[(all_contrasts[0], current_df, condition)] - signal, self_freq, other_freq, time = get_signals(block) - am = extract_am(signal) - - self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] - other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] - - # plot frequency traces - ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) - ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], - color="#ff7f0e", label="%iHz" % self_eodf) - ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) - if other_freq is not None: - ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], - color="#1f77b4", label="%iHz" % other_eodf) - ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) - ax.set_title(condition_labels[i]) - despine(ax, ["top", "bottom", "left", "right"], True) - - # plot the am - ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) - ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], - color="#2ca02c", label="signal") - ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], - color="#d62728", label="am") - despine(ax, ["top", "bottom", "left", "right"], True) - ax.set_ylim([-1.25, 1.25]) - ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) - - # for each contrast plot the firing rate - for j, contrast in enumerate(all_contrasts): - t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) - avg_resp = np.mean(rates, axis=0) - error = np.std(rates, axis=0) - ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) - ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) - ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], - (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) - ax.set_ylim([0, 750]) - ax.set_xlabel("") - ax.set_ylabel("") - ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) - ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) - ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) - if j < len(all_contrasts) -1: - ax.set_xticklabels([]) - ax.set_yticks(np.arange(0.0, 751., 500)) - ax.set_yticks(np.arange(0.0, 751., 125), minor=True) - if i > 0: - ax.set_yticklabels([]) - despine(ax, ["top", "right"], False) - if i == 2: - ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], - color="#d62728", ha="left", fontsize=7) - - if i == 1: - ax.set_xlabel("time [ms]") - if i == 0: - ax.set_ylabel("frequency [Hz]", va="center") - ax.yaxis.set_label_coords(-0.45, 3.5) - - name = figure_name if figure_name is not None else "chirp_responses.pdf" - name = (name + ".pdf") if ".pdf" not in name else name - plt.savefig(os.path.join(figure_folder, name)) - plt.close() - - -def get_chirp_metadata(block): - trial_duration = float(block.metadata["stimulus parameter"]["duration"]) - dt = float(block.metadata["stimulus parameter"]["dt"]) - chirp_duration = block.metadata["stimulus parameter"]["chirp_duration"] - chirp_size = block.metadata["stimulus parameter"]["chirp_size"] - chirp_times = block.metadata["stimulus parameter"]["chirp_times"] - - return trial_duration, dt, chirp_size, chirp_duration, chirp_times def foreign_fish_detection_beat(block_map, df, all_contrasts, all_conditions, kernel_width=0.0005, cell_name="", store_roc=False): @@ -436,79 +263,6 @@ def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None) fig.savefig(os.path.join(figure_folder, name)) -def plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, current_df): - conditions = ["no-other", "self", "other"] - condition_labels = ["soliloquy", "self chirping", "other chirping"] - min_time = 0.5 - max_time = min_time + 0.5 - - fig = plt.figure(figsize=(6.5, 2.)) - fig_grid = (3, len(all_conditions)*3+2) - axes = [] - for i, condition in enumerate(conditions): - # plot the signals - block = block_map[(all_contrasts[0], current_df, condition)] - signal, self_freq, other_freq, time = get_signals(block) - - self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] - other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] - - # plot frequency traces - ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) - ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], - color="#ff7f0e", label="%iHz" % self_eodf) - ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) - if other_freq is not None: - ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], - color="#1f77b4", label="%iHz" % other_eodf) - ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) - # ax.set_title(condition_labels[i]) - ax.set_ylim([735, 885]) - despine(ax, ["top", "bottom", "left", "right"], True) - axes.append(ax) - - rects = [] - rect = Rectangle((0.675, 740), 0.098, 140) - rects.append(rect) - rect = Rectangle((0.57, 740), 0.098, 140) - rects.append(rect) - - pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") - axes[0].add_collection(pc) - - rects = [] - rect = Rectangle((0.675, 740), 0.098, 140) - rects.append(rect) - rect = Rectangle((0.575, 740), 0.098, 140) - rects.append(rect) - - pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") - axes[1].add_collection(pc) - - rects = [] - rect = Rectangle((0.57, 740), 0.098, 140) - rects.append(rect) - pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") - axes[2].add_collection(pc) - - con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", - axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") - axes[1].add_artist(con) - con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data", - axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") - axes[1].add_artist(con) - con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", - axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") - axes[1].add_artist(con) - - axes[0].text(1., 660, "2.") - axes[1].text(1.05, 660, "3.") - axes[0].text(1.1, 890, "1.") - fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) - fig.savefig(os.path.join(figure_folder, "comparisons.pdf")) - plt.close() - - def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False): dfs = [current_df] if current_df is not None else all_dfs kernels = [0.00025, 0.0005, 0.001, 0.0025] @@ -530,6 +284,7 @@ def estimate_chirp_phase(am, chirp_times): def process_cell(filename, dfs=[], contrasts=[], conditions=[]): + print(filename) nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) if "baseline" in block_map.keys(): @@ -542,41 +297,16 @@ def process_cell(filename, dfs=[], contrasts=[], conditions=[]): return results -def plot_examples(filename, dfs=[], contrasts=[], conditions=[]): - nf = nix.File.open(filename, nix.FileMode.ReadOnly) - block_map, all_contrasts, all_dfs, all_conditions = sort_blocks(nf) - if "baseline" in block_map.keys(): - baseline_spikes = read_baseline(block_map["baseline"]) - else: - print("ERROR: no baseline data for file %s!" % filename) - - # plot the responses - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) - - # sketch showing the comparisons - # plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20) - - # plot the discrimination analyses - #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0] - # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, - # cell_name=cell_name, store_roc=True) - # pdf = pd.DataFrame(results) - # plot_detection_results(pdf, 20, 0.001, cell_name) - - nf.close() - - def main(): + num_cores = multiprocessing.cpu_count() - 6 nix_files = sorted(glob.glob(os.path.join(data_folder, "cell*.nix"))) - for nix_file in nix_files: - #plot_examples(nix_file, dfs=[20], contrasts=[20], conditions=["self"]) - results = process_cell(nix_file, dfs=[], contrasts=[20], conditions=["self"]) - # break - embed() - + + processed_list = Parallel(n_jobs=num_cores)(delayed(process_cell)(nix_file) for nix_file in nix_files) + results = [] + for pr in processed_list: + results.extend(pr) + df = pd.DataFrame(results) + df.to_csv(os.path.join(data_folder, "discimination_results.csv"), sep=";") if __name__ == "__main__": main() \ No newline at end of file