import glob import os import argparse import nixio as nix import numpy as np import pandas as pd import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection from matplotlib.patches import ConnectionPatch from matplotlib import cm from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import from IPython import embed from nix_util import sort_blocks, read_baseline, get_signals from util import despine, extract_am from response_discriminability import get_firing_rate, foreign_fish_detection figure_folder = "figures" data_folder = "data" def plot_comparisons(args): files = sorted(glob.glob(os.path.join(data_folder, "*.nix"))) if len(files) < 1: print("plot comparisons: no data found!") return filename = files[0] nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, _, _, all_conditions = sort_blocks(nf) conditions = ["no-other", "self", "other"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 2.)) fig_grid = (3, len(all_conditions)*3+2) axes = [] for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], args.deltaf, args.chirpsize, condition)] _, self_freq, other_freq, time = get_signals(block) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) # ax.set_title(condition_labels[i]) ax.set_ylim([735, 895]) despine(ax, ["top", "bottom", "left", "right"], True) axes.append(ax) rects = [] rect = Rectangle((0.675, 740), 0.098, 150) rects.append(rect) rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[0].add_collection(pc) axes[0].text(0.625, 860, "a)", ha="center", fontsize=7) axes[0].text(0.724, 860, "b)", ha="center", fontsize=7) rects = [] rect = Rectangle((0.675, 740), 0.098, 150) rects.append(rect) rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[1].add_collection(pc) axes[1].text(0.625, 860, "c)", ha="center", fontsize=7) axes[1].text(0.724, 860, "d)", ha="center", fontsize=7) rects = [] rect = Rectangle((0.57, 740), 0.098, 150) rects.append(rect) pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--") axes[2].add_collection(pc) axes[2].text(0.625, 860, "e)", ha="center", fontsize=7) con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 895), xyB=(0.725, 890), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.725, 890), coordsA="data", coordsB="data", axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.35") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.615, 735), xyB=(0.735, 745), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=1.") axes[1].add_artist(con) con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.625, 895), coordsA="data", coordsB="data", axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.35") axes[1].add_artist(con) axes[0].text(1., 655, "2.") axes[1].text(1.05, 655, "3.") axes[0].text(1.1, 895, "1.") axes[0].text(0.6, 925, "4.") axes[1].text(0.675, 680, "5.") axes[1].text(1.05, 895, "6.") fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9) fig.savefig(args.outfile) plt.close() nf.close() def create_response_plot(filename, current_df=20, figure_name=None): nf = nix.File.open(filename, nix.FileMode.ReadOnly) block_map, all_contrasts, _, _, all_conditions = sort_blocks(nf) conditions = ["no-other", "self", "other"] condition_labels = ["soliloquy", "self chirping", "other chirping"] min_time = 0.5 max_time = min_time + 0.5 fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2) all_contrasts = sorted(all_contrasts, reverse=True) for i, condition in enumerate(conditions): # plot the signals block = block_map[(all_contrasts[0], current_df, condition)] signal, self_freq, other_freq, time = get_signals(block) am = extract_am(signal) self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"] other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"] # plot frequency traces ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)], color="#ff7f0e", label="%iHz" % self_eodf) ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9) if other_freq is not None: ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)], color="#1f77b4", label="%iHz" % other_eodf) ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9) ax.set_title(condition_labels[i]) despine(ax, ["top", "bottom", "left", "right"], True) # plot the am ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig) ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)], color="#2ca02c", label="signal") ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)], color="#d62728", label="am") despine(ax, ["top", "bottom", "left", "right"], True) ax.set_ylim([-1.25, 1.25]) ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False) # for each contrast plot the firing rate for j, contrast in enumerate(all_contrasts): t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition) avg_resp = np.mean(rates, axis=0) error = np.std(rates, axis=0) ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3) ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5) ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)], (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25) ax.set_ylim([0, 750]) ax.set_xlabel("") ax.set_ylabel("") ax.set_xticks(np.arange(min_time, max_time+.01, 0.250)) ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000)) ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True) if j < len(all_contrasts) -1: ax.set_xticklabels([]) ax.set_yticks(np.arange(0.0, 751., 500)) ax.set_yticks(np.arange(0.0, 751., 125), minor=True) if i > 0: ax.set_yticklabels([]) despine(ax, ["top", "right"], False) if i == 2: ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], color="#d62728", ha="left", fontsize=7) if i == 1: ax.set_xlabel("time [ms]") if i == 0: ax.set_ylabel("frequency [Hz]", va="center") ax.yaxis.set_label_coords(-0.45, 3.5) plt.savefig(figure_name) plt.close() nf.close() def response_examples(args): files = sorted(glob.glob(args.cell + "*")) if len(files) < 1: raise ValueError("Cell data with name %s not found" % args.cell) filename = files[0] create_response_plot(filename, args.deltaf, figure_name=args.outfile) def plot_detection_results(data_frame, df, kernel_width, cell=None, figure_name=None): if cell is None: cell = data_frame.cell.unique()[0] dfs = np.sort(data_frame.df.unique()) if df not in dfs: raise ValueError("requested deltaf not present, valid choices are: " + str(dfs)) cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)] conditions = sorted(cell_results.detection_task.unique()) kernels = sorted(cell_results.kernel_width.unique()) if kernel_width not in kernels: raise ValueError("requested kernel not present, valid choices are: " + str(kernels)) fig = plt.figure(figsize=(6.5, 5.5)) fig_grid = (8, 7) for i, c in enumerate(conditions): condition_results = cell_results[cell_results.detection_task == c] roc_data = condition_results[condition_results.kernel_width == kernel_width] contrasts = roc_data.contrast.unique() roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2) roc_ax.set_title(c, fontsize=9, loc="left") auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2) for c in contrasts: tpr = roc_data.true_positives[roc_data.contrast == c].values[0] fpr = roc_data.false_positives[roc_data.contrast == c].values[0] roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2) if i == 0: roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25) roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5)) roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True) if i == len(conditions) - 1: roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) roc_ax.set_xlabel("false positive rate", fontsize=9) roc_ax.set_ylabel("true positive rate", fontsize=9) roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8) for k in kernels: contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k]) aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k]) aucs_sorted = aucs[np.argsort(contrasts)] contrasts_sorted = np.sort(contrasts) auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1) if i == len(conditions) - 1: auc_ax.set_xlabel("contrast [%]", fontsize=9) auc_ax.set_ylim([0.25, 1.0]) auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25)) auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8) auc_ax.set_ylabel("discriminability", fontsize=9) if i == 0: auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center") auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0) fig.savefig(figure_name) def foreign_fish_detection_example_plot(args): files = glob.glob(os.path.join(data_folder, "*discriminations.h5")) if len(files) == 0: raise ValueError("no discrimination results found!") store = pd.HDFStore(files[0]) data_frame = store.get("discrimination_results") plot_detection_results(data_frame, args.deltaf, args.kernel_width, figure_name=args.outfile) store.close() def plot_surfaces(data_frame, dfs, contrasts, tasks, selected_dfs, selected_contrasts, kernel_width, chirpsize, filename): X, Y = np.meshgrid(dfs, contrasts) Z = np.zeros_like(X) fig = plt.figure(figsize=(8.0, 10)) fig_grid = (19, 10) for index, t in enumerate(tasks): ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" ) ax.set_title(t, loc="left", pad=-0.5) for i, d in enumerate(dfs): for j, c in enumerate(contrasts): data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] Z[j, i] = np.mean(data_df.auc) ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0) ax.set_xlabel(r"$\Delta f [Hz]$", fontsize=8) ax.set_ylabel("contrast [%]", fontsize=8) ax.set_zlabel("performance", fontsize=8, rotation=180) ax.set_zlim([0.45, 1.0]) ax.set_zticks(np.arange(0.5, 1.01, 0.25)) ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True) cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2) performances = np.zeros_like(contrasts) errors = np.zeros_like(contrasts) for d in selected_dfs: for i, c in enumerate(contrasts): data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] performances[i] = np.mean(data_df.auc) errors[i] = np.std(data_df.auc) cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta f:$ %i Hz" % d) cntrst_ax.set_ylim([0.25, 1.0]) cntrst_ax.set_ylabel("performance", fontsize=8) cntrst_ax.set_xlabel("contrast [%]", fontsize=8) cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2) cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2) performances = np.zeros_like(dfs) errors = np.zeros_like(dfs) for c in selected_contrasts: for i, d in enumerate(dfs): data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)] performances[i] = np.mean(data_df.auc) errors[i] = np.std(data_df.auc) df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c) df_ax.set_ylim([0.25, 1.0]) df_ax.set_ylabel("performance", fontsize=8) df_ax.set_xlabel(r"$\Delta f$ [Hz]", fontsize=8) df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2) df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25) fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975) fig.savefig(filename) plt.close() def performance_plot(args): if not os.path.exists(args.inputfile): raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile) df = pd.read_csv(args.inputfile, sep=";") all_dfs = np.sort(df.df.unique()) all_contrasts = np.sort(df.contrast.unique()) tasks = df.detection_task.unique() kernel_widths = list(df.kernel_width.unique()) kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0] chirpsizes = list(df.chirpsize.unique()) if args.chirpsize not in chirpsizes: raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes)) selected_dfs = args.deltafs selected_contrasts = args.contrasts filename = args.outfile temp = filename.split('.') filename = temp[0] + "_1." + temp[-1] plot_surfaces(df, all_dfs, all_contrasts, tasks[:3], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename) filename = temp[0] + "_2." + temp[-1] plot_surfaces(df, all_dfs, all_contrasts, tasks[3:], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename) def main(): parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.") subparsers = parser.add_subparsers(title="commands", help="Sub commands for plotting different figures", description="", dest="explore_cmd") comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons") comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting. Defaults to 20 Hz") comp_parser.add_argument("-cs", "--chirpsize", type=int, default=60, help="The chirpsize. Defaults to 60 Hz.") comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot") comp_parser.set_defaults(func=plot_comparisons) roc_parser = subparsers.add_parser("roc", help="plot roc analysis of example cell") roc_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "roc_analysis.pdf"), help="filename of the plot") roc_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot") roc_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting, defaults to 0.001s") roc_parser.set_defaults(func=foreign_fish_detection_example_plot) perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells") perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot") infile = os.path.join(data_folder, "discrimination_results.csv") perf_parser.add_argument("-i", "--inputfile", default=infile , help="Filename of the file containing the discrimination results. Defaults to %s" % infile) perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting") perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-200, 5, 200], help="deltaf for individual plot") perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot") perf_parser.add_argument("-cs", "--chirpsize", type=int, default=60, help="The chirpsize. Defaults to 60Hz.") perf_parser.set_defaults(func=performance_plot) resps_parser = subparsers.add_parser("responses", help="plot responses from and example cell") resps_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "response_example.pdf"), help="filename of the plot") resps_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot") dflt_cell = os.path.join(data_folder, "cell_2010-11-08-al") resps_parser.add_argument("-c", "--cell", type=str, default=dflt_cell, help="cell name, defaults to %s" %dflt_cell) resps_parser.set_defaults(func=response_examples) args = parser.parse_args() args.func(args) if __name__ == "__main__": main()