import glob import os import argparse import nixio as nix import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection from matplotlib.patches import ConnectionPatch from matplotlib import cm from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import from IPython import embed from nix_util import sort_blocks, read_baseline, get_signals from util import despine, extract_am from response_discriminability import get_firing_rate, foreign_fish_detection figure_folder = "figures" data_folder = "data" def plot_comparisons(args): files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) if len(files) < 1: print("plot comparisons: no data found!") return filename = files[0] nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, _, all_conditions = sort_blocks(nf) conditions = ["no-other", "self", "other"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 2.)) fig_grid = (3, len(all_conditions)*3+2) axes = [] for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], args.current_df, condition)] _, self_freq, other_freq, time = get_signals(block) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) # ax.set_title(condition_labels[i]) ax.set_ylim([735, 885]) despine(ax, ["top", "bottom", "left", "right"], True) axes.append(ax) rects = [] rect = Rectangle((0.675, 740), 0.098, 140) rects.append(rect) rect = Rectangle((0.57, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[0].add_collection(pc) rects = [] rect = Rectangle((0.675, 740), 0.098, 140) rects.append(rect) rect = Rectangle((0.575, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[1].add_collection(pc) rects = [] rect = Rectangle((0.57, 740), 0.098, 140) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[2].add_collection(pc) con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) axes[0].text(1., 660, "2.") axes[1].text(1.05, 660, "3.") axes[0].text(1.1, 890, "1.") fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) fig.savefig(args.outfile) plt.close() nf.close() def create_response_plot(filename, current_df=20, figure_name=None): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, _, all_conditions = sort_blocks(nf) conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) all_contrasts = sorted(all_contrasts, reverse=True) for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) am = extract_am(signal) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) ax.set_title(condition_labels[i]) despine(ax, ["top", "bottom", "left", "right"], True) # plot the am ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], color="#2ca02c", label="signal") ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], color="#d62728", label="am") despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) # for each contrast plot the firing rate for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) ax.set_ylim([0, 750]) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) if j < len(all_contrasts) -1: ax.set_xticklabels([]) ax.set_yticks(np.arange(0.0, 751., 500)) ax.set_yticks(np.arange(0.0, 751., 125), minor=True) if i > 0: ax.set_yticklabels([]) despine(ax, ["top", "right"], False) if i == 2: ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], color="#d62728", ha="left", fontsize=7) if i == 1: ax.set_xlabel("time [ms]") if i == 0: ax.set_ylabel("frequency [Hz]", va="center") ax.yaxis.set_label_coords(-0.45, 3.5) name = figure_name if figure_name is not None else "chirp_responses.pdf" name = (name + ".pdf") if ".pdf" not in name else name plt.savefig(os.path.join(figure_folder, name)) plt.close() nf.close() def response_examples(): filename = sorted(glob.glob(os.path.join(data_folder, "*.nix")))[0] fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" create_response_plot(filename, 20, figure_name=fig_name) fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" create_response_plot(filename, -100, figure_name=fig_name) def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None): cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] conditions = sorted(cell_results.detection_task.unique()) kernels = sorted(cell_results.kernel_width.unique()) fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (8, 7) for i, c in enumerate(conditions): condition_results = cell_results[cell_results.detection_task == c] roc_data = condition_results[condition_results.kernel_width == kernel_width] contrasts = roc_data.contrast.unique() roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) roc_ax.set_title(c, fontsize=9, loc="left") auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) for c in contrasts: tpr = roc_data.true_positives[roc_data.contrast == c].values[0] fpr = roc_data.false_positives[roc_data.contrast == c].values[0] roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) if i == 0: roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25) roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) if i == len(conditions) - 1: roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) roc_ax.set_xlabel("false positive rate", fontsize=9) roc_ax.set_ylabel("true positive rate", fontsize=9) roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) for k in kernels: contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) aucs_sorted = aucs[np.argsort(contrasts)] contrasts_sorted = np.sort(contrasts) auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1) if i == len(conditions) - 1: auc_ax.set_xlabel("contrast [%]", fontsize=9) auc_ax.set_ylim([0.25, 1.0]) auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25)) auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8) auc_ax.set_ylabel("discriminability", fontsize=9) if i == 0: auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center") auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0) name = figure_name if figure_name is not None else "foreign_fish_detection.pdf" name = (name + ".pdf") if ".pdf" not in name else name fig.savefig(os.path.join(figure_folder, name)) def foreign_fish_detection_example_plot(): files = glob.glob(os.path.join(data_folder, "*discriminations.h5")) if len(files) == 0: raise ValueError("no discrimination results found!") store = pd.HDFStore(files[0]) data_frame = store.get("discrimination_results") embed() plot_detection_results(data_frame, 20, 0.001, ) pass def performance_plot(args): df = pd.read_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";") dfs = np.sort(df.df.unique()) contrasts = np.sort(df.contrast.unique()) tasks = df.detection_task.unique() kernel_widths = list(df.kernel_width.unique()) kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0] X, Y = np.meshgrid(dfs, contrasts) Z = np.zeros_like(X) fig = plt.figure(figsize=(8.0, 10)) fig_grid = (19, 10) for index, t in enumerate(tasks): ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" ) ax.set_title(t, loc="left", pad=-0.5) for i, d in enumerate(dfs): for j, c in enumerate(contrasts): Z[j, i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0) ax.set_xlabel(r"$\Delta_f [Hz]$") ax.set_ylabel("contrast [%]") ax.set_zlabel("performance") ax.set_zlim([0.45, 1.0]) ax.set_zticks(np.arange(0.5, 1.01, 0.25)) ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True) cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2) performances = np.zeros_like(contrasts) errors = np.zeros_like(contrasts) for d in args.deltafs: for i, c in enumerate(contrasts): performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d) cntrst_ax.set_ylim([0.25, 1.0]) cntrst_ax.set_ylabel("performance", fontsize=8) cntrst_ax.set_xlabel("contrast [%]", fontsize=8) cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2) cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2) performances = np.zeros_like(dfs) errors = np.zeros_like(dfs) for c in args.contrasts: for i, d in enumerate(dfs): performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c) df_ax.set_ylim([0.25, 1.0]) df_ax.set_ylabel("performance", fontsize=8) df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8) df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2) df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975) fig.savefig(args.outfile) plt.close() def main(): cmd_map = {"comparisons": plot_comparisons, "response_examples": response_examples, "fish_detection_example": foreign_fish_detection_example_plot} parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.") subparsers = parser.add_subparsers(title="commands", help="Sub commands for plotting different figures", description="", dest="explore_cmd") comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons") comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting") comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot") comp_parser.set_defaults(func=plot_comparisons) perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells") perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot") perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting") perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-100, 20, 100], help="deltaf for individual plot") perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot") perf_parser.set_defaults(func=performance_plot) args = parser.parse_args() args.func(args) if __name__ == "__main__": main()