diff --git a/plots.py b/plots.py index eb5ace4..6a0cb5e 100644 --- a/plots.py +++ b/plots.py @@ -1,19 +1,27 @@ import glob import os +import argparse import nixio as nix +import numpy as np +import pandas as pd import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection from matplotlib.patches import ConnectionPatch +from matplotlib import cm +from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import + +from IPython import embed from nix_util import sort_blocks, read_baseline, get_signals -from util import despine +from util import despine, extract_am +from response_discriminability import get_firing_rate, foreign_fish_detection figure_folder = "figures" data_folder = "data" -def plot_comparisons(current_df=20): +def plot_comparisons(args): files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) if len(files) < 1: print("plot comparisons: no data found!") @@ -30,7 +38,7 @@ def plot_comparisons(current_df=20): axes = [] for i, condition in enumerate(conditions): # plot the signals - block = block_map[(all_contrasts[0], current_df, condition)] + block = block_map[(all_contrasts[0], args.current_df, condition)] _, self_freq, other_freq, time = get_signals(block) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] @@ -88,17 +96,12 @@ def plot_comparisons(current_df=20): axes[1].text(1.05, 660, "3.") axes[0].text(1.1, 890, "1.") fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) - fig.savefig(os.path.join(figure_folder, "comparisons.pdf")) + fig.savefig(args.outfile) plt.close() nf.close() def create_response_plot(filename, current_df=20, figure_name=None): - files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) - if len(files) < 1: - print("plot comparisons: no data found!") - return - filename = files[0] nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, _, all_conditions = sort_blocks(nf) @@ -180,45 +183,164 @@ def create_response_plot(filename, current_df=20, figure_name=None): plt.close() nf.close() -def response_examples(*kwargs): - - if filename in kwargs: - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) - -def main(task=None, parameter={}): - plot_tasks = {"comparisons": plot_comparisons, - "response_examples": create_response_plot} - if task is not None and task in plot_tasks.keys(): - plot_tasks[task](*parameter) - elif task is None: - for t in plot_tasks.keys(): - plot_tasks[t](*parameter) +def response_examples(): + filename = sorted(glob.glob(os.path.join(data_folder, "*.nix")))[0] + fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" + create_response_plot(filename, 20, figure_name=fig_name) + fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" + create_response_plot(filename, -100, figure_name=fig_name) -if __name__ == "__main__": - main("comparisons") +def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None): + cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] + conditions = sorted(cell_results.detection_task.unique()) + kernels = sorted(cell_results.kernel_width.unique()) + + fig = plt.figure(figsize=(6.5, 5.5)) + fig_grid = (8, 7) + for i, c in enumerate(conditions): + condition_results = cell_results[cell_results.detection_task == c] + roc_data = condition_results[condition_results.kernel_width == kernel_width] + contrasts = roc_data.contrast.unique() + + roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) + roc_ax.set_title(c, fontsize=9, loc="left") + auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) + + for c in contrasts: + tpr = roc_data.true_positives[roc_data.contrast == c].values[0] + fpr = roc_data.false_positives[roc_data.contrast == c].values[0] + roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) + if i == 0: + roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25) + roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) + roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) + roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) + roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) + roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) + if i == len(conditions) - 1: + roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) + roc_ax.set_xlabel("false positive rate", fontsize=9) + roc_ax.set_ylabel("true positive rate", fontsize=9) + roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) + for k in kernels: + contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) + aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) + aucs_sorted = aucs[np.argsort(contrasts)] + contrasts_sorted = np.sort(contrasts) + auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1) + if i == len(conditions) - 1: + auc_ax.set_xlabel("contrast [%]", fontsize=9) + auc_ax.set_ylim([0.25, 1.0]) + auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25)) + auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8) + auc_ax.set_ylabel("discriminability", fontsize=9) + if i == 0: + auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center") + auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0) + name = figure_name if figure_name is not None else "foreign_fish_detection.pdf" + name = (name + ".pdf") if ".pdf" not in name else name + fig.savefig(os.path.join(figure_folder, name)) -def plot_examples(filename, dfs=[], contrasts=[], conditions=[]): - # plot the responses - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name) - #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf" - #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name) + +def foreign_fish_detection_example_plot(): + files = glob.glob(os.path.join(data_folder, "*discriminations.h5")) + if len(files) == 0: + raise ValueError("no discrimination results found!") + store = pd.HDFStore(files[0]) + data_frame = store.get("discrimination_results") + embed() + plot_detection_results(data_frame, 20, 0.001, ) + pass + + +def performance_plot(args): + df = pd.read_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";") + dfs = np.sort(df.df.unique()) + contrasts = np.sort(df.contrast.unique()) + tasks = df.detection_task.unique() + kernel_widths = list(df.kernel_width.unique()) + kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0] + X, Y = np.meshgrid(dfs, contrasts) + Z = np.zeros_like(X) - # sketch showing the comparisons - #plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20) - - # plot the discrimination analyses - #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0] - # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, - # cell_name=cell_name, store_roc=True) - # pdf = pd.DataFrame(results) - # plot_detection_results(pdf, 20, 0.001, cell_name) + fig = plt.figure(figsize=(8.0, 10)) + fig_grid = (19, 10) + for index, t in enumerate(tasks): + ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" ) + ax.set_title(t, loc="left", pad=-0.5) + for i, d in enumerate(dfs): + for j, c in enumerate(contrasts): + Z[j, i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) + + surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0) + ax.set_xlabel(r"$\Delta_f [Hz]$") + ax.set_ylabel("contrast [%]") + ax.set_zlabel("performance") + ax.set_zlim([0.45, 1.0]) + ax.set_zticks(np.arange(0.5, 1.01, 0.25)) + ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True) + + cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2) + performances = np.zeros_like(contrasts) + errors = np.zeros_like(contrasts) + for d in args.deltafs: + for i, c in enumerate(contrasts): + performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) + errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) + cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d) + cntrst_ax.set_ylim([0.25, 1.0]) + cntrst_ax.set_ylabel("performance", fontsize=8) + cntrst_ax.set_xlabel("contrast [%]", fontsize=8) + cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2) + cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) + + df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2) + performances = np.zeros_like(dfs) + errors = np.zeros_like(dfs) + for c in args.contrasts: + for i, d in enumerate(dfs): + performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) + errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)]) + df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c) + df_ax.set_ylim([0.25, 1.0]) + df_ax.set_ylabel("performance", fontsize=8) + df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8) + df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2) + df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) + + fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975) + fig.savefig(args.outfile) + plt.close() + + +def main(): + cmd_map = {"comparisons": plot_comparisons, + "response_examples": response_examples, + "fish_detection_example": foreign_fish_detection_example_plot} - #nf.close() - pass \ No newline at end of file + parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.") + subparsers = parser.add_subparsers(title="commands", + help="Sub commands for plotting different figures", + description="", dest="explore_cmd") + comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons") + comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting") + comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot") + comp_parser.set_defaults(func=plot_comparisons) + + perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells") + perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot") + perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting") + perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-100, 20, 100], help="deltaf for individual plot") + perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot") + perf_parser.set_defaults(func=performance_plot) + + args = parser.parse_args() + args.func(args) + + + +if __name__ == "__main__": + main() diff --git a/response_discriminability.py b/response_discriminability.py index 828a0a9..b78a1a4 100644 --- a/response_discriminability.py +++ b/response_discriminability.py @@ -210,57 +210,6 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k return detection_performances -def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None): - cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] - conditions = sorted(cell_results.detection_task.unique()) - kernels = sorted(cell_results.kernel_width.unique()) - - fig = plt.figure(figsize=(6.5, 5.5)) - fig_grid = (8, 7) - for i, c in enumerate(conditions): - condition_results = cell_results[cell_results.detection_task == c] - roc_data = condition_results[condition_results.kernel_width == kernel_width] - contrasts = roc_data.contrast.unique() - - roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) - roc_ax.set_title(c, fontsize=9, loc="left") - auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) - - for c in contrasts: - tpr = roc_data.true_positives[roc_data.contrast == c].values[0] - fpr = roc_data.false_positives[roc_data.contrast == c].values[0] - roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) - if i == 0: - roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25) - roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) - roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) - roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) - roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) - roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) - if i == len(conditions) - 1: - roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) - roc_ax.set_xlabel("false positive rate", fontsize=9) - roc_ax.set_ylabel("true positive rate", fontsize=9) - roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) - - for k in kernels: - contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) - aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) - aucs_sorted = aucs[np.argsort(contrasts)] - contrasts_sorted = np.sort(contrasts) - auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1) - if i == len(conditions) - 1: - auc_ax.set_xlabel("contrast [%]", fontsize=9) - auc_ax.set_ylim([0.25, 1.0]) - auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25)) - auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8) - auc_ax.set_ylabel("discriminability", fontsize=9) - if i == 0: - auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center") - auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0) - name = figure_name if figure_name is not None else "foreign_fish_detection.pdf" - name = (name + ".pdf") if ".pdf" not in name else name - fig.savefig(os.path.join(figure_folder, name)) def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):