chirp_probing/plots.py
2023-04-01 18:16:13 +02:00

404 lines
21 KiB
Python

import glob
import os
import argparse
import nixio as nix
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
from IPython import embed
from nix_util import sort_blocks, read_baseline, get_signals
from util import despine, extract_am
from response_discriminability import get_firing_rate, foreign_fish_detection
figure_folder = "figures"
data_folder = "data"
def plot_comparisons(args):
files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
if len(files) < 1:
print("plot comparisons: no data found!")
return
filename = files[0]
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 2.))
fig_grid = (3, len(all_conditions)*3+2)
axes = []
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], args.deltaf, args.chirpsize, condition)]
_, self_freq, other_freq, time = get_signals(block)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
# ax.set_title(condition_labels[i])
ax.set_ylim([735, 895])
despine(ax, ["top", "bottom", "left", "right"], True)
axes.append(ax)
rects = []
rect = Rectangle((0.675, 740), 0.098, 150)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 150)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[0].add_collection(pc)
axes[0].text(0.625, 860, "a)", ha="center", fontsize=7)
axes[0].text(0.724, 860, "b)", ha="center", fontsize=7)
rects = []
rect = Rectangle((0.675, 740), 0.098, 150)
rects.append(rect)
rect = Rectangle((0.57, 740), 0.098, 150)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[1].add_collection(pc)
axes[1].text(0.625, 860, "c)", ha="center", fontsize=7)
axes[1].text(0.724, 860, "d)", ha="center", fontsize=7)
rects = []
rect = Rectangle((0.57, 740), 0.098, 150)
rects.append(rect)
pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
axes[2].add_collection(pc)
axes[2].text(0.625, 860, "e)", ha="center", fontsize=7)
con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 895), xyB=(0.725, 890), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5,
connectionstyle="arc3,rad=-.25")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.725, 890), coordsA="data", coordsB="data",
axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5,
connectionstyle="arc3,rad=-.35")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.615, 735), xyB=(0.735, 745), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[1], arrowstyle="<->", shrinkB=5,
connectionstyle="arc3,rad=1.")
axes[1].add_artist(con)
con = ConnectionPatch(xyA=(0.625, 895), xyB=(0.625, 895), coordsA="data", coordsB="data",
axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5,
connectionstyle="arc3,rad=-.35")
axes[1].add_artist(con)
axes[0].text(1., 655, "2.")
axes[1].text(1.05, 655, "3.")
axes[0].text(1.1, 895, "1.")
axes[0].text(0.6, 925, "4.")
axes[1].text(0.675, 680, "5.")
axes[1].text(1.05, 895, "6.")
fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
fig.savefig(args.outfile)
plt.close()
nf.close()
def create_response_plot(filename, current_df=20, figure_name=None):
nf = nix.File.open(filename, nix.FileMode.ReadOnly)
block_map, all_contrasts, _, _, all_conditions = sort_blocks(nf)
conditions = ["no-other", "self", "other"]
condition_labels = ["soliloquy", "self chirping", "other chirping"]
min_time = 0.5
max_time = min_time + 0.5
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
all_contrasts = sorted(all_contrasts, reverse=True)
for i, condition in enumerate(conditions):
# plot the signals
block = block_map[(all_contrasts[0], current_df, condition)]
signal, self_freq, other_freq, time = get_signals(block)
am = extract_am(signal)
self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
# plot frequency traces
ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
color="#ff7f0e", label="%iHz" % self_eodf)
ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
if other_freq is not None:
ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
color="#1f77b4", label="%iHz" % other_eodf)
ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)
ax.set_title(condition_labels[i])
despine(ax, ["top", "bottom", "left", "right"], True)
# plot the am
ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
color="#2ca02c", label="signal")
ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
color="#d62728", label="am")
despine(ax, ["top", "bottom", "left", "right"], True)
ax.set_ylim([-1.25, 1.25])
ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
# for each contrast plot the firing rate
for j, contrast in enumerate(all_contrasts):
t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
avg_resp = np.mean(rates, axis=0)
error = np.std(rates, axis=0)
ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
(avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
ax.set_ylim([0, 750])
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
if j < len(all_contrasts) -1:
ax.set_xticklabels([])
ax.set_yticks(np.arange(0.0, 751., 500))
ax.set_yticks(np.arange(0.0, 751., 125), minor=True)
if i > 0:
ax.set_yticklabels([])
despine(ax, ["top", "right"], False)
if i == 2:
ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j],
color="#d62728", ha="left", fontsize=7)
if i == 1:
ax.set_xlabel("time [ms]")
if i == 0:
ax.set_ylabel("frequency [Hz]", va="center")
ax.yaxis.set_label_coords(-0.45, 3.5)
plt.savefig(figure_name)
plt.close()
nf.close()
def response_examples(args):
files = sorted(glob.glob(args.cell + "*"))
if len(files) < 1:
raise ValueError("Cell data with name %s not found" % args.cell)
filename = files[0]
create_response_plot(filename, args.deltaf, figure_name=args.outfile)
def plot_detection_results(data_frame, df, kernel_width, cell=None, figure_name=None):
if cell is None:
cell = data_frame.cell.unique()[0]
dfs = np.sort(data_frame.df.unique())
if df not in dfs:
raise ValueError("requested deltaf not present, valid choices are: " + str(dfs))
cell_results = data_frame[(data_frame.cell == cell) & (data_frame.df == df)]
conditions = sorted(cell_results.detection_task.unique())
kernels = sorted(cell_results.kernel_width.unique())
if kernel_width not in kernels:
raise ValueError("requested kernel not present, valid choices are: " + str(kernels))
fig = plt.figure(figsize=(6.5, 5.5))
fig_grid = (8, 7)
for i, c in enumerate(conditions):
condition_results = cell_results[cell_results.detection_task == c]
roc_data = condition_results[condition_results.kernel_width == kernel_width]
contrasts = roc_data.contrast.unique()
roc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 0), colspan=3, rowspan=2)
roc_ax.set_title(c, fontsize=9, loc="left")
auc_ax = plt.subplot2grid(fig_grid, (i * 2 + i, 4), colspan=3, rowspan=2)
for c in contrasts:
tpr = roc_data.true_positives[roc_data.contrast == c].values[0]
fpr = roc_data.false_positives[roc_data.contrast == c].values[0]
roc_ax.plot(fpr, tpr, label="%.3f" % c, zorder=2)
if i == 0:
roc_ax.legend(loc="lower right", fontsize=6, ncol=2, frameon=False, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25)
roc_ax.plot([0., 1.],[0., 1.], color="k", lw=0.5, ls="--", zorder=0)
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_xticks(np.arange(0.0, 1.01, 0.25), minor=True)
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.5))
roc_ax.set_yticks(np.arange(0.0, 1.01, 0.25), minor=True)
if i == len(conditions) - 1:
roc_ax.set_xticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
roc_ax.set_xlabel("false positive rate", fontsize=9)
roc_ax.set_ylabel("true positive rate", fontsize=9)
roc_ax.set_yticklabels(np.arange(0.0, 1.01, 0.5), fontsize=8)
for k in kernels:
contrasts = np.asarray(condition_results.contrast[condition_results.kernel_width == k])
aucs = np.asarray(condition_results.auc[condition_results.kernel_width == k])
aucs_sorted = aucs[np.argsort(contrasts)]
contrasts_sorted = np.sort(contrasts)
auc_ax.plot(contrasts_sorted, aucs_sorted, marker=".", label=r"$\sigma$: %.2f ms" % (k * 1000), zorder=1)
if i == len(conditions) - 1:
auc_ax.set_xlabel("contrast [%]", fontsize=9)
auc_ax.set_ylim([0.25, 1.0])
auc_ax.set_yticks(np.arange(0.25, 1.01, 0.25))
auc_ax.set_yticklabels(np.arange(0.25, 1.01, 0.25), fontsize=8)
auc_ax.set_ylabel("discriminability", fontsize=9)
if i == 0:
auc_ax.legend(ncol=2, fontsize=6, handletextpad=0.4, columnspacing=1.0, labelspacing=0.25, frameon=False, loc="lower center")
auc_ax.plot([min(contrasts), max(contrasts)], [0.5, 0.5], lw=0.5, ls="--", zorder=0)
fig.savefig(figure_name)
def foreign_fish_detection_example_plot(args):
files = glob.glob(os.path.join(data_folder, "*discriminations.h5"))
if len(files) == 0:
raise ValueError("no discrimination results found!")
store = pd.HDFStore(files[0])
data_frame = store.get("discrimination_results")
plot_detection_results(data_frame, args.deltaf, args.kernel_width, figure_name=args.outfile)
store.close()
def plot_surfaces(data_frame, dfs, contrasts, tasks, selected_dfs, selected_contrasts, kernel_width, chirpsize, filename):
X, Y = np.meshgrid(dfs, contrasts)
Z = np.zeros_like(X)
fig = plt.figure(figsize=(8.0, 10))
fig_grid = (19, 10)
for index, t in enumerate(tasks):
ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 0), colspan=4, rowspan=5, projection="3d" )
ax.set_title(t, loc="left", pad=-0.5)
for i, d in enumerate(dfs):
for j, c in enumerate(contrasts):
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
Z[j, i] = np.mean(data_df.auc)
ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=0.2, edgecolor="white", antialiased=True, alpha=0.85, vmin=0.5, vmax=1.0)
ax.set_xlabel(r"$\Delta f [Hz]$", fontsize=8)
ax.set_ylabel("contrast [%]", fontsize=8)
ax.set_zlabel("performance", fontsize=8, rotation=180)
ax.set_zlim([0.45, 1.0])
ax.set_zticks(np.arange(0.5, 1.01, 0.25))
ax.set_zticks(np.arange(0.5, 1.01, 0.125), minor=True)
cntrst_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2, 6), colspan=4, rowspan=2)
performances = np.zeros_like(contrasts)
errors = np.zeros_like(contrasts)
for d in selected_dfs:
for i, c in enumerate(contrasts):
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
performances[i] = np.mean(data_df.auc)
errors[i] = np.std(data_df.auc)
cntrst_ax.errorbar(contrasts, performances, yerr=errors, fmt=".-", label=r"$\Delta f:$ %i Hz" % d)
cntrst_ax.set_ylim([0.25, 1.0])
cntrst_ax.set_ylabel("performance", fontsize=8)
cntrst_ax.set_xlabel("contrast [%]", fontsize=8)
cntrst_ax.hlines(0.5, contrasts[0], contrasts[-1], color="k", ls="--", lw=0.2)
cntrst_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
df_ax = plt.subplot2grid(fig_grid, (index * 5 + index * 2 + 3, 6), colspan=4, rowspan=2)
performances = np.zeros_like(dfs)
errors = np.zeros_like(dfs)
for c in selected_contrasts:
for i, d in enumerate(dfs):
data_df = data_frame[(data_frame.kernel_width == kernel_width) & (data_frame.contrast == c) & (data_frame.df == d) & (data_frame.detection_task == t) & (data_frame.chirpsize == chirpsize)]
performances[i] = np.mean(data_df.auc)
errors[i] = np.std(data_df.auc)
df_ax.errorbar(dfs, performances, yerr=errors, fmt=".-", label="%.2f" % c)
df_ax.set_ylim([0.25, 1.0])
df_ax.set_ylabel("performance", fontsize=8)
df_ax.set_xlabel(r"$\Delta f$ [Hz]", fontsize=8)
df_ax.hlines(0.5, dfs[0], dfs[-1], color="k", ls="--", lw=0.2)
df_ax.legend(fontsize=7, ncol=4, frameon=False, loc="lower center", mode="expand", handlelength=1.0, handletextpad=0.25)
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.975)
fig.savefig(filename)
plt.close()
def performance_plot(args):
if not os.path.exists(args.inputfile):
raise ValueError("Error plotting discrimination performance. Input file (%s) not found!" % args.inputfile)
df = pd.read_csv(args.inputfile, sep=";")
all_dfs = np.sort(df.df.unique())
all_contrasts = np.sort(df.contrast.unique())
tasks = df.detection_task.unique()
kernel_widths = list(df.kernel_width.unique())
kernel_width = args.kernel_width if args.kernel_width in kernel_widths else kernel_widths[0]
chirpsizes = list(df.chirpsize.unique())
if args.chirpsize not in chirpsizes:
raise ValueError("Error plotting discrimination performance. Requested chirpsize (%i Hz) is not found in the data. Available chirpsizes are: " % args.chirpsize + str(chirpsizes))
selected_dfs = args.deltafs
selected_contrasts = args.contrasts
filename = args.outfile
temp = filename.split('.')
filename = temp[0] + "_1." + temp[-1]
plot_surfaces(df, all_dfs, all_contrasts, tasks[:3], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename)
filename = temp[0] + "_2." + temp[-1]
plot_surfaces(df, all_dfs, all_contrasts, tasks[3:], selected_dfs, selected_contrasts, kernel_width, args.chirpsize, filename)
def main():
parser = argparse.ArgumentParser(description="Plotting tool for chrip probing project.")
subparsers = parser.add_subparsers(title="commands",
help="Sub commands for plotting different figures",
description="", dest="explore_cmd")
comp_parser = subparsers.add_parser("comparisons", help="Create a didactic plot illustrating the comparisons")
comp_parser.add_argument("-df", "--deltaf", type=int, default=20, help="The difference frequency to used for plotting. Defaults to 20 Hz")
comp_parser.add_argument("-cs", "--chirpsize", type=int, default=60, help="The chirpsize. Defaults to 60 Hz.")
comp_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "comparisons.pdf"), help="filename of the plot")
comp_parser.set_defaults(func=plot_comparisons)
roc_parser = subparsers.add_parser("roc", help="plot roc analysis of example cell")
roc_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "roc_analysis.pdf"), help="filename of the plot")
roc_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot")
roc_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting, defaults to 0.001s")
roc_parser.set_defaults(func=foreign_fish_detection_example_plot)
perf_parser = subparsers.add_parser("discrimination", help="plot discrimination performance across all cells")
perf_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "discrimination_performances.pdf"), help="filename of the plot")
infile = os.path.join(data_folder, "discrimination_results.csv")
perf_parser.add_argument("-i", "--inputfile", default=infile , help="Filename of the file containing the discrimination results. Defaults to %s" % infile)
perf_parser.add_argument("-k", "--kernel_width", type=float, default=0.001, help="Kernel width to choose for plotting")
perf_parser.add_argument("-d", "--deltafs", type=float, nargs="+", default=[-200, 5, 200], help="deltaf for individual plot")
perf_parser.add_argument("-c", "--contrasts", type=float, nargs="+", default=[5, 10, 20], help="stimulus contrast for individual plot")
perf_parser.add_argument("-cs", "--chirpsize", type=int, default=60, help="The chirpsize. Defaults to 60Hz.")
perf_parser.set_defaults(func=performance_plot)
resps_parser = subparsers.add_parser("responses", help="plot responses from and example cell")
resps_parser.add_argument("-o", "--outfile", default=os.path.join(figure_folder, "response_example.pdf"), help="filename of the plot")
resps_parser.add_argument("-d", "--deltaf", type=int, default=20, help="deltaf for individual plot")
dflt_cell = os.path.join(data_folder, "cell_2010-11-08-al")
resps_parser.add_argument("-c", "--cell", type=str, default=dflt_cell, help="cell name, defaults to %s" %dflt_cell)
resps_parser.set_defaults(func=response_examples)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()