chirp_probing/plots.py
2020-09-27 17:16:24 +02:00

361 lines
18 KiB
Python

import glob
import os
import argparse
import nixio as nix
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
from IPython import embed
from nix_util import sort_blocks, read_baseline, get_signals
from util import despine, extract_am
from response_discriminability import get_firing_rate, foreign_fish_detection
figure_folder = "figures"
data_folder = "data"
def plot_comparisons(args):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
return
filename = files[0]
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 2.))
fig_grid = (3, len(all_conditions)*3+2)
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], args.current_df, condition)]
_, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
# ax.set_title(condition_labels[i])
ax.set_ylim([735, 885])
despine(ax, ["top", "bottom", "left", "right"], True)
axes.append(ax)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[0].add_collection(pc)
rects = []
rect = Rectangle((0.675, 740), 0.098, 140)
rects.append(rect)
rect = Rectangle((0.575, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[1].add_collection(pc)
rects = []
rect = Rectangle((0.57, 740), 0.098, 140)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[2].add_collection(pc)
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
axes[0].text(1., 660, "2.")
axes[1].text(1.05, 660, "3.")
axes[0].text(1.1, 890, "1.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(args.outfile)
plt.close()
nf.close()
def create_response_plot(filename, current_df=20, figure_name=None):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
plt.savefig(figure_name)
plt.close()
nf.close()
def response_examples(args):
files = sorted(glob.glob(args.cell + "*"))
if len(files) < 1:
raise ValueError("Cell data with name %s not found" % args.cell)
filename = files[0]
create_response_plot(filename, args.deltaf, figure_name=args.outfile)
def plot_detection_results(data_frame, df, kernel_width, cell=None, figure_name=None):
if cell is None:
cell = data_frame.cell.unique()[0]
dfs = np.sort(data_frame.df.unique())
if df not in dfs:
raise ValueError("requested deltaf not present, valid choices are: " + str(dfs))
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
conditions = sorted(cell_results.detection_task.unique())
kernels = sorted(cell_results.kernel_width.unique())
if kernel_width not in kernels:
raise ValueError("requested kernel not present, valid choices are: " + str(kernels))
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (8, 7)
for i, c in enumerate(conditions):
condition_results = cell_results[cell_results.detection_task == c]
roc_data = condition_results[condition_results.kernel_width == kernel_width]
contrasts = roc_data.contrast.unique()
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
roc_ax.set_title(c, fontsize=9, loc="left")
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
if i == 0:
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
if i == len(conditions) - 1:
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
roc_ax.set_xlabel("false positive rate", fontsize=9)
roc_ax.set_ylabel("true positive rate", fontsize=9)
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
for k in kernels:
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
if i == len(conditions) - 1:
auc_ax.set_xlabel("contrast [%]", fontsize=9)
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
auc_ax.set_ylabel("discriminability", fontsize=9)
if i == 0:
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
fig.savefig(figure_name)
def foreign_fish_detection_example_plot(args):
files = glob.glob(os.path.join(data_folder, "*discriminations.h5"))
if len(files) == 0:
raise ValueError("no discrimination results found!")
store = pd.HDFStore(files[0])
data_frame = store.get("discrimination_results")
plot_detection_results(data_frame, args.deltaf, args.kernel_width, figure_name=args.outfile)
store.close()
def performance_plot(args):
df = pd.read_csv(os.path.join(data_folder, "discrimination_results.csv"), sep=";")
dfs = np.sort(df.df.unique())
contrasts = np.sort(df.contrast.unique())
tasks = df.detection_task.unique()
kernel_widths = list(df.kernel_width.unique())
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
X, Y = np.meshgrid(dfs, contrasts)
Z = np.zeros_like(X)
fig = plt.figure(figsize=(8.0, 10))
fig_grid = (19, 10)
for index, t in enumerate(tasks):
ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" )
ax.set_title(t, loc="left", pad=-0.5)
for i, d in enumerate(dfs):
for j, c in enumerate(contrasts):
Z[j, i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0)
ax.set_xlabel(r"$\Delta_f [Hz]$")
ax.set_ylabel("contrast [%]")
ax.set_zlabel("performance")
ax.set_zlim([0.45, 1.0])
ax.set_zticks(np.arange(0.5, 1.01, 0.25))
ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True)
cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2)
performances = np.zeros_like(contrasts)
errors = np.zeros_like(contrasts)
for d in args.deltafs:
for i, c in enumerate(contrasts):
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta_f:$ %i Hz" % d)
cntrst_ax.set_ylim([0.25, 1.0])
cntrst_ax.set_ylabel("performance", fontsize=8)
cntrst_ax.set_xlabel("contrast [%]", fontsize=8)
cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2)
cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2)
performances = np.zeros_like(dfs)
errors = np.zeros_like(dfs)
for c in args.contrasts:
for i, d in enumerate(dfs):
performances[i] = np.mean(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
errors[i] = np.std(df.auc[(df.kernel_width == kernel_width) & (df.contrast == c) & (df.df == d) & (df.detection_task == t)])
df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c)
df_ax.set_ylim([0.25, 1.0])
df_ax.set_ylabel("performance", fontsize=8)
df_ax.set_xlabel(r"$\Delta_f$ [Hz]", fontsize=8)
df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2)
df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975)
fig.savefig(args.outfile)
plt.close()
def main():
cmd_map = {"comparisons": plot_comparisons,
"response_examples": response_examples,
"fish_detection_example": foreign_fish_detection_example_plot}
parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.")
subparsers = parser.add_subparsers(title="commands",
help="Sub commands for plotting different figures",
description="", dest="explore_cmd")
comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons")
comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting")
comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot")
comp_parser.set_defaults(func=plot_comparisons)
roc_parser = subparsers.add_parser("roc", help="plot roc analysis of example cell")
roc_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "roc_analysis.pdf"), help="filename of the plot")
roc_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot")
roc_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting, defaults to 0.001s")
roc_parser.set_defaults(func=foreign_fish_detection_example_plot)
perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells")
perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot")
perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting")
perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-100, 20, 100], help="deltaf for individual plot")
perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot")
perf_parser.set_defaults(func=performance_plot)
resps_parser = subparsers.add_parser("responses", help="plot responses from and example cell")
resps_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "response_example.pdf"), help="filename of the plot")
resps_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot")
dflt_cell = os.path.join(data_folder, "cell_2010-11-08-al")
resps_parser.add_argument("-c", "--cell", type=str, default=dflt_cell, help="cell name, defaults to %s" %dflt_cell)
resps_parser.set_defaults(func=response_examples)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()