From 51d8f72c0693565e6b2476b237e44338be5ca455 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Tue, 28 Jul 2020 09:30:38 +0200 Subject: [PATCH] add f_0_curve error image to master plot --- ModelFit.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 10 deletions(-) diff --git a/ModelFit.py b/ModelFit.py index 5e13eaf..bd44b66 100644 --- a/ModelFit.py +++ b/ModelFit.py @@ -1,20 +1,25 @@ import os from models.LIFACnoise import LifacNoiseModel +from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus from Baseline import get_baseline_class from FiCurve import get_fi_curve_class from CellData import CellData +import helperFunctions as hF import numpy as np import functions as fu import matplotlib.pyplot as plt -def get_best_fit(folder_path): +def get_best_fit(folder_path, use_comparable_error=True): min_err = np.inf min_item = "" for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) - err = ModelFit(item_path).comparable_error() + if use_comparable_error: + err = ModelFit(item_path).comparable_error() + else: + err = ModelFit(item_path).get_fit_routine_error() if err < min_err: min_err = err min_item = item @@ -118,6 +123,11 @@ class ModelFit: path = os.path.join(self.path, self.cell_f_zero_file) return np.load(path) + def get_fit_routine_error(self): + foldername = os.path.basename(self.path) + parts = foldername.split("_") + return float(parts[-1]) + def comparable_error(self): cell_values, model_values = self.get_behaviour_values() @@ -165,12 +175,12 @@ class ModelFit: model = self.get_model() cell = self.get_cell_data() - fig, axes = plt.subplots(3, 1, figsize=(8, 10)) + fig, axes = plt.subplots(4, 1, figsize=(8, 12)) # isi histogram: axes[0].set_title("ISI-Histogram") axes[0].set_xlim((0, 50)) bins = np.arange(0, 50, 0.1) - for data, name in zip((model, cell), ("model", "cell")): + for data, name in zip((cell, model), ("cell", "model")): base = get_baseline_class(data, cell.get_eod_frequency(), trials=5) isis = np.array(base.get_interspike_intervals()) * 1000 axes[0].hist(isis, bins=bins, label=name, alpha=0.5, density=True) @@ -178,8 +188,18 @@ class ModelFit: axes[0].legend() # fi_curve - fi_curve_cell = get_fi_curve_class(cell, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15) - fi_curve_model = get_fi_curve_class(model, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15) + + fi_curve = get_fi_curve_class(cell, cell.get_fi_contrasts(), save_dir=cell.get_data_path()) + f_inf_slope = fi_curve.get_f_inf_slope() + contrasts = np.array(cell.get_fi_contrasts()) + if f_inf_slope < 0: + contrasts = contrasts * -1 + # print("old contrasts:", cell_data.get_fi_contrasts()) + # print("new contrasts:", contrasts) + contrasts = sorted(contrasts) + + fi_curve_cell = get_fi_curve_class(cell, contrasts, eod_freq=cell.get_eod_frequency(), trials=15) + fi_curve_model = get_fi_curve_class(model, contrasts, eod_freq=cell.get_eod_frequency(), trials=15) axes[1].set_title("Fi-Curve") min_x = min(min(fi_curve_cell.stimulus_values), min(fi_curve_model.stimulus_values)) @@ -210,7 +230,42 @@ class ModelFit: axes[1].set_title("cell model comparision") axes[1].set_xlabel("Stimulus value - contrast") axes[1].legend() - + # comparision of f_zero_curve: + + max_contrast = max(contrasts) + test_contrast = 0.5 * max_contrast + diff_contrasts = np.abs(contrasts - test_contrast) + f_zero_curve_contrast_idx = np.argmin(diff_contrasts) + + # model: + stimulus = SinusoidalStepStimulus(cell.get_eod_frequency(), contrasts[f_zero_curve_contrast_idx], + start_time=0, duration=cell.get_stimulus_duration()) + freq_traces = [] + time_traces = [] + for i in range(10): + v1, spikes = model.simulate_fast(stimulus, cell.get_time_end() - cell.get_time_start(), cell.get_time_start()) + time, freq = hF.calculate_time_and_frequency_trace(spikes, model.get_sampling_interval()) + freq_traces.append(freq) + time_traces.append(time) + + time, freq = hF.calculate_mean_of_frequency_traces(time_traces, freq_traces, model.get_sampling_interval()) + + cell_times, cell_freqs = fi_curve_cell.get_mean_time_and_freq_traces() + axes[2].plot(cell_times[f_zero_curve_contrast_idx], cell_freqs[f_zero_curve_contrast_idx]) + axes[2].plot(time, freq) + axes[2].set_title("blue: cell, orange: model") + axes[2].set_xlim(-0.15, 0.35) + + start_idx = -1 + end_idx = -1 + for idx in range(len(cell_times[f_zero_curve_contrast_idx])): + if cell_times[f_zero_curve_contrast_idx][idx] < -0.15: + start_idx = idx + elif cell_times[f_zero_curve_contrast_idx][idx] > 0.35: + end_idx = idx + break + axes[2].set_ylim(0.9*min(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx]), + 1.1*max(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx])) # Value table: cell_values, model_values = self.get_behaviour_values() @@ -220,10 +275,11 @@ class ModelFit: clust_data[0].append(cell_values[k]) clust_data[1].append(model_values[k]) - axes[2].axis('tight') - axes[2].axis('off') - table = axes[2].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center') + axes[3].axis('tight') + axes[3].axis('off') + table = axes[3].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center') fig.suptitle(cell.get_cell_name() + "_comp_err: {:.2f}".format(self.comparable_error())) + plt.tight_layout() if save_path is None: plt.show() else: