From 574f4a80f2db4f884e36b329406b718d7e78b1c5 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Sun, 5 Jul 2020 11:04:32 +0200 Subject: [PATCH] add loading of fi-curve values, add creation of masterplot --- ModelFit.py | 152 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 2 deletions(-) diff --git a/ModelFit.py b/ModelFit.py index 7198686..d336fb8 100644 --- a/ModelFit.py +++ b/ModelFit.py @@ -1,14 +1,19 @@ import os from models.LIFACnoise import LifacNoiseModel +from Baseline import get_baseline_class +from FiCurve import get_fi_curve_class +from CellData import CellData import numpy as np +import functions as fu +import matplotlib.pyplot as plt def get_best_fit(folder_path): min_err = np.inf min_item = "" for item in os.listdir(folder_path): - err = float(item.split("_")[-1]) + err = ModelFit(os.path.join(folder_path, min_item)).comparable_error() if err < min_err: min_err = err min_item = item @@ -26,6 +31,11 @@ class ModelFit: self.isi_hist_img = "isi-histogram.png" self.isi_hist_comp_img = "isi-histogram_comparision.png" + self.model_f_inf_file = "model_fi_inf_values.npy" + self.cell_f_inf_file = "cell_fi_inf_values.npy" + self.model_f_zero_file = "model_fi_zero_values.npy" + self.cell_f_zero_file = "cell_fi_zero_values.npy" + def get_final_parameters(self): par_file_path = os.path.join(self.path, self.parameter_file_name) with open(par_file_path, 'r') as par_file: @@ -33,7 +43,7 @@ class ModelFit: line = line.strip().split('\t') if line[0] == "final_parameters:": - return dict(line[1]) + return eval(line[1]) print("Final parameters not found! - ", self.path) return {} @@ -81,3 +91,141 @@ class ModelFit: def get_model(self): return LifacNoiseModel(self.get_final_parameters()) + + def get_cell_path(self): + with open(os.path.join(self.path, "cell_data_path.txt"), "r") as f: + cell_path = f.readline().strip() + + return cell_path + + def get_cell_data(self): + return CellData(self.get_cell_path()) + + def get_model_f_inf_values(self): + path = os.path.join(self.path, self.model_f_inf_file) + return np.load(path) + + def get_model_f_zero_values(self): + path = os.path.join(self.path, self.model_f_zero_file) + return np.load(path) + + def get_cell_f_inf_values(self): + path = os.path.join(self.path, self.cell_f_inf_file) + return np.load(path) + + def get_cell_f_zero_values(self): + path = os.path.join(self.path, self.cell_f_zero_file) + return np.load(path) + + def comparable_error(self): + cell_values, model_values = self.get_behaviour_values() + + error = 0 + + bf = "baseline_frequency" + error += abs(cell_values[bf] - model_values[bf]) / 5 + vs = "vector_strength" + error += abs(cell_values[vs] - model_values[vs]) / 0.1 + sc = "serial_correlation" + error += abs(cell_values[sc] - model_values[sc]) / 0.1 + burst = "Burstiness" + error += abs(cell_values[burst] - model_values[burst]) / 0.05 + cv = "coefficient_of_variation" + error += abs(cell_values[cv] - model_values[cv]) / 0.1 + f_inf_slope = "f_inf_slope" + error += abs(cell_values[f_inf_slope] - model_values[f_inf_slope]) / 5 + + # f_zero_sloe = "f_zero_slope" + # error += abs(cell_values[f_zero_sloe] - model_values[f_zero_sloe]) / 100 + + c_f_inf_values = self.get_cell_f_inf_values() + c_f_zero_values = self.get_cell_f_zero_values() + + m_f_inf_values = self.get_model_f_inf_values() + m_f_zero_values = self.get_cell_f_zero_values() + + error_f_inf = 0 + for m_value, c_value in zip(m_f_inf_values, c_f_inf_values): + error_f_inf += abs(c_value - m_value) / 10 + + error_f_inf = error_f_inf / len(m_f_inf_values) + error += error_f_inf + + error_f_zero = 0 + for m_value, c_value in zip(m_f_zero_values, c_f_zero_values): + error_f_zero += abs(c_value - m_value) / 10 + + error_f_zero = error_f_zero / len(m_f_zero_values) + error += error_f_zero + + return error + + def generate_master_plot(self, save_path=None): + model = self.get_model() + cell = self.get_cell_data() + + fig, axes = plt.subplots(3, 1, figsize=(8, 10)) + # isi histogram: + axes[0].set_title("ISI-Histogram") + axes[0].set_xlim((0, 50)) + bins = np.arange(0, 50, 0.1) + for data, name in zip((model, cell), ("model", "cell")): + base = get_baseline_class(data, cell.get_eod_frequency(), trials=5) + isis = np.array(base.get_interspike_intervals()) * 1000 + axes[0].hist(isis, bins=bins, label=name, alpha=0.5, density=True) + + axes[0].legend() + + # fi_curve + fi_curve_cell = get_fi_curve_class(cell, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15) + fi_curve_model = get_fi_curve_class(model, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15) + + axes[1].set_title("Fi-Curve") + min_x = min(min(fi_curve_cell.stimulus_values), min(fi_curve_model.stimulus_values)) + max_x = max(max(fi_curve_cell.stimulus_values), max(fi_curve_model.stimulus_values)) + step = (max_x - min_x) / 5000 + x_values = np.arange(min_x, max_x + step, step) + + # plot baseline + f_base_color = ("blue", "deepskyblue") + f_inf_color = ("green", "limegreen") + f_zero_color = ("red", "orange") + + median_baseline = np.median(fi_curve_cell.get_f_baseline_frequencies()) + axes[1].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base") + axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_baseline_frequencies(), + 'o', color=f_base_color[1], label='model base') + + y_values = [fu.clipped_line(x, fi_curve_cell.f_inf_fit[0], fi_curve_cell.f_inf_fit[1]) for x in x_values] + axes[1].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell') + axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_inf_frequencies(), + 'o', color=f_inf_color[1], label='f_inf model') + + popt = fi_curve_cell.f_zero_fit + axes[1].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], + color=f_zero_color[0], label='f_0_fit cell') + axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_zero_frequencies(), + 'o', color=f_zero_color[1], label='f_zero model') + axes[1].set_title("cell model comparision") + axes[1].set_xlabel("Stimulus value - contrast") + axes[1].legend() + + # Value table: + cell_values, model_values = self.get_behaviour_values() + + collabel = sorted(cell_values.keys()) + clust_data = [[], []] + for k in collabel: + clust_data[0].append(cell_values[k]) + clust_data[1].append(model_values[k]) + + axes[2].axis('tight') + axes[2].axis('off') + table = axes[2].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center') + fig.suptitle(cell.get_cell_name() + "_comp_err: {:.2f}".format(self.comparable_error())) + if save_path is None: + plt.show() + else: + plt.savefig(save_path + cell.get_cell_name() + "_master_plot.pdf") + +