P-unit_model/Figures_results.py
alexanderott 9d68799c63 stuff
2021-06-07 09:26:21 +02:00

865 lines
35 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
from scipy.stats import pearsonr
import os
from analysis import get_filtered_fit_info, get_behaviour_values, get_parameter_values, behaviour_correlations, parameter_correlations
from fitting.ModelFit import get_best_fit
from experiments.Baseline import BaselineModel, BaselineCellData
from experiments.FiCurve import FICurveModel, FICurveCellData
from parser.CellData import CellData
from my_util import functions as fu
from my_util import save_load
import Figure_constants as consts
parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$",
"mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
"refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
"v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}
parameter_xlabels = {"input_scaling": "cm", "delta_a": r"$\Delta_A$",
"mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
"refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
"v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}
behaviour_titles = {"baseline_frequency": "Base Rate", "Burstiness": "Burst", "coefficient_of_variation": "CV",
"serial_correlation": "SC", "vector_strength": "VS",
"f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope",
"f_zero_middle": r"$f_0$ middle", "eodf": "EODf"}
def main():
# run_all_images()
# quit()
dir_path = "results/final_2/"
# dend_tau_and_ref_effect()
# quit()
fits_info = get_filtered_fit_info(dir_path, filter=True)
# visualize_tested_correlations(fits_info)
quit()
print("Cells left:", len(fits_info))
cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
# plot_cell_model_comp_baseline(cell_behaviour, model_behaviour)
# plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour)
plot_cell_model_comp_adaption(cell_behaviour, model_behaviour)
behaviour_correlations_plot(fits_info)
parameter_correlation_plot(fits_info)
#
# create_parameter_distributions(get_parameter_values(fits_info))
# create_parameter_distributions(get_parameter_values(fits_info, scaled=True, goal_eodf=800), "scaled_to_800_")
# errors = calculate_percent_errors(fits_info)
# create_boxplots(errors)
# example_bad_hist_fits(dir_path)
# example_good_fi_fits(dir_path)
# example_bad_fi_fits(dir_path)
def run_all_images(dir_path, filter=True, pre_analysis_path="", recalculate=False):
if pre_analysis_path != "":
fit_info_name = "figures_res_fit_info.npy"
behaviours_name = "figures_res_behaviour.npy"
fit_info_path = os.path.join(pre_analysis_path, fit_info_name)
if not os.path.exists(fit_info_path) or recalculate:
fits_info = get_filtered_fit_info(dir_path, filter=filter)
save_load.save(fits_info, fit_info_path)
else:
fits_info = save_load.load(fit_info_path)
behaviours_path = os.path.join(pre_analysis_path, behaviours_name)
if not os.path.exists(behaviours_path) or recalculate:
cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
save_load.save([cell_behaviour, model_behaviour], behaviours_path)
else:
cell_behaviour, model_behaviour = save_load.load(behaviours_path)
else:
fits_info = get_filtered_fit_info(dir_path, filter=True)
cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
plot_cell_model_comp_baseline(cell_behaviour, model_behaviour)
plot_cell_model_comp_adaption(cell_behaviour, model_behaviour)
plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour)
behaviour_correlations_plot(fits_info)
parameter_correlation_plot(fits_info)
create_parameter_distributions(get_parameter_values(fits_info))
create_parameter_distributions(get_parameter_values(fits_info, scaled=True, goal_eodf=800), "scaled_to_800_")
# Plots using example cells:
# dend_tau_and_ref_effect()
# example_good_hist_fits(dir_path)
# example_bad_hist_fits(dir_path)
# example_good_fi_fits(dir_path)
# example_bad_fi_fits(dir_path)
def visualize_tested_correlations(fits_info):
for leave_out in range(1, 11, 1):
significance_count, total_count, labels = test_correlations(fits_info, leave_out, model_values=False)
percentages = significance_count / total_count
border = total_count * 0.01
fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 0.5), hspace=0.5, wspace=0.4, left=0.2)
ax = fig.add_subplot(gs[0, 0])
# We want to show all ticks...
ax.imshow(percentages)
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels([behaviour_titles[l] for l in labels])
# remove frame:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ... and label them with the respective list entries
ax.set_yticks(np.arange(len(labels)))
ax.set_yticklabels([behaviour_titles[l] for l in labels])
ax.set_title("Percent: removed {}".format(leave_out))
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(labels)):
for j in range(len(labels)):
if percentages[i, j] > 0.5:
text = ax.text(j, i, "{:.2f}".format(percentages[i, j]), ha="center", va="center",
color="black", size=6)
else:
text = ax.text(j, i, "{:.2f}".format(percentages[i, j]), ha="center", va="center",
color="white", size=6)
ax = fig.add_subplot(gs[0, 1])
ax.imshow(percentages)
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels([behaviour_titles[l] for l in labels])
# remove frame:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ... and label them with the respective list entries
ax.set_yticks(np.arange(len(labels)))
ax.set_yticklabels([behaviour_titles[l] for l in labels])
ax.set_title("Counts - removed {}".format(leave_out))
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(labels)):
for j in range(len(labels)):
if percentages[i, j] > 0.5:
text = ax.text(j, i, "{:.0f}".format(significance_count[i, j]), ha="center", va="center",
color="black", size=6)
else:
text = ax.text(j, i, "{:.0f}".format(significance_count[i, j]), ha="center", va="center",
color="white", size=6)
ax_col = fig.add_subplot(gs[1, :])
data = [np.arange(0, 1.001, 0.01)] * 10
ax_col.set_xticks([0, 25, 50, 75, 100])
ax_col.set_xticklabels([0, 0.25, 0.5, 0.75, 1])
ax_col.set_yticks([])
ax_col.imshow(data)
ax_col.set_xlabel("Correlation Coefficients")
plt.tight_layout()
plt.savefig("figures/consistency_correlations_removed_{}.pdf".format(leave_out))
def test_correlations(fits_info, left_out, model_values=False):
bv_cell, bv_model = get_behaviour_values(fits_info)
# eod_frequencies = [fits_info[cell][3] for cell in sorted(fits_info.keys())]
if model_values:
behaviour_values = bv_model
else:
behaviour_values = bv_cell
labels = ["baseline_frequency", "serial_correlation", "vector_strength", "coefficient_of_variation",
"Burstiness", "f_inf_slope", "f_zero_slope"] # , "eodf"]
significance_counts = np.zeros((len(labels), len(labels)))
correction_factor = sum(range(len(labels)))
total_count = 0
for mask in iall_masks(len(behaviour_values["f_inf_slope"]), left_out):
total_count += 1
idx = np.ones(len(behaviour_values["f_inf_slope"]), dtype=np.int32)
for masked in mask:
idx[masked] = 0
for i in range(len(labels)):
for j in range(len(labels)):
if j > i:
continue
idx = np.array(idx, dtype=np.bool)
values_i = np.array(behaviour_values[labels[i]])[idx]
values_j = np.array(behaviour_values[labels[j]])[idx]
c, p = pearsonr(values_i, values_j)
if p*correction_factor < 0.05:
significance_counts[i, j] += 1
return significance_counts, total_count, labels
def iall_masks(values_count: int, left_out: int):
mask = np.array(range(left_out))
while True:
if mask[0] == values_count - left_out + 1:
break
yield mask
mask[-1] += 1
if mask[-1] >= values_count:
idx_to_start = 0
for i in range(left_out-1):
if mask[-1 - i] >= values_count-i:
mask[-1 - (i+1)] += 1
idx_to_start -= 1
else:
break
while idx_to_start < 0:
# print("i:", idx_to_start, "mask:", mask)
mask[idx_to_start] = mask[idx_to_start -1] + 1
idx_to_start += 1
# print("i:", idx_to_start, "mask:", mask, "end")
def dend_tau_and_ref_effect():
cells = ["2012-12-21-am-invivo-1", "2014-03-19-ad-invivo-1", "2014-03-25-aa-invivo-1"]
cell_type = ["no burster", "burster", "strong burster"]
folders = ["results/ref_and_tau/no_dend_tau/", "results/ref_and_tau/no_ref_period/", "results/final_2/"]
title = [r"without $\tau_{dend}$", r"without $t_{ref}$", "with both"]
fig, axes = plt.subplots(len(cells), 3, figsize=consts.FIG_SIZE_LARGE, sharey="row", sharex="all")
for i, cell in enumerate(cells):
cell_data = CellData("data/final/" + cell)
cell_baseline = BaselineCellData(cell_data)
cell_baseline.load_values(cell_data.get_data_path())
eodf = cell_data.get_eod_frequency()
print(cell)
print("EODf:", eodf)
print("base rate:", cell_baseline.get_baseline_frequency())
print("bursty:", cell_baseline.get_burstiness())
print()
for j, folder in enumerate(folders):
fit = get_best_fit(folder + cell)
model_baseline = BaselineModel(fit.get_model(), eodf)
cell_isis = cell_baseline.get_interspike_intervals() * eodf
model_isis = model_baseline.get_interspike_intervals() * eodf
bins = np.arange(0, 0.025, 0.0001) * eodf
if i == 0 and j == 2:
axes[i, j].hist(model_isis, density=True, bins=bins, color=consts.COLOR_MODEL, alpha=0.75,
label="model")
axes[i, j].hist(cell_isis, density=True, bins=bins, color=consts.COLOR_DATA, alpha=0.5, label="data")
axes[i, j].legend(loc="upper right", frameon=False)
else:
axes[i, j].hist(model_isis, density=True, bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
axes[i, j].hist(cell_isis, density=True, bins=bins, color=consts.COLOR_DATA, alpha=0.5)
if j == 0:
axes[i, j].set_ylabel(cell_type[i])
axes[i, j].set_yticklabels([])
if i == 0:
axes[0, j].set_title(title[j])
plt.xlim(0, 17.5)
fig.text(0.5, 0.04, 'Time [EOD periods]', ha='center', va='center') # shared x label
fig.text(0.06, 0.5, 'ISI Density', ha='center', va='center', rotation='vertical') # shared y label
fig.text(0.11, 0.9, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
fig.text(0.3825, 0.9, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
fig.text(0.655, 0.9, 'C', ha='center', va='center', rotation='horizontal', size=16, family='serif')
# fig.text(0.11, 0.86, '1', ha='center', va='center', rotation='horizontal', size=16, family='serif')
# fig.text(0.11, 0.59, '2', ha='center', va='center', rotation='horizontal', size=16, family='serif')
# fig.text(0.11, 0.32, '3', ha='center', va='center', rotation='horizontal', size=16, family='serif')
plt.savefig(consts.SAVE_FOLDER + "dend_ref_effect.pdf", transparent=True)
plt.close()
def create_parameter_distributions(par_values, prefix=""):
fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace": 0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
if len(par_values.keys()) != 8:
print("not eight parameters")
labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
"tau_a", "delta_a", "dend_tau", "refractory_period"]
x_labels = ["[cm]", "[mV]", "[ms]", r"[mV$\sqrt{s}$]", "[ms]", "[mVms]", "[ms]", "[ms]"]
axes_flat = axes.flatten()
for i, l in enumerate(labels):
bins = calculate_bins(par_values[l], 20)
if "ms" in x_labels[i]:
bins *= 1000
par_values[l] = np.array(par_values[l]) * 1000
axes_flat[i].hist(par_values[l], bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
# axes_flat[i].set_title(parameter_titles[l])
axes_flat[i].set_xlabel(parameter_titles[l] + " " + x_labels[i])
fig.text(0.03, 0.5, 'Count', ha='center', va='center', rotation='vertical', size=12) # shared y label
plt.tight_layout()
consts.set_figure_labels(xoffset=-2.5, yoffset=1.5)
# fig.label_axes()
plt.savefig(consts.SAVE_FOLDER + prefix + "parameter_distributions.pdf")
plt.close()
def behaviour_correlations_plot(fits_info):
fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 0.5), hspace=0.5, wspace=0.15, left=0.2)
# fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False)
labels = [behaviour_titles[k] for k in keys]
img = create_correlation_plot(fig.add_subplot(gs[0, 0]), labels, corr_values, corrected_p_values, "Data")
keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=True)
labels = [behaviour_titles[k] for k in keys]
ax = fig.add_subplot(gs[0, 1])
img = create_correlation_plot(ax, labels, corr_values, corrected_p_values, "Model", y_label=False)
# cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
ax_col = fig.add_subplot(gs[1, :])
data = [np.arange(-1, 1.001, 0.01)] * 10
ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200])
ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
ax_col.set_yticks([])
ax_col.imshow(data)
ax_col.set_xlabel("Correlation Coefficients")
plt.tight_layout()
plt.savefig(consts.SAVE_FOLDER + "behaviour_correlations.pdf")
plt.close()
def parameter_correlation_plot(fits_info):
labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
par_labels = [parameter_titles[l] for l in labels]
fig, ax = plt.subplots(1, 1, figsize=consts.FIG_SIZE_MEDIUM)
# ax, labels, correlations, p_values, title, y_label=True
im = create_correlation_plot(ax, par_labels, corr_values, corrected_p_values, "")
fig.colorbar(im, ax=ax)
plt.savefig(consts.SAVE_FOLDER + "parameter_correlations.pdf")
plt.close()
def create_correlation_plot(ax, labels, correlations, p_values, title, y_label=True):
cleaned_cors = np.zeros(correlations.shape)
for i in range(correlations.shape[0]):
for j in range(correlations.shape[1]):
if abs(p_values[i, j]) < 0.05:
cleaned_cors[i, j] = correlations[i, j]
else:
cleaned_cors[i, j] = np.NAN
if j > i:
cleaned_cors[i, j] = np.NAN
im = ax.imshow(cleaned_cors, vmin=-1, vmax=1)
# We want to show all ticks...
ax.set_xticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
# remove frame:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ... and label them with the respective list entries
if y_label:
ax.set_yticks(np.arange(len(labels)))
ax.set_yticklabels(labels)
else:
ax.set_yticklabels([])
ax.set_title(title)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(labels)):
for j in range(len(labels)):
if not np.isnan(cleaned_cors[i, j]):
if cleaned_cors[i, j] > 0:
text = ax.text(j, i, "{:.2f}".format(cleaned_cors[i, j]), ha="center", va="center", color="black", size=6)
else:
text = ax.text(j, i, "{:.2f}".format(cleaned_cors[i, j]), ha="center", va="center", color="white", size=6)
# if p_values[i][j] < 0.0001:
# text = ax.text(j, i, "***", ha="center", va="center", color="b")
# elif p_values[i][j] < 0.001:
# text = ax.text(j, i, "**", ha="center", va="center", color="b")
# elif p_values[i][j] < 0.05:
# text = ax.text(j, i, "*", ha="center", va="center", color="b")
return im
def example_good_hist_fits(dir_path):
strong_bursty_cell = "2018-05-08-ac-invivo-1"
bursty_cell = "2014-03-19-ad-invivo-1"
non_bursty_cell = "2012-12-21-am-invivo-1"
fig, axes = plt.subplots(1, 3, sharex="all", figsize=(8, 4))
for i, cell in enumerate([non_bursty_cell, bursty_cell, strong_bursty_cell]):
fit_dir = dir_path + cell + "/"
fit = get_best_fit(fit_dir)
cell_data = fit.get_cell_data()
eodf = cell_data.get_eod_frequency()
model = fit.get_model()
baseline_model = BaselineModel(model, eodf, trials=5)
model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf
bins = np.arange(0, 0.025, 0.0001) * eodf
axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)
axes[i].set_xlabel("ISI in EOD periods")
axes[0].set_ylabel("Density")
plt.tight_layout()
consts.set_figure_labels(xoffset=-2.5)
fig.label_axes()
plt.savefig(consts.SAVE_FOLDER + "example_good_isi_hist_fits.pdf", transparent=True)
plt.close()
def example_bad_hist_fits(dir_path):
bursty_cell = "2014-06-06-ag-invivo-1"
strong_bursty_cell = "2018-05-08-ab-invivo-1"
extra_structure_cell = "2014-12-11-ad-invivo-1"
fig, axes = plt.subplots(1, 3, sharex="all", figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE) # , gridspec_kw={"top": 0.95})
for i, cell in enumerate([bursty_cell, strong_bursty_cell, extra_structure_cell]):
fit_dir = dir_path + cell + "/"
fit = get_best_fit(fit_dir)
cell_data = fit.get_cell_data()
eodf = cell_data.get_eod_frequency()
model = fit.get_model()
baseline_model = BaselineModel(model, eodf, trials=5)
cell_baseline = BaselineCellData(cell_data)
print(cell)
print("EODf:", eodf)
print("base rate:", cell_baseline.get_baseline_frequency())
print("bursty:", cell_baseline.get_burstiness())
print()
model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
cell_isi = cell_baseline.get_interspike_intervals() * eodf
bins = np.arange(0, 0.025, 0.0001) * eodf
if i == 0:
axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL, label="model")
axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA, label="data")
axes[i].legend(loc="upper right", frameon=False)
else:
axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)
axes[i].set_xlabel("ISI [EOD periods]")
axes[0].set_ylabel("Density")
plt.tight_layout()
consts.set_figure_labels(xoffset=-2.5, yoffset=1.25)
fig.label_axes()
plt.savefig(consts.SAVE_FOLDER + "example_bad_isi_hist_fits.pdf", transparent=True)
plt.close()
def example_good_fi_fits(dir_path):
fig, axes = plt.subplots(1, 3, figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE, sharey="all")
for i, cell in enumerate(["2012-12-21-am-invivo-1", "2014-03-19-ae-invivo-1", "2014-03-25-aa-invivo-1" ]):
fit_dir = dir_path + cell + "/"
fit = get_best_fit(fit_dir)
cell_data = fit.get_cell_data()
eodf = cell_data.get_eod_frequency()
cell_baseline = BaselineCellData(cell_data)
print(cell)
print("EODf:", eodf)
print("base rate:", cell_baseline.get_baseline_frequency())
print("bursty:", cell_baseline.get_burstiness())
print()
model = fit.get_model()
contrasts = cell_data.get_fi_contrasts()
fi_curve_data = FICurveCellData(cell_data, contrasts, save_dir=cell_data.get_data_path())
contrasts = fi_curve_data.stimulus_values
x_values = np.arange(min(contrasts), max(contrasts), 0.001)
fi_curve_model = FICurveModel(model, contrasts, eodf, trials=10)
f_zero_fit = fi_curve_data.f_zero_fit
f_inf_fit = fi_curve_data.f_inf_fit
# f zero response
axes[i].plot(contrasts, fi_curve_data.get_f_zero_frequencies(), ',',
marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_DATA_f0, label=r"data $f_0$")
axes[i].plot(x_values, fu.full_boltzmann(x_values, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]),
color=consts.COLOR_DATA_f0, alpha=0.75)
axes[i].plot(contrasts, fi_curve_model.get_f_zero_frequencies(), ',',
marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_MODEL_f0, label=r"model $f_0$")
# f inf response
axes[i].plot(contrasts, fi_curve_data.get_f_inf_frequencies(), ',',
marker=consts.finf_marker, alpha=0.5, color=consts.COLOR_DATA_finf, label=r"data $f_{\infty}$")
axes[i].plot(x_values, fu.clipped_line(x_values, f_inf_fit[0], f_inf_fit[1]),
color=consts.COLOR_DATA_finf, alpha=0.5)
axes[i].plot(contrasts, fi_curve_model.get_f_inf_frequencies(), ',',
marker=consts.finf_marker, alpha=0.75, color=consts.COLOR_MODEL_finf, label=r"model $f_{\infty}$")
axes[i].set_xlabel("Contrast")
axes[i].set_xlim((-0.22, 0.22))
axes[0].legend(loc="upper left", frameon=False)
axes[0].set_ylabel("Frequency [Hz]")
plt.tight_layout()
consts.set_figure_labels(xoffset=-2.5)
fig.label_axes()
plt.savefig(consts.SAVE_FOLDER + "example_good_fi_fits.pdf", transparent=True)
plt.close()
def example_bad_fi_fits(dir_path):
fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE)
# "2013-01-08-aa-invivo-1" candidate cell
for i, cell in enumerate(["2012-12-13-ao-invivo-1", "2014-01-23-ab-invivo-1"]):
fit_dir = dir_path + cell + "/"
fit = get_best_fit(fit_dir)
cell_data = fit.get_cell_data()
eodf = cell_data.get_eod_frequency()
cell_baseline = BaselineCellData(cell_data)
print(cell)
print("EODf:", eodf)
print("base rate:", cell_baseline.get_baseline_frequency())
print("bursty:", cell_baseline.get_burstiness())
print()
model = fit.get_model()
contrasts = cell_data.get_fi_contrasts()
fi_curve_data = FICurveCellData(cell_data, contrasts, save_dir=cell_data.get_data_path())
contrasts = fi_curve_data.stimulus_values
x_values = np.arange(min(contrasts), max(contrasts), 0.001)
fi_curve_model = FICurveModel(model, contrasts, eodf, trials=10)
f_zero_fit = fi_curve_data.f_zero_fit
f_inf_fit = fi_curve_data.f_inf_fit
# f zero response
axes[i].plot(contrasts, fi_curve_data.get_f_zero_frequencies(), ',',
marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_DATA_f0, label=r"data $f_0$")
axes[i].plot(x_values, fu.full_boltzmann(x_values, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]),
color=consts.COLOR_DATA_f0, alpha=0.75)
axes[i].plot(contrasts, fi_curve_model.get_f_zero_frequencies(), ',',
marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_MODEL_f0, label=r"model $f_0$")
# f inf response
axes[i].plot(contrasts, fi_curve_data.get_f_inf_frequencies(), ',',
marker=consts.finf_marker, alpha=0.5, color=consts.COLOR_DATA_finf, label=r"data $f_{\infty}$")
axes[i].plot(x_values, fu.clipped_line(x_values, f_inf_fit[0], f_inf_fit[1]),
color=consts.COLOR_DATA_finf, alpha=0.5)
axes[i].plot(contrasts, fi_curve_model.get_f_inf_frequencies(), ',',
marker=consts.finf_marker, alpha=0.75, color=consts.COLOR_MODEL_finf, label=r"model $f_{\infty}$")
axes[i].set_xlabel("Contrast")
axes[i].set_xlim((-0.22, 0.2))
axes[0].set_ylabel("Frequency [Hz]")
axes[0].legend(loc="upper left", frameon=False)
plt.tight_layout()
consts.set_figure_labels(xoffset=-2.5)
fig.label_axes()
plt.savefig(consts.SAVE_FOLDER + "example_bad_fi_fits.pdf", transparent=True)
plt.close()
def create_boxplots(errors):
labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())]
for k in sorted(errors.keys()):
print("{}: median %-error: {:.2f}".format(k, np.median(errors[k])))
y_values = [errors[k] for k in sorted(errors.keys())]
plt.boxplot(y_values)
plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45)
plt.tight_layout()
plt.show()
plt.close()
def plot_cell_model_comp_baseline(cell_behavior, model_behaviour):
fig = plt.figure(figsize=(8, 4))
gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7],
left=0.1, right=0.95, bottom=0.1, top=0.9,
wspace=0.4, hspace=0.2)
num_of_bins = 20
cmap = 'jet'
cell_bursting = cell_behavior["Burstiness"]
# baseline freq plot:
i = 0
cell = cell_behavior["baseline_frequency"]
model = model_behaviour["baseline_frequency"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["baseline_frequency"], bins) # , cmap, cell_bursting)
ax.set_xlabel(r"Cell [Hz]")
ax.set_ylabel(r"Model [Hz]")
ax_histx.set_ylabel("Count")
i += 1
cell = cell_behavior["vector_strength"]
model = model_behaviour["vector_strength"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
print("Cells in cell_model_comp_baseline:", len(cell))
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["vector_strength"], bins) # , cmap, cell_bursting)
ax.set_xlabel(r"Cell")
ax.set_ylabel(r"Model")
ax_histx.set_ylabel("Count")
i += 1
cell = cell_behavior["serial_correlation"]
model = model_behaviour["serial_correlation"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["serial_correlation"], bins) # , cmap, cell_bursting)
ax.set_xlabel(r"Cell")
ax.set_ylabel(r"Model")
fig.text(0.09, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
fig.text(0.375, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
fig.text(0.6625, 0.925, 'C', ha='center', va='center', rotation='horizontal', size=16, family='serif')
ax_histx.set_ylabel("Count")
i += 1
plt.tight_layout()
plt.savefig(consts.SAVE_FOLDER + "fit_baseline_comparison.pdf", transparent=True)
plt.close()
def plot_cell_model_comp_burstiness(cell_behavior, model_behaviour):
fig = plt.figure(figsize=consts.FIG_SIZE_MEDIUM_WIDE)
# ("Burstiness", "coefficient_of_variation")
# Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.
gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7],
left=0.1, right=0.9, bottom=0.1, top=0.9,
wspace=0.3, hspace=0.2)
num_of_bins = 20
# baseline freq plot:
i = 0
cmap = 'jet'
cell = cell_behavior["Burstiness"]
cell_bursting = cell
model = model_behaviour["Burstiness"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax.set_xlabel("Cell [%ms]")
ax.set_ylabel("Model [%ms]")
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
ax_histx.set_ylabel("Count")
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["Burstiness"], bins, cmap, cell_bursting)
i += 1
cell = cell_behavior["coefficient_of_variation"]
model = model_behaviour["coefficient_of_variation"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["coefficient_of_variation"], bins, cmap, cell_bursting)
ax.set_xlabel("Cell")
ax.set_ylabel("Model")
ax_histx.set_ylabel("Count")
plt.tight_layout()
fig.text(0.085, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
fig.text(0.53, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
plt.savefig(consts.SAVE_FOLDER + "fit_burstiness_comparison.pdf", transparent=True)
plt.close()
def plot_cell_model_comp_adaption(cell_behavior, model_behaviour):
fig = plt.figure(figsize=(8, 4))
gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7],
left=0.1, right=0.95, bottom=0.1, top=0.9,
wspace=0.4, hspace=0.3)
# ("f_inf_slope", "f_zero_slope")
# Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.
mpl.rc("axes.formatter", limits=(-5, 3))
num_of_bins = 20
# baseline freq plot:
i = 0
cell = cell_behavior["f_inf_slope"]
model = model_behaviour["f_inf_slope"]
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_inf_slope"], bins)
ax.set_xlabel(r"Cell [Hz]")
ax.set_ylabel(r"Model [Hz]")
ax_histx.set_ylabel("Count")
i += 1
cell = cell_behavior["f_zero_slope"]
model = model_behaviour["f_zero_slope"]
length_before = len(cell)
idx = np.array(cell) < 25000
cell = np.array(cell)[idx]
model = np.array(model)[idx]
idx = np.array(model) < 25000
cell = np.array(cell)[idx]
model = np.array(model)[idx]
print("removed {} values from f_zero_slope plot.".format(length_before - len(cell)))
minimum = min(min(cell), min(model))
maximum = max(max(cell), max(model))
step = (maximum - minimum) / num_of_bins
bins = np.arange(minimum, maximum + step, step)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_zero_slope"], bins)
ax.set_xlabel("Cell [Hz]")
ax.set_ylabel("Model [Hz]")
ax_histx.set_ylabel("Count")
i += 1
# ratio:
cell_inf = cell_behavior["f_inf_slope"]
model_inf = model_behaviour["f_inf_slope"]
cell_zero = cell_behavior["f_zero_slope"]
model_zero = model_behaviour["f_zero_slope"]
cell_ratio = [cell_zero[i]/cell_inf[i] for i in range(len(cell_inf))]
model_ratio = [model_zero[i]/model_inf[i] for i in range(len(model_inf))]
idx = np.array(cell_ratio) < 60
cell_ratio = np.array(cell_ratio)[idx]
model_ratio = np.array(model_ratio)[idx]
idx = np.array(model_ratio) < 60
cell_ratio = np.array(cell_ratio)[idx]
model_ratio = np.array(model_ratio)[idx]
both_ratios = list(cell_ratio.copy())
both_ratios.extend(model_ratio)
bins = calculate_bins(both_ratios, num_of_bins)
ax = fig.add_subplot(gs[1, i])
ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
scatter_hist(cell_ratio, model_ratio, ax, ax_histx, r"$f_0$ / $f_{\infty}$ slope ratio", bins)
ax.set_xlabel("Cell")
ax.set_ylabel("Model")
ax_histx.set_ylabel("Count")
plt.tight_layout()
# fig.text(0.085, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
# fig.text(0.54, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
plt.savefig(consts.SAVE_FOLDER + "fit_adaption_comparison_with_ratio.pdf", transparent=True)
plt.close()
mpl.rc("axes.formatter", limits=(-5, 6))
def scatter_hist(cell_values, model_values, ax, ax_histx, behaviour, bins, cmap=None, color_values=None):
# copied from matplotlib
# the scatter plot:
minimum = min(min(cell_values), min(model_values))
maximum = max(max(cell_values), max(model_values))
ax.plot((minimum, maximum), (minimum, maximum), color="grey")
if cmap is None:
ax.scatter(cell_values, model_values, color="black")
else:
ax.scatter(cell_values, model_values, c=color_values, cmap=cmap)
ax_histx.hist(model_values, bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
ax_histx.hist(cell_values, bins=bins, color=consts.COLOR_DATA, alpha=0.50)
ax_histx.set_title(behaviour)
def calculate_bins(values, num_of_bins):
minimum = np.min(values)
maximum = np.max(values)
step = (maximum - minimum) / (num_of_bins-1)
bins = np.arange(minimum-0.5*step, maximum + step, step)
return bins
if __name__ == '__main__':
main()