add option to recalculate saved values

This commit is contained in:
a.ott 2020-08-01 12:02:06 +02:00
parent c5b72bec26
commit b1be741042

View File

@ -14,7 +14,7 @@ from sys import stderr
class FICurve: class FICurve:
def __init__(self, stimulus_values, save_dir=None): def __init__(self, stimulus_values, save_dir=None, recalculate=False):
self.save_file_name = "fi_curve_values.pkl" self.save_file_name = "fi_curve_values.pkl"
self.stimulus_values = stimulus_values self.stimulus_values = stimulus_values
@ -29,6 +29,10 @@ class FICurve:
if save_dir is None: if save_dir is None:
self.initialize() self.initialize()
else:
if recalculate:
self.initialize()
self.save_values(save_dir)
else: else:
if not self.load_values(save_dir): if not self.load_values(save_dir):
self.initialize() self.initialize()
@ -253,7 +257,7 @@ class FICurve:
values = pickle.load(file) values = pickle.load(file)
if set(values["stimulus_values"]) != set(self.stimulus_values): if set(values["stimulus_values"]) != set(self.stimulus_values):
stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n " stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n "
"given: {}\n loaded: {}".format(str(self.stimulus_values), str(values["stimulus_values"]))) "given: {}\n loaded: {}\n".format(str(self.stimulus_values), str(values["stimulus_values"])))
self.stimulus_values = values["stimulus_values"] self.stimulus_values = values["stimulus_values"]
self.f_baseline_frequencies = values["f_baseline_frequencies"] self.f_baseline_frequencies = values["f_baseline_frequencies"]
@ -267,9 +271,9 @@ class FICurve:
class FICurveCellData(FICurve): class FICurveCellData(FICurve):
def __init__(self, cell_data: CellData, stimulus_values, save_dir=None): def __init__(self, cell_data: CellData, stimulus_values, save_dir=None, recalculate=False):
self.cell_data = cell_data self.cell_data = cell_data
super().__init__(stimulus_values, save_dir) super().__init__(stimulus_values, save_dir, recalculate)
def calculate_all_frequency_points(self): def calculate_all_frequency_points(self):
mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies() mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
@ -509,9 +513,9 @@ class FICurveModel(FICurve):
plt.close() plt.close()
def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None) -> FICurve: def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None, recalculate=False) -> FICurve:
if isinstance(data, CellData): if isinstance(data, CellData):
return FICurveCellData(data, stimulus_values, save_dir) return FICurveCellData(data, stimulus_values, save_dir, recalculate)
if isinstance(data, LifacNoiseModel): if isinstance(data, LifacNoiseModel):
if eod_freq is None: if eod_freq is None:
raise ValueError("The FiCurveModel needs the eod variable to work") raise ValueError("The FiCurveModel needs the eod variable to work")