extract more plotting to plot.py use argparse
This commit is contained in:
parent
d0bde3b673
commit
6b501618be
210
plots.py
210
plots.py
@ -1,19 +1,27 @@
|
||||
import glob
|
||||
import os
|
||||
import argparse
|
||||
import nixio as nix
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Rectangle
|
||||
from matplotlib.collections import PatchCollection
|
||||
from matplotlib.patches import ConnectionPatch
|
||||
from matplotlib import cm
|
||||
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
|
||||
|
||||
from IPython import embed
|
||||
|
||||
from nix_util import sort_blocks, read_baseline, get_signals
|
||||
from util import despine
|
||||
from util import despine, extract_am
|
||||
from response_discriminability import get_firing_rate, foreign_fish_detection
|
||||
|
||||
figure_folder = "figures"
|
||||
data_folder = "data"
|
||||
|
||||
|
||||
def plot_comparisons(current_df=20):
|
||||
def plot_comparisons(args):
|
||||
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
|
||||
if len(files) < 1:
|
||||
print("plot comparisons: no data found!")
|
||||
@ -30,7 +38,7 @@ def plot_comparisons(current_df=20):
|
||||
axes = []
|
||||
for i, condition in enumerate(conditions):
|
||||
# plot the signals
|
||||
block = block_map[(all_contrasts[0], current_df, condition)]
|
||||
block = block_map[(all_contrasts[0], args.current_df, condition)]
|
||||
_, self_freq, other_freq, time = get_signals(block)
|
||||
|
||||
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
|
||||
@ -88,17 +96,12 @@ def plot_comparisons(current_df=20):
|
||||
axes[1].text(1.05, 660, "3.")
|
||||
axes[0].text(1.1, 890, "1.")
|
||||
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
|
||||
fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
|
||||
fig.savefig(args.outfile)
|
||||
plt.close()
|
||||
nf.close()
|
||||
|
||||
|
||||
def create_response_plot(filename, current_df=20, figure_name=None):
|
||||
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
|
||||
if len(files) < 1:
|
||||
print("plot comparisons: no data found!")
|
||||
return
|
||||
filename = files[0]
|
||||
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
|
||||
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
|
||||
|
||||
@ -180,45 +183,164 @@ def create_response_plot(filename, current_df=20, figure_name=None):
|
||||
plt.close()
|
||||
nf.close()
|
||||
|
||||
def response_examples(*kwargs):
|
||||
|
||||
if filename in kwargs:
|
||||
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
|
||||
|
||||
def main(task=None, parameter={}):
|
||||
plot_tasks = {"comparisons": plot_comparisons,
|
||||
"response_examples": create_response_plot}
|
||||
if task is not None and task in plot_tasks.keys():
|
||||
plot_tasks[task](*parameter)
|
||||
elif task is None:
|
||||
for t in plot_tasks.keys():
|
||||
plot_tasks[t](*parameter)
|
||||
def response_examples():
|
||||
filename = sorted(glob.glob(os.path.join(data_folder, "*.nix")))[0]
|
||||
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
create_response_plot(filename, 20, figure_name=fig_name)
|
||||
fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
create_response_plot(filename, -100, figure_name=fig_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main("comparisons")
|
||||
def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None):
|
||||
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
|
||||
conditions = sorted(cell_results.detection_task.unique())
|
||||
kernels = sorted(cell_results.kernel_width.unique())
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 5.5))
|
||||
fig_grid = (8, 7)
|
||||
for i, c in enumerate(conditions):
|
||||
condition_results = cell_results[cell_results.detection_task == c]
|
||||
roc_data = condition_results[condition_results.kernel_width == kernel_width]
|
||||
contrasts = roc_data.contrast.unique()
|
||||
|
||||
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
|
||||
roc_ax.set_title(c, fontsize=9, loc="left")
|
||||
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
|
||||
|
||||
for c in contrasts:
|
||||
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
|
||||
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
|
||||
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
|
||||
if i == 0:
|
||||
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
|
||||
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
|
||||
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
|
||||
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
|
||||
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
|
||||
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
|
||||
if i == len(conditions) - 1:
|
||||
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
|
||||
roc_ax.set_xlabel("false positive rate", fontsize=9)
|
||||
roc_ax.set_ylabel("true positive rate", fontsize=9)
|
||||
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
|
||||
|
||||
for k in kernels:
|
||||
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
|
||||
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
|
||||
aucs_sorted = aucs[np.argsort(contrasts)]
|
||||
contrasts_sorted = np.sort(contrasts)
|
||||
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
|
||||
if i == len(conditions) - 1:
|
||||
auc_ax.set_xlabel("contrast [%]", fontsize=9)
|
||||
auc_ax.set_ylim([0.25, 1.0])
|
||||
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
|
||||
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
|
||||
auc_ax.set_ylabel("discriminability", fontsize=9)
|
||||
if i == 0:
|
||||
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
|
||||
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
|
||||
name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
|
||||
name = (name + ".pdf") if ".pdf" not in name else name
|
||||
fig.savefig(os.path.join(figure_folder, name))
|
||||
|
||||
def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
|
||||
# plot the responses
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
|
||||
#fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
|
||||
#create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
|
||||
|
||||
def foreign_fish_detection_example_plot():
|
||||
files = glob.glob(os.path.join(data_folder, "*discriminations.h5"))
|
||||
if len(files) == 0:
|
||||
raise ValueError("no discrimination results found!")
|
||||
store = pd.HDFStore(files[0])
|
||||
data_frame = store.get("discrimination_results")
|
||||
embed()
|
||||
plot_detection_results(data_frame, 20, 0.001, )
|
||||
pass
|
||||
|
||||
|
||||
def performance_plot(args):
|
||||
df = pd.read_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";")
|
||||
dfs = np.sort(df.df.unique())
|
||||
contrasts = np.sort(df.contrast.unique())
|
||||
tasks = df.detection_task.unique()
|
||||
kernel_widths = list(df.kernel_width.unique())
|
||||
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
|
||||
X, Y = np.meshgrid(dfs, contrasts)
|
||||
Z = np.zeros_like(X)
|
||||
|
||||
# sketch showing the comparisons
|
||||
#plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)
|
||||
|
||||
# plot the discrimination analyses
|
||||
#cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
|
||||
# results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20,
|
||||
# cell_name=cell_name, store_roc=True)
|
||||
# pdf = pd.DataFrame(results)
|
||||
# plot_detection_results(pdf, 20, 0.001, cell_name)
|
||||
fig = plt.figure(figsize=(8.0, 10))
|
||||
fig_grid = (19, 10)
|
||||
for index, t in enumerate(tasks):
|
||||
ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" )
|
||||
ax.set_title(t, loc="left", pad=-0.5)
|
||||
for i, d in enumerate(dfs):
|
||||
for j, c in enumerate(contrasts):
|
||||
Z[j, i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
|
||||
|
||||
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0)
|
||||
ax.set_xlabel(r"$\Delta_f [Hz]$")
|
||||
ax.set_ylabel("contrast [%]")
|
||||
ax.set_zlabel("performance")
|
||||
ax.set_zlim([0.45, 1.0])
|
||||
ax.set_zticks(np.arange(0.5, 1.01, 0.25))
|
||||
ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True)
|
||||
|
||||
cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2)
|
||||
performances = np.zeros_like(contrasts)
|
||||
errors = np.zeros_like(contrasts)
|
||||
for d in args.deltafs:
|
||||
for i, c in enumerate(contrasts):
|
||||
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
|
||||
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
|
||||
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d)
|
||||
cntrst_ax.set_ylim([0.25, 1.0])
|
||||
cntrst_ax.set_ylabel("performance", fontsize=8)
|
||||
cntrst_ax.set_xlabel("contrast [%]", fontsize=8)
|
||||
cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2)
|
||||
cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
|
||||
|
||||
df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2)
|
||||
performances = np.zeros_like(dfs)
|
||||
errors = np.zeros_like(dfs)
|
||||
for c in args.contrasts:
|
||||
for i, d in enumerate(dfs):
|
||||
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
|
||||
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
|
||||
df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c)
|
||||
df_ax.set_ylim([0.25, 1.0])
|
||||
df_ax.set_ylabel("performance", fontsize=8)
|
||||
df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8)
|
||||
df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2)
|
||||
df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
|
||||
|
||||
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975)
|
||||
fig.savefig(args.outfile)
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
cmd_map = {"comparisons": plot_comparisons,
|
||||
"response_examples": response_examples,
|
||||
"fish_detection_example": foreign_fish_detection_example_plot}
|
||||
|
||||
#nf.close()
|
||||
pass
|
||||
parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.")
|
||||
subparsers = parser.add_subparsers(title="commands",
|
||||
help="Sub commands for plotting different figures",
|
||||
description="", dest="explore_cmd")
|
||||
comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons")
|
||||
comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting")
|
||||
comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot")
|
||||
comp_parser.set_defaults(func=plot_comparisons)
|
||||
|
||||
perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells")
|
||||
perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot")
|
||||
perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting")
|
||||
perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-100, 20, 100], help="deltaf for individual plot")
|
||||
perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot")
|
||||
perf_parser.set_defaults(func=performance_plot)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -210,57 +210,6 @@ def foreign_fish_detection_chirp(block_map, df, all_contrasts, all_conditions, k
|
||||
return detection_performances
|
||||
|
||||
|
||||
def plot_detection_results(data_frame, df, kernel_width, cell, figure_name=None):
|
||||
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
|
||||
conditions = sorted(cell_results.detection_task.unique())
|
||||
kernels = sorted(cell_results.kernel_width.unique())
|
||||
|
||||
fig = plt.figure(figsize=(6.5, 5.5))
|
||||
fig_grid = (8, 7)
|
||||
for i, c in enumerate(conditions):
|
||||
condition_results = cell_results[cell_results.detection_task == c]
|
||||
roc_data = condition_results[condition_results.kernel_width == kernel_width]
|
||||
contrasts = roc_data.contrast.unique()
|
||||
|
||||
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
|
||||
roc_ax.set_title(c, fontsize=9, loc="left")
|
||||
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
|
||||
|
||||
for c in contrasts:
|
||||
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
|
||||
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
|
||||
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
|
||||
if i == 0:
|
||||
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
|
||||
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
|
||||
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
|
||||
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
|
||||
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
|
||||
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
|
||||
if i == len(conditions) - 1:
|
||||
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
|
||||
roc_ax.set_xlabel("false positive rate", fontsize=9)
|
||||
roc_ax.set_ylabel("true positive rate", fontsize=9)
|
||||
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
|
||||
|
||||
for k in kernels:
|
||||
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
|
||||
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
|
||||
aucs_sorted = aucs[np.argsort(contrasts)]
|
||||
contrasts_sorted = np.sort(contrasts)
|
||||
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
|
||||
if i == len(conditions) - 1:
|
||||
auc_ax.set_xlabel("contrast [%]", fontsize=9)
|
||||
auc_ax.set_ylim([0.25, 1.0])
|
||||
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
|
||||
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
|
||||
auc_ax.set_ylabel("discriminability", fontsize=9)
|
||||
if i == 0:
|
||||
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
|
||||
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
|
||||
name = figure_name if figure_name is not None else "foreign_fish_detection.pdf"
|
||||
name = (name + ".pdf") if ".pdf" not in name else name
|
||||
fig.savefig(os.path.join(figure_folder, name))
|
||||
|
||||
|
||||
def foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=None, cell_name="", store_roc=False):
|
||||
|
Loading…
Reference in New Issue
Block a user