extract more plotting to plot.py use argparse

This commit is contained in:
Jan Grewe 2020-09-27 16:32:34 +02:00
parent d0bde3b673
commit 6b501618be
2 changed files with 166 additions and 95 deletions

210
plots.py
View File

@ -1,19 +1,27 @@
import glob
import os
import argparse
import nixio as nix
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
from IPython import embed
from nix_util import sort_blocks, read_baseline, get_signals
from util import despine
from util import despine, extract_am
from response_discriminability import get_firing_rate, foreign_fish_detection
figure_folder = "figures"
data_folder = "data"
def plot_comparisons(current_df=20):
def plot_comparisons(args):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
@ -30,7 +38,7 @@ def plot_comparisons(current_df=20):
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
block = block_map[(all_contrasts[0], args.current_df, condition)]
_, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
@ -88,17 +96,12 @@ def plot_comparisons(current_df=20):
axes[1].text(1.05, 660, "3.")
axes[0].text(1.1, 890, "1.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
fig.savefig(args.outfile)
plt.close()
nf.close()
def create_response_plot(filename, current_df=20, figure_name=None):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
return
filename = files[0]
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
@ -180,45 +183,164 @@ def create_response_plot(filename, current_df=20, figure_name=None):
plt.close()
nf.close()
def response_examples(*kwargs):
if filename in kwargs:
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
def main(task=None, parameter={}):
plot_tasks = {"comparisons": plot_comparisons,
"response_examples": create_response_plot}
if task is not None and task in plot_tasks.keys():
plot_tasks[task](*parameter)
elif task is None:
for t in plot_tasks.keys():
plot_tasks[t](*parameter)
def response_examples():
filename = sorted(glob.glob(os.path.join(data_folder, "*.nix")))[0]
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
create_response_plot(filename, 20, figure_name=fig_name)
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
create_response_plot(filename, -100, figure_name=fig_name)
if __name__ == "__main__":
main("comparisons")
def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None):
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
conditions = sorted(cell_results.detection_task.unique())
kernels = sorted(cell_results.kernel_width.unique())
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (8, 7)
for i, c in enumerate(conditions):
condition_results = cell_results[cell_results.detection_task == c]
roc_data = condition_results[condition_results.kernel_width == kernel_width]
contrasts = roc_data.contrast.unique()
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
roc_ax.set_title(c, fontsize=9, loc="left")
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
if i == 0:
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
if i == len(conditions) - 1:
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
roc_ax.set_xlabel("false positive rate", fontsize=9)
roc_ax.set_ylabel("true positive rate", fontsize=9)
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
for k in kernels:
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
if i == len(conditions) - 1:
auc_ax.set_xlabel("contrast [%]", fontsize=9)
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
auc_ax.set_ylabel("discriminability", fontsize=9)
if i == 0:
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
fig.savefig(os.path.join(figure_folder, name))
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
# plot the responses
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
def foreign_fish_detection_example_plot():
files = glob.glob(os.path.join(data_folder, "*discriminations.h5"))
if len(files) == 0:
raise ValueError("no discrimination results found!")
store = pd.HDFStore(files[0])
data_frame = store.get("discrimination_results")
embed()
plot_detection_results(data_frame, 20, 0.001, )
pass
def performance_plot(args):
df = pd.read_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";")
dfs = np.sort(df.df.unique())
contrasts = np.sort(df.contrast.unique())
tasks = df.detection_task.unique()
kernel_widths = list(df.kernel_width.unique())
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
X, Y = np.meshgrid(dfs, contrasts)
Z = np.zeros_like(X)
# sketch showing the comparisons
#plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
# plot the discrimination analyses
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
# cell_name=cell_name, store_roc=True)
# pdf = pd.DataFrame(results)
# plot_detection_results(pdf, 20, 0.001, cell_name)
fig = plt.figure(figsize=(8.0, 10))
fig_grid = (19, 10)
for index, t in enumerate(tasks):
ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" )
ax.set_title(t, loc="left", pad=-0.5)
for i, d in enumerate(dfs):
for j, c in enumerate(contrasts):
Z[j, i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0)
ax.set_xlabel(r"$\Delta_f [Hz]$")
ax.set_ylabel("contrast [%]")
ax.set_zlabel("performance")
ax.set_zlim([0.45, 1.0])
ax.set_zticks(np.arange(0.5, 1.01, 0.25))
ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True)
cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2)
performances = np.zeros_like(contrasts)
errors = np.zeros_like(contrasts)
for d in args.deltafs:
for i, c in enumerate(contrasts):
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d)
cntrst_ax.set_ylim([0.25, 1.0])
cntrst_ax.set_ylabel("performance", fontsize=8)
cntrst_ax.set_xlabel("contrast [%]", fontsize=8)
cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2)
cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2)
performances = np.zeros_like(dfs)
errors = np.zeros_like(dfs)
for c in args.contrasts:
for i, d in enumerate(dfs):
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c)
df_ax.set_ylim([0.25, 1.0])
df_ax.set_ylabel("performance", fontsize=8)
df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8)
df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2)
df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975)
fig.savefig(args.outfile)
plt.close()
def main():
cmd_map = {"comparisons": plot_comparisons,
"response_examples": response_examples,
"fish_detection_example": foreign_fish_detection_example_plot}
#nf.close()
pass
parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.")
subparsers = parser.add_subparsers(title="commands",
help="Sub commands for plotting different figures",
description="", dest="explore_cmd")
comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons")
comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting")
comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot")
comp_parser.set_defaults(func=plot_comparisons)
perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells")
perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot")
perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting")
perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-100, 20, 100], help="deltaf for individual plot")
perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot")
perf_parser.set_defaults(func=performance_plot)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()

View File

@ -210,57 +210,6 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
return detection_performances
def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None):
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
conditions = sorted(cell_results.detection_task.unique())
kernels = sorted(cell_results.kernel_width.unique())
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (8, 7)
for i, c in enumerate(conditions):
condition_results = cell_results[cell_results.detection_task == c]
roc_data = condition_results[condition_results.kernel_width == kernel_width]
contrasts = roc_data.contrast.unique()
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
roc_ax.set_title(c, fontsize=9, loc="left")
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
if i == 0:
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
if i == len(conditions) - 1:
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
roc_ax.set_xlabel("false positive rate", fontsize=9)
roc_ax.set_ylabel("true positive rate", fontsize=9)
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
for k in kernels:
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
if i == len(conditions) - 1:
auc_ax.set_xlabel("contrast [%]", fontsize=9)
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
auc_ax.set_ylabel("discriminability", fontsize=9)
if i == 0:
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
name = (name + ".pdf") if ".pdf" not in name else name
fig.savefig(os.path.join(figure_folder, name))
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):