488 lines
18 KiB
Python
488 lines
18 KiB
Python
|
|
import numpy as np
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.gridspec as gridspec
|
|
from scipy.optimize import curve_fit
|
|
from scipy.stats import multivariate_normal, pearsonr
|
|
|
|
from analysis import get_parameter_values, get_filtered_fit_info, parameter_correlations, get_behaviour_values
|
|
from ModelFit import get_best_fit
|
|
import functions as fu
|
|
from Baseline import BaselineModel
|
|
from FiCurve import FICurveModel
|
|
from models.LIFACnoise import LifacNoiseModel
|
|
from Figures_results import create_correlation_plot
|
|
import Figure_constants as consts
|
|
|
|
LOG_TRANSFORM = {"v_offset": False, 'input_scaling': True, 'dend_tau': True, 'tau_a': True, 'delta_a': True,
|
|
'refractory_period': False, 'noise_strength': True, 'mem_tau': True}
|
|
|
|
behaviour_titles = {"baseline_frequency": "Base Rate", "Burstiness": "Burst", "coefficient_of_variation": "CV",
|
|
"serial_correlation": "SC", "vector_strength": "VS",
|
|
"f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope"}
|
|
|
|
parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$",
|
|
"mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
|
|
"refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
|
|
"v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}
|
|
|
|
|
|
recalculate = False
|
|
num_of_models = 100
|
|
|
|
|
|
def main():
|
|
|
|
rerun_all_images()
|
|
quit()
|
|
|
|
folder = "results/final_2/"
|
|
fit_infos = get_filtered_fit_info(folder, filter=True)
|
|
goal_eodf = 800
|
|
param_values = get_parameter_values(fit_infos, scaled=True, goal_eodf=goal_eodf)
|
|
|
|
# plots 1
|
|
keys, means, cov_matrix = calculate_means_and_covariances(param_values)
|
|
par_list = draw_random_models(1000, keys, means, cov_matrix, seed=1)
|
|
parameter_correlation_plot(par_list, fit_infos)
|
|
plot_distributions_with_set_fits(param_values)
|
|
|
|
if recalculate:
|
|
keys, means, cov_matrix = calculate_means_and_covariances(param_values)
|
|
par_list = draw_random_models(num_of_models, keys, means, cov_matrix)
|
|
|
|
behaviour = model_behaviour_distributions(par_list, eodf=goal_eodf)
|
|
|
|
save_behaviour(behaviour, par_list)
|
|
else:
|
|
behaviour, par_list = load_behavior()
|
|
create_behaviour_distributions(behaviour, fit_infos)
|
|
compare_distribution_random_vs_fitted_params(par_list, param_values)
|
|
|
|
|
|
def rerun_all_images():
|
|
|
|
folder = "results/final_2/"
|
|
fit_infos = get_filtered_fit_info(folder, filter=True)
|
|
goal_eodf = 800
|
|
param_values = get_parameter_values(fit_infos, scaled=True, goal_eodf=goal_eodf)
|
|
|
|
keys, means, cov_matrix = calculate_means_and_covariances(param_values)
|
|
par_list = draw_random_models(1000, keys, means, cov_matrix, seed=1)
|
|
parameter_correlation_plot(par_list, fit_infos)
|
|
plot_distributions_with_set_fits(param_values)
|
|
|
|
behaviour, par_list = load_behavior()
|
|
create_behaviour_distributions(behaviour, fit_infos)
|
|
compare_distribution_random_vs_fitted_params(par_list, param_values)
|
|
|
|
|
|
def compare_distribution_random_vs_fitted_params(par_list, scaled_param_values):
|
|
labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
|
|
"tau_a", "delta_a", "dend_tau", "refractory_period"]
|
|
x_labels = ["[cm]", "[mV]", "[ms]", r"[mV$\sqrt{s}$]", "[ms]", "[mVms]", "[ms]", "[ms]"]
|
|
model_parameter_values = {}
|
|
for l in labels:
|
|
model_parameter_values[l] = []
|
|
|
|
for params in par_list:
|
|
for l in labels:
|
|
model_parameter_values[l].append(params[l])
|
|
|
|
fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
|
|
axes_flat = axes.flatten()
|
|
for i, l in enumerate(labels):
|
|
rand_model_values = model_parameter_values[l]
|
|
fitted_model_values = scaled_param_values[l]
|
|
|
|
if "ms" in x_labels[i]:
|
|
rand_model_values = np.array(rand_model_values) * 1000
|
|
fitted_model_values = np.array(fitted_model_values) * 1000
|
|
|
|
min_v = min(min(rand_model_values), min(fitted_model_values)) * 0.95
|
|
max_v = max(max(rand_model_values), max(fitted_model_values)) * 1.05
|
|
# limit = 50000
|
|
# if max_v > limit:
|
|
# print("For {} the max value was limited to {}, {} values were excluded!".format(l, limit, np.sum(np.array(cell_b_values[l]) > limit)))
|
|
# max_v = limit
|
|
step = (max_v - min_v) / 30
|
|
bins = np.arange(min_v, max_v + step, step)
|
|
axes_flat[i].hist(fitted_model_values, bins=bins, alpha=0.5, density=True)
|
|
axes_flat[i].hist(rand_model_values, bins=bins, alpha=0.5, density=True)
|
|
axes_flat[i].set_xlabel(parameter_titles[l] + " " + x_labels[i])
|
|
axes_flat[i].set_yticks([])
|
|
axes_flat[i].set_yticklabels([])
|
|
|
|
fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12) # shared y label
|
|
plt.tight_layout()
|
|
|
|
consts.set_figure_labels(xoffset=-2.5, yoffset=0)
|
|
fig.label_axes()
|
|
|
|
plt.savefig(consts.SAVE_FOLDER + "compare_parameter_dist_random_models.pdf")
|
|
plt.close()
|
|
|
|
|
|
def parameter_correlation_plot(par_list, fits_info):
|
|
|
|
fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
|
|
gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 1), hspace=0.05, wspace=0.05)
|
|
# fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
|
|
|
|
labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
|
|
par_labels = [parameter_titles[l] for l in labels]
|
|
img = create_correlation_plot(fig.add_subplot(gs[0, 0]), par_labels, corr_values, corrected_p_values,
|
|
"Fitted Models", y_label=True)
|
|
|
|
rand_labels, rand_corr_values, rand_corrected_p_values = parameter_correlations_from_par_list(par_list)
|
|
par_labels = [parameter_titles[l] for l in rand_labels]
|
|
img = create_correlation_plot(fig.add_subplot(gs[0, 1]), par_labels, rand_corr_values, rand_corrected_p_values * 10e50, "Drawn Models", y_label=False)
|
|
|
|
consts.set_figure_labels(xoffset=-2.5, yoffset=1.5)
|
|
fig.label_axes()
|
|
|
|
ax_col = fig.add_subplot(gs[1, :])
|
|
data = [np.arange(-1, 1.001, 0.01)] * 10
|
|
ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200])
|
|
ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
|
|
ax_col.set_yticks([])
|
|
ax_col.imshow(data)
|
|
ax_col.set_xlabel("Correlation Coefficients")
|
|
|
|
|
|
|
|
plt.savefig(consts.SAVE_FOLDER + "rand_parameter_correlations_comparison.pdf")
|
|
plt.close()
|
|
|
|
|
|
def parameter_correlations_from_par_list(par_list):
|
|
labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
|
|
"tau_a", "delta_a", "dend_tau", "refractory_period"]
|
|
|
|
parameter_values = {}
|
|
for l in labels:
|
|
parameter_values[l] = []
|
|
|
|
for params in par_list:
|
|
for l in labels:
|
|
parameter_values[l].append(params[l])
|
|
|
|
corr_values = np.zeros((len(labels), len(labels)))
|
|
p_values = np.ones((len(labels), len(labels)))
|
|
|
|
for i in range(len(labels)):
|
|
for j in range(len(labels)):
|
|
c, p = pearsonr(parameter_values[labels[i]], parameter_values[labels[j]])
|
|
corr_values[i, j] = c
|
|
p_values[i, j] = p
|
|
|
|
corrected_p_values = p_values * sum(range(len(labels)))
|
|
|
|
return labels, corr_values, corrected_p_values
|
|
|
|
|
|
def model_behaviour_distributions(par_list, eodf=800):
|
|
behaviour = {}
|
|
|
|
for key in behaviour_titles.keys():
|
|
behaviour[key] = []
|
|
|
|
for i, parset in enumerate(par_list):
|
|
|
|
model = LifacNoiseModel(parset)
|
|
baseline = BaselineModel(model, eodf)
|
|
|
|
behaviour["baseline_frequency"].append(baseline.get_baseline_frequency())
|
|
behaviour["Burstiness"].append(baseline.get_burstiness())
|
|
behaviour["coefficient_of_variation"].append(baseline.get_coefficient_of_variation())
|
|
behaviour["serial_correlation"].append(baseline.get_serial_correlation(1)[0])
|
|
behaviour["vector_strength"].append(baseline.get_vector_strength())
|
|
|
|
fi_curve = FICurveModel(model, np.arange(-0.3, 0.301, 0.1), eodf)
|
|
behaviour["f_inf_slope"].append(fi_curve.f_inf_fit[0])
|
|
behaviour["f_zero_slope"].append(fi_curve.get_f_zero_fit_slope_at_straight())
|
|
|
|
print("{:} of {:}".format(i + 1, len(par_list)))
|
|
|
|
return behaviour
|
|
|
|
|
|
def save_behaviour(behaviour, par_list):
|
|
# save behaviour:
|
|
keys = np.array(sorted(behaviour.keys()))
|
|
data_points = len(behaviour[keys[0]])
|
|
data = np.zeros((len(keys), data_points))
|
|
for i, k in enumerate(keys):
|
|
k_data = np.array(behaviour[k])
|
|
data[i, :] = k_data
|
|
|
|
np.save("data/random_model_behaviour_data.npy", data)
|
|
np.save("data/random_model_behaviour_keys.npy", keys)
|
|
|
|
# save parameter list:
|
|
|
|
par_keys = np.array(sorted(par_list[0].keys()))
|
|
num_models = len(par_list)
|
|
|
|
pars_data = np.zeros((num_models, len(par_keys)))
|
|
|
|
for i, params in enumerate(par_list):
|
|
params_array = np.array([params[k] for k in par_keys])
|
|
pars_data[i, :] = params_array
|
|
|
|
np.save("data/random_model_parameter_data.npy", pars_data)
|
|
np.save("data/random_model_parameter_keys.npy", par_keys)
|
|
|
|
|
|
def load_behavior():
|
|
data = np.load("data/random_model_behaviour_data.npy")
|
|
keys = np.load("data/random_model_behaviour_keys.npy")
|
|
behaviour = {}
|
|
for i, k in enumerate(keys):
|
|
behaviour[k] = data[i, :]
|
|
|
|
pars_data = np.load("data/random_model_parameter_data.npy")
|
|
par_keys = np.load("data/random_model_parameter_keys.npy")
|
|
par_list = []
|
|
|
|
for i in range(len(pars_data[:, 0])):
|
|
param_dict = {}
|
|
for j, k in enumerate(par_keys):
|
|
param_dict[k] = pars_data[i, j]
|
|
par_list.append(param_dict)
|
|
|
|
return behaviour, par_list
|
|
|
|
|
|
def create_behaviour_distributions(drawn_model_behaviour, fits_info):
|
|
fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
|
|
cell_behaviour, fitted_model_behaviour = get_behaviour_values(fits_info)
|
|
labels = ['Burstiness', 'baseline_frequency', 'coefficient_of_variation', 'f_inf_slope', 'f_zero_slope', 'serial_correlation', 'vector_strength']
|
|
unit = ["[%ms]", "[Hz]", "", "[Hz]", "[Hz]", "", ""]
|
|
|
|
axes_flat = axes.flatten()
|
|
for i, l in enumerate(labels):
|
|
min_v = min(drawn_model_behaviour[l]) * 0.95
|
|
max_v = max(drawn_model_behaviour[l]) * 1.05
|
|
# limit = 50000
|
|
# if max_v > limit:
|
|
# print("For {} the max value was limited to {}, {} values were excluded!".format(l, limit, np.sum(np.array(cell_b_values[l]) > limit)))
|
|
# max_v = limit
|
|
step = (max_v - min_v) / 20
|
|
bins = np.arange(min_v, max_v + step, step)
|
|
axes_flat[i].hist(drawn_model_behaviour[l], bins=bins, alpha=0.75, density=True, color=consts.COLOR_MODEL)
|
|
axes_flat[i].hist(cell_behaviour[l], bins=bins, alpha=0.5, density=True, color=consts.COLOR_DATA)
|
|
axes_flat[i].set_xlabel(behaviour_titles[l] + " " + unit[i])
|
|
axes_flat[i].set_yticks([])
|
|
axes_flat[i].set_yticklabels([])
|
|
axes_flat[-1].set_visible(False)
|
|
|
|
plt.tight_layout()
|
|
|
|
consts.set_figure_labels(xoffset=-2.5, yoffset=0)
|
|
fig.label_axes()
|
|
fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12) # shared y label
|
|
|
|
plt.savefig(consts.SAVE_FOLDER + "random_models_behaviour_dist.pdf")
|
|
plt.close()
|
|
|
|
|
|
def test_plot_models(par_list, eodf):
|
|
|
|
for pars in par_list:
|
|
baseline = BaselineModel(LifacNoiseModel(pars), eodf)
|
|
baseline.plot_interspike_interval_histogram()
|
|
|
|
fi_curve = FICurveModel(LifacNoiseModel(pars), np.arange(-0.3, 0.31, 0.1), eodf)
|
|
fi_curve.plot_fi_curve()
|
|
|
|
|
|
def calculate_means_and_covariances(param_values):
|
|
transformed_values = {}
|
|
keys = sorted(param_values.keys())
|
|
for key in keys:
|
|
if LOG_TRANSFORM[key]:
|
|
transformed_values[key] = np.log(np.array(param_values[key]))
|
|
else:
|
|
transformed_values[key] = np.array(param_values[key])
|
|
transformed_fits = get_gauss_fits()
|
|
means = np.array([transformed_fits[k][1] for k in keys])
|
|
|
|
cov_matrix = np.zeros((len(keys), len(keys)))
|
|
|
|
for i, k1 in enumerate(keys):
|
|
for j, k2 in enumerate(keys):
|
|
cor, p = pearsonr(transformed_values[k1], transformed_values[k2])
|
|
cov_matrix[i, j] = cor * transformed_fits[k1][2] * transformed_fits[k2][2]
|
|
|
|
return keys, means, cov_matrix
|
|
|
|
|
|
def draw_random_models(num_of_models, keys, means, cov_matrix, seed=None):
|
|
if seed is not None:
|
|
transformed_model_params = multivariate_normal.rvs(means, cov_matrix, num_of_models, seed)
|
|
else:
|
|
transformed_model_params = multivariate_normal.rvs(means, cov_matrix, num_of_models)
|
|
|
|
drawn_parameters = []
|
|
|
|
for par_set in transformed_model_params:
|
|
retransformed_parameters = {}
|
|
|
|
for i, k in enumerate(keys):
|
|
if LOG_TRANSFORM[k]:
|
|
retransformed_parameters[k] = np.exp(par_set[i])
|
|
else:
|
|
retransformed_parameters[k] = par_set[i]
|
|
|
|
drawn_parameters.append(retransformed_parameters)
|
|
|
|
return drawn_parameters
|
|
|
|
|
|
def get_gauss_fits():
|
|
# TODO NOT NORMED TO INTEGRAL OF 1 !!!!!!!
|
|
transformed_gauss_fits = {}
|
|
# fit parameter: amplitude, mean, sigma
|
|
transformed_gauss_fits["delta_a"] = [0.52555418, -2.17583514, 0.658713652] # tweak
|
|
transformed_gauss_fits["dend_tau"] = [0.90518987, -5.509343763, 0.3593178] # good
|
|
transformed_gauss_fits["mem_tau"] = [0.85176348, -6.2468377, 0.42126255] # good
|
|
transformed_gauss_fits["input_scaling"] = [0.57239028, 5., 0.6] # [0.37239028, 5.92264105, 1.77342945] # tweak
|
|
transformed_gauss_fits["noise_strength"] = [0.62216977, -3.49622807, 0.58081673]# good
|
|
transformed_gauss_fits["tau_a"] = [0.82351638, -2.39879173, 0.45725644] # good
|
|
transformed_gauss_fits["v_offset"] = [1.33749859e-02, -1.91220096e+01, 1.71068108e+01] # good
|
|
transformed_gauss_fits["refractory_period"] = [1.370406256, 9.14715386e-04, 2.33470418e-04]
|
|
|
|
return transformed_gauss_fits
|
|
|
|
|
|
def plot_distributions_with_set_fits(param_values):
|
|
|
|
fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
|
|
|
|
gauss_fits = get_gauss_fits()
|
|
bin_number = 30
|
|
|
|
labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
|
|
"tau_a", "delta_a", "dend_tau", "refractory_period"]
|
|
x_labels = ["[ln(cm)]", "[mV]", "[ln(s)]", r"[ln(mV$\sqrt{s}$)]", "[ln(s)]", "[ln(mVs)]", "[ln(s)]", "[ms]"]
|
|
for i, key in enumerate(labels):
|
|
k = i % 2
|
|
m = int(i/2)
|
|
|
|
values = param_values[key]
|
|
if LOG_TRANSFORM[key]:
|
|
values = np.log(np.array(param_values[key]))
|
|
|
|
x = np.arange(min(values), max(values), (max(values) - min(values)) / 100)
|
|
plot_x = np.arange(min(values), max(values), (max(values) - min(values)) / 100)
|
|
if "ms" in x_labels[i]:
|
|
values = np.array(values) * 1000
|
|
plot_x *= 1000
|
|
|
|
gauss_param = gauss_fits[key]
|
|
|
|
bins = calculate_bins(values, bin_number)
|
|
|
|
axes[m, k].hist(values, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
|
|
axes[m, k].plot(plot_x, fu.gauss(x, gauss_param[0], gauss_param[1], gauss_param[2]), color="black")
|
|
axes[m, k].set_xlabel(parameter_titles[key] + " " + x_labels[i])
|
|
axes[m, k].set_yticklabels([])
|
|
axes[m, k].set_yticks([])
|
|
|
|
plt.tight_layout()
|
|
|
|
consts.set_figure_labels(xoffset=-2.5, yoffset=0)
|
|
fig.label_axes()
|
|
fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12) # shared y label
|
|
|
|
plt.savefig(consts.SAVE_FOLDER + "parameter_distribution_with_gauss_fits.pdf")
|
|
plt.close()
|
|
|
|
|
|
def calculate_bins(values, num_of_bins):
|
|
minimum = np.min(values)
|
|
maximum = np.max(values)
|
|
step = (maximum - minimum) / (num_of_bins-1)
|
|
|
|
bins = np.arange(minimum-0.5*step, maximum + step, step)
|
|
return bins
|
|
|
|
|
|
def plot_distributions(param_values):
|
|
bin_number = 30
|
|
fig, axes = plt.subplots(len(param_values.keys()), 2)
|
|
for i, key in enumerate(sorted(param_values.keys())):
|
|
|
|
# normal hist:
|
|
values = param_values[key]
|
|
bins = calculate_bins(values, bin_number)
|
|
normal, n_bins, patches = axes[i, 0].hist(values, bins=calculate_bins(values, bin_number), density=True)
|
|
axes[i, 0].set_title(key)
|
|
|
|
# fit gauss:
|
|
bin_width = np.mean(np.diff(n_bins))
|
|
middle_of_bins = n_bins + bin_width / 2
|
|
axes[i, 0].plot(middle_of_bins[:-1], normal, 'o')
|
|
try:
|
|
n_gauss_pars = fit_gauss(middle_of_bins[:-1], normal)
|
|
x = np.arange(min(param_values[key]), max(param_values[key]),
|
|
(max(param_values[key]) - min(param_values[key])) / 100)
|
|
axes[i, 0].plot(x, fu.gauss(x, n_gauss_pars[0], n_gauss_pars[1], n_gauss_pars[2]))
|
|
print(key, ": normal:", n_gauss_pars)
|
|
except RuntimeError as e:
|
|
pass
|
|
|
|
# log transformed:
|
|
if key != "v_offset":
|
|
log_values = np.log(np.array(param_values[key]))
|
|
log_trans, l_bins, patches = axes[i, 1].hist(log_values, bins=bin_number, density=True)
|
|
bin_width = np.mean(np.diff(l_bins))
|
|
middle_of_bins = l_bins + bin_width / 2
|
|
axes[i, 1].plot(middle_of_bins[:-1], log_trans, 'o')
|
|
try:
|
|
l_gauss_pars = fit_gauss(middle_of_bins[:-1], log_trans)
|
|
x = np.arange(min(log_values), max(log_values),
|
|
(max(log_values) - min(log_values)) / 100)
|
|
axes[i, 1].plot(x, fu.gauss(x, l_gauss_pars[0], l_gauss_pars[1], l_gauss_pars[2]))
|
|
print(key, ": log:", l_gauss_pars)
|
|
except RuntimeError as e:
|
|
pass
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
plt.close()
|
|
|
|
|
|
def fit_gauss(x_values, y_values):
|
|
mean_v = np.mean(x_values)
|
|
std_v = np.std(x_values)
|
|
amp = max(y_values)
|
|
popt, pcov = curve_fit(fu.gauss, x_values, y_values, p0=(amp, mean_v, std_v))
|
|
|
|
return popt
|
|
|
|
def get_parameter_distributions(folder, param_keys=None):
|
|
if param_keys is None:
|
|
param_keys = ["v_offset", 'input_scaling', 'dend_tau', 'tau_a', 'delta_a',
|
|
'refractory_period', 'noise_strength', 'mem_tau']
|
|
parameter_values = {}
|
|
|
|
for key in param_keys:
|
|
parameter_values[key] = []
|
|
|
|
for cell in sorted(os.listdir(folder)):
|
|
fit = get_best_fit(folder + cell)
|
|
|
|
final_params = fit.get_final_parameters()
|
|
|
|
for key in param_keys:
|
|
parameter_values[key].append(final_params[key])
|
|
|
|
return parameter_values
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|