add f_0_curve error image to master plot

This commit is contained in:
a.ott 2020-07-28 09:30:38 +02:00
parent 3803628749
commit 51d8f72c06

View File

@ -1,20 +1,25 @@
import os import os
from models.LIFACnoise import LifacNoiseModel from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from Baseline import get_baseline_class from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class from FiCurve import get_fi_curve_class
from CellData import CellData from CellData import CellData
import helperFunctions as hF
import numpy as np import numpy as np
import functions as fu import functions as fu
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def get_best_fit(folder_path): def get_best_fit(folder_path, use_comparable_error=True):
min_err = np.inf min_err = np.inf
min_item = "" min_item = ""
for item in os.listdir(folder_path): for item in os.listdir(folder_path):
item_path = os.path.join(folder_path, item) item_path = os.path.join(folder_path, item)
err = ModelFit(item_path).comparable_error() if use_comparable_error:
err = ModelFit(item_path).comparable_error()
else:
err = ModelFit(item_path).get_fit_routine_error()
if err < min_err: if err < min_err:
min_err = err min_err = err
min_item = item min_item = item
@ -118,6 +123,11 @@ class ModelFit:
path = os.path.join(self.path, self.cell_f_zero_file) path = os.path.join(self.path, self.cell_f_zero_file)
return np.load(path) return np.load(path)
def get_fit_routine_error(self):
foldername = os.path.basename(self.path)
parts = foldername.split("_")
return float(parts[-1])
def comparable_error(self): def comparable_error(self):
cell_values, model_values = self.get_behaviour_values() cell_values, model_values = self.get_behaviour_values()
@ -165,12 +175,12 @@ class ModelFit:
model = self.get_model() model = self.get_model()
cell = self.get_cell_data() cell = self.get_cell_data()
fig, axes = plt.subplots(3, 1, figsize=(8, 10)) fig, axes = plt.subplots(4, 1, figsize=(8, 12))
# isi histogram: # isi histogram:
axes[0].set_title("ISI-Histogram") axes[0].set_title("ISI-Histogram")
axes[0].set_xlim((0, 50)) axes[0].set_xlim((0, 50))
bins = np.arange(0, 50, 0.1) bins = np.arange(0, 50, 0.1)
for data, name in zip((model, cell), ("model", "cell")): for data, name in zip((cell, model), ("cell", "model")):
base = get_baseline_class(data, cell.get_eod_frequency(), trials=5) base = get_baseline_class(data, cell.get_eod_frequency(), trials=5)
isis = np.array(base.get_interspike_intervals()) * 1000 isis = np.array(base.get_interspike_intervals()) * 1000
axes[0].hist(isis, bins=bins, label=name, alpha=0.5, density=True) axes[0].hist(isis, bins=bins, label=name, alpha=0.5, density=True)
@ -178,8 +188,18 @@ class ModelFit:
axes[0].legend() axes[0].legend()
# fi_curve # fi_curve
fi_curve_cell = get_fi_curve_class(cell, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15)
fi_curve_model = get_fi_curve_class(model, cell.get_fi_contrasts(), eod_freq=cell.get_eod_frequency(), trials=15) fi_curve = get_fi_curve_class(cell, cell.get_fi_contrasts(), save_dir=cell.get_data_path())
f_inf_slope = fi_curve.get_f_inf_slope()
contrasts = np.array(cell.get_fi_contrasts())
if f_inf_slope < 0:
contrasts = contrasts * -1
# print("old contrasts:", cell_data.get_fi_contrasts())
# print("new contrasts:", contrasts)
contrasts = sorted(contrasts)
fi_curve_cell = get_fi_curve_class(cell, contrasts, eod_freq=cell.get_eod_frequency(), trials=15)
fi_curve_model = get_fi_curve_class(model, contrasts, eod_freq=cell.get_eod_frequency(), trials=15)
axes[1].set_title("Fi-Curve") axes[1].set_title("Fi-Curve")
min_x = min(min(fi_curve_cell.stimulus_values), min(fi_curve_model.stimulus_values)) min_x = min(min(fi_curve_cell.stimulus_values), min(fi_curve_model.stimulus_values))
@ -210,7 +230,42 @@ class ModelFit:
axes[1].set_title("cell model comparision") axes[1].set_title("cell model comparision")
axes[1].set_xlabel("Stimulus value - contrast") axes[1].set_xlabel("Stimulus value - contrast")
axes[1].legend() axes[1].legend()
# comparision of f_zero_curve:
max_contrast = max(contrasts)
test_contrast = 0.5 * max_contrast
diff_contrasts = np.abs(contrasts - test_contrast)
f_zero_curve_contrast_idx = np.argmin(diff_contrasts)
# model:
stimulus = SinusoidalStepStimulus(cell.get_eod_frequency(), contrasts[f_zero_curve_contrast_idx],
start_time=0, duration=cell.get_stimulus_duration())
freq_traces = []
time_traces = []
for i in range(10):
v1, spikes = model.simulate_fast(stimulus, cell.get_time_end() - cell.get_time_start(), cell.get_time_start())
time, freq = hF.calculate_time_and_frequency_trace(spikes, model.get_sampling_interval())
freq_traces.append(freq)
time_traces.append(time)
time, freq = hF.calculate_mean_of_frequency_traces(time_traces, freq_traces, model.get_sampling_interval())
cell_times, cell_freqs = fi_curve_cell.get_mean_time_and_freq_traces()
axes[2].plot(cell_times[f_zero_curve_contrast_idx], cell_freqs[f_zero_curve_contrast_idx])
axes[2].plot(time, freq)
axes[2].set_title("blue: cell, orange: model")
axes[2].set_xlim(-0.15, 0.35)
start_idx = -1
end_idx = -1
for idx in range(len(cell_times[f_zero_curve_contrast_idx])):
if cell_times[f_zero_curve_contrast_idx][idx] < -0.15:
start_idx = idx
elif cell_times[f_zero_curve_contrast_idx][idx] > 0.35:
end_idx = idx
break
axes[2].set_ylim(0.9*min(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx]),
1.1*max(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx]))
# Value table: # Value table:
cell_values, model_values = self.get_behaviour_values() cell_values, model_values = self.get_behaviour_values()
@ -220,10 +275,11 @@ class ModelFit:
clust_data[0].append(cell_values[k]) clust_data[0].append(cell_values[k])
clust_data[1].append(model_values[k]) clust_data[1].append(model_values[k])
axes[2].axis('tight') axes[3].axis('tight')
axes[2].axis('off') axes[3].axis('off')
table = axes[2].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center') table = axes[3].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center')
fig.suptitle(cell.get_cell_name() + "_comp_err: {:.2f}".format(self.comparable_error())) fig.suptitle(cell.get_cell_name() + "_comp_err: {:.2f}".format(self.comparable_error()))
plt.tight_layout()
if save_path is None: if save_path is None:
plt.show() plt.show()
else: else: