add refratory period to fitting, add burstiness error

This commit is contained in:
alexanderott 2020-05-27 13:36:14 +02:00
parent 44c9024f1f
commit 2a55078894
6 changed files with 125 additions and 292 deletions

View File

@ -27,6 +27,53 @@ class Baseline:
def get_coefficient_of_variation(self): def get_coefficient_of_variation(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_burstiness(self):
isis = np.array(self.get_interspike_intervals()) * 1000 # change unit to ms
if len(isis) <= 10:
return 0
step = 0.1
bins = np.arange(0, min(isis) * 3, step)
num_spikes_per_bin = np.zeros(bins.shape)
for i, bin in enumerate(bins):
num_of_spikes = np.sum(isis[(isis >= bin) & (isis < bin + step)])
num_spikes_per_bin[i] = num_of_spikes
max_found = -1
end_of_peak = -1
if max(num_spikes_per_bin) < 10:
return 0
for i, num in enumerate(num_spikes_per_bin):
if i + 1 >= len(num_spikes_per_bin):
return 0
if max_found == -1:
if num_spikes_per_bin[i+1] > num:
continue
elif num > 10:
max_found = i
else:
if num_spikes_per_bin[i + 1] > num:
end_of_peak = i +1
break
burstiness = sum(num_spikes_per_bin[:end_of_peak]) / len(isis)
# bins = np.arange(0, max(isis) * 1.01, 0.1)
#
# plt.title('Baseline ISIs - burstiness {:.2f}'.format(burstiness))
# plt.xlabel('ISI in ms')
# plt.ylabel('Count')
# plt.hist(isis, bins=bins)
# plt.plot((0.5*step, bins[end_of_peak-1] + 0.5*step,), (0, 0), 'o')
# plt.show()
return burstiness
def get_interspike_intervals(self): def get_interspike_intervals(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
@ -298,12 +345,12 @@ class BaselineModel(Baseline):
save_path, position, time_length) save_path, position, time_length)
def get_baseline_class(data, eod_freq=None) -> Baseline: def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
if isinstance(data, CellData): if isinstance(data, CellData):
return BaselineCellData(data) return BaselineCellData(data)
if isinstance(data, LifacNoiseModel): if isinstance(data, LifacNoiseModel):
if eod_freq is None: if eod_freq is None:
raise ValueError("The EOD frequency is needed for the BaselineModel Class.") raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
return BaselineModel(data, eod_freq) return BaselineModel(data, eod_freq, trials=trials)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data))) raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))

View File

@ -10,7 +10,14 @@ def icelldata_of_dir(base_path):
item_path = base_path + item item_path = base_path + item
try: try:
yield CellData(item_path) data = CellData(item_path)
trace = data.get_base_traces(trace_type=data.V1)
if len(trace) == 0:
print("NO V1 TRACE FOUND: ", item_path)
continue
else:
yield data
except TypeError as e: except TypeError as e:
warn_msg = str(e) warn_msg = str(e)
warn(warn_msg) warn(warn_msg)

View File

@ -456,12 +456,12 @@ class FICurveModel(FICurve):
plt.close() plt.close()
def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve: def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5) -> FICurve:
if isinstance(data, CellData): if isinstance(data, CellData):
return FICurveCellData(data, stimulus_values) return FICurveCellData(data, stimulus_values)
if isinstance(data, LifacNoiseModel): if isinstance(data, LifacNoiseModel):
if eod_freq is None: if eod_freq is None:
raise ValueError("The FiCurveModel needs the eod variable to work") raise ValueError("The FiCurveModel needs the eod variable to work")
return FICurveModel(data, stimulus_values, eod_freq) return FICurveModel(data, stimulus_values, eod_freq, trials=trials)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data))) raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))

274
Fitter.py
View File

@ -1,208 +1,13 @@
from models.LIFACnoise import LifacNoiseModel from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from CellData import CellData, icelldata_of_dir from CellData import CellData
from Baseline import get_baseline_class from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class from FiCurve import get_fi_curve_class
from AdaptionCurrent import Adaption from AdaptionCurrent import Adaption
import numpy as np import numpy as np
from warnings import warn from warnings import warn
from scipy.optimize import minimize from scipy.optimize import minimize
import time
import os
SAVE_PATH_PREFIX = ""
FIT_ROUTINE = ""
def main():
# fitter = Fitter()
# run_with_real_data(fitter, fitter.fit_routine_3)
test_fit_routines()
def test_fit_routines():
fitter = Fitter()
names = ("routine_1", "routine_2", "routine_3")
global FIT_ROUTINE
for i, routine in enumerate([fitter.fit_routine_1, fitter.fit_routine_2, fitter.fit_routine_3]):
FIT_ROUTINE = names[i]
run_with_real_data(fitter, routine)
best = []
cells = sorted(os.listdir("test_routines/" + names[0] + "/"))
for name in names:
save_path = "test_routines/" + name + "/"
cell_best = []
for directory in sorted(os.listdir(save_path)):
path = os.path.join(save_path, directory)
if os.path.isdir(path):
cell_best.append(find_best_run(path))
best.append(cell_best)
with open("test_routines/comparision.csv", "w") as res_file:
res_file.write("routine")
for cell in cells:
res_file.write("," + cell)
for i, routine_results in enumerate(best):
res_file.write(names[i])
for cell_best in routine_results:
res_file.write("," + str(cell_best))
def find_best_run(cell_path):
values = []
for directory in sorted(os.listdir(cell_path)):
start_par_path = os.path.join(cell_path, directory)
if os.path.isdir(start_par_path):
values.append(float(start_par_path.split("_")[-1]))
return min(values)
def iget_start_parameters():
# mem_tau, input_scaling, noise_strength, dend_tau,
# expand by tau_a, delta_a ?
mem_tau_list = [0.01]
input_scaling_list = [40, 60]
noise_strength_list = [0.03] # [0.02, 0.06]
dend_tau_list = [0.001, 0.002]
delta_a_list = [0.035, 0.065]
for mem_tau in mem_tau_list:
for input_scaling in input_scaling_list:
for noise_strength in noise_strength_list:
for dend_tau in dend_tau_list:
for delta_a in delta_a_list:
yield {"mem_tau": mem_tau, "input_scaling": input_scaling,
"noise_strength": noise_strength, "dend_tau": dend_tau,
"delta_a": delta_a}
def run_with_real_data(fitter, fit_routine_func, parallel=False):
count = 0
for cell_data in icelldata_of_dir("./data/"):
count += 1
if count < 7:
pass
#continue
print("cell:", cell_data.get_data_path())
trace = cell_data.get_base_traces(trace_type=cell_data.V1)
if len(trace) == 0:
print("NO V1 TRACE FOUND")
continue
global FIT_ROUTINE
# results_path = "results/" + os.path.split(cell_data.get_data_path())[-1] + "/"
results_path = "test_routines/" + FIT_ROUTINE + "/" + os.path.split(cell_data.get_data_path())[-1] + "/"
print("results at:", results_path)
if not os.path.exists(results_path):
os.makedirs(results_path)
# plot cell images:
cell_save_path = results_path + "cell/"
if not os.path.exists(cell_save_path):
os.makedirs(cell_save_path)
data_baseline = get_baseline_class(cell_data)
data_baseline.plot_baseline(cell_save_path)
data_baseline.plot_interspike_interval_histogram(cell_save_path)
data_baseline.plot_serial_correlation(6, cell_save_path)
data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts())
data_fi_curve.plot_fi_curve(cell_save_path)
start_par_count = 0
for start_parameters in iget_start_parameters():
start_par_count += 1
print("START PARAMETERS:", start_par_count)
start_time = time.time()
# fitter = Fitter()
fmin, parameters = fitter.fit_model_to_data(cell_data, start_parameters, fit_routine_func)
print(fmin)
print(parameters)
end_time = time.time()
parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/"
if not os.path.exists(parameter_set_path):
os.makedirs(parameter_set_path)
with open(parameter_set_path + "parameters_info.txt".format(start_par_count), "w") as file:
file.writelines(["start_parameters:\t" + str(start_parameters),
"\nfinal_parameters:\t" + str(parameters),
"\nfinal_fmin:\t" + str(fmin)])
print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time)))
# print(results_path)
print_comparision_cell_model(cell_data, data_baseline, data_fi_curve, parameters,
plot=True, save_path=parameter_set_path)
# from Sounds import play_finished_sound
# play_finished_sound()
pass
def print_comparision_cell_model(cell_data, data_baseline, data_fi_curve, parameters, plot=False, save_path=None):
model = LifacNoiseModel(parameters)
eod_frequency = cell_data.get_eod_frequency()
model_baseline = get_baseline_class(model, eod_frequency)
m_bf = model_baseline.get_baseline_frequency()
m_vs = model_baseline.get_vector_strength()
m_sc = model_baseline.get_serial_correlation(1)
m_cv = model_baseline.get_coefficient_of_variation()
model_ficurve = get_fi_curve_class(model, cell_data.get_fi_contrasts(), eod_frequency)
m_f_infinities = model_ficurve.get_f_inf_frequencies()
m_f_zeros = model_ficurve.get_f_zero_frequencies()
m_f_infinities_slope = model_ficurve.get_f_inf_slope()
m_f_zero_slope = model_ficurve.get_f_zero_fit_slope_at_straight()
c_bf = data_baseline.get_baseline_frequency()
c_vs = data_baseline.get_vector_strength()
c_sc = data_baseline.get_serial_correlation(1)
c_cv = data_baseline.get_coefficient_of_variation()
c_f_inf_slope = data_fi_curve.get_f_inf_slope()
c_f_inf_values = data_fi_curve.f_inf_frequencies
c_f_zero_slope = data_fi_curve.get_f_zero_fit_slope_at_straight()
c_f_zero_values = data_fi_curve.f_zero_frequencies
print("EOD-frequency: {:.2f}".format(cell_data.get_eod_frequency()))
print("bf: cell - {:.2f} vs model {:.2f}".format(c_bf, m_bf))
print("vs: cell - {:.2f} vs model {:.2f}".format(c_vs, m_vs))
print("sc: cell - {:.2f} vs model {:.2f}".format(c_sc[0], m_sc[0]))
print("cv: cell - {:.2f} vs model {:.2f}".format(c_cv, m_cv))
print("f_inf_slope: cell - {:.2f} vs model {:.2f}".format(c_f_inf_slope, m_f_infinities_slope))
print("f infinity values:\n cell -", c_f_inf_values, "\n model -", m_f_infinities)
print("f_zero_slope: cell - {:.2f} vs model {:.2f}".format(c_f_zero_slope, m_f_zero_slope))
print("f zero values:\n cell -", c_f_zero_values, "\n model -", m_f_zeros)
if save_path is not None:
with open(save_path + "value_comparision.tsv", 'w') as value_file:
value_file.write("Variable\tCell\tModel\n")
value_file.write("baseline_frequency\t{:.2f}\t{:.2f}\n".format(c_bf, m_bf))
value_file.write("vector_strength\t{:.2f}\t{:.2f}\n".format(c_vs, m_vs))
value_file.write("serial_correlation\t{:.2f}\t{:.2f}\n".format(c_sc[0], m_sc[0]))
value_file.write("coefficient_of_variation\t{:.2f}\t{:.2f}\n".format(c_cv, m_cv))
value_file.write("f_inf_slope\t{:.2f}\t{:.2f}\n".format(c_f_inf_slope, m_f_infinities_slope))
value_file.write("f_zero_slope\t{:.2f}\t{:.2f}\n".format(c_f_zero_slope, m_f_zero_slope))
if plot:
# plot model images
model_baseline.plot_baseline(save_path)
model_baseline.plot_interspike_interval_histogram(save_path)
model_baseline.plot_serial_correlation(6, save_path)
model_ficurve.plot_fi_curve(save_path)
model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path)
class Fitter: class Fitter:
@ -219,13 +24,14 @@ class Fitter:
self.fi_contrasts = [] self.fi_contrasts = []
self.eod_freq = 0 self.eod_freq = 0
self.sc_max_lag = 1 self.sc_max_lag = 2
# values to be replicated: # values to be replicated:
self.baseline_freq = 0 self.baseline_freq = 0
self.vector_strength = -1 self.vector_strength = -1
self.serial_correlation = [] self.serial_correlation = []
self.coefficient_of_variation = 0 self.coefficient_of_variation = 0
self.burstiness = -1
self.f_inf_values = [] self.f_inf_values = []
self.f_inf_slope = 0 self.f_inf_slope = 0
@ -249,6 +55,7 @@ class Fitter:
self.vector_strength = data_baseline.get_vector_strength() self.vector_strength = data_baseline.get_vector_strength()
self.serial_correlation = data_baseline.get_serial_correlation(self.sc_max_lag) self.serial_correlation = data_baseline.get_serial_correlation(self.sc_max_lag)
self.coefficient_of_variation = data_baseline.get_coefficient_of_variation() self.coefficient_of_variation = data_baseline.get_coefficient_of_variation()
self.burstiness = data_baseline.get_burstiness()
fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts()) fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts())
self.fi_contrasts = fi_curve.stimulus_values self.fi_contrasts = fi_curve.stimulus_values
@ -277,54 +84,27 @@ class Fitter:
x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"], x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
start_parameters["input_scaling"], self.tau_a, start_parameters["delta_a"], start_parameters["input_scaling"], self.tau_a, start_parameters["delta_a"],
start_parameters["dend_tau"]]) start_parameters["dend_tau"], start_parameters["refractory_period"]])
initial_simplex = create_init_simples(x0, search_scale=2) initial_simplex = create_init_simples(x0, search_scale=2)
# error_list = [error_bf, error_vs, error_sc, error_cv, # error_list = [error_bf, error_vs, error_sc, error_cv,
# error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] # error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_weights = (0, 1, 1, 1, 1, 1, 1, 1) error_weights = (0, 1, 1, 1, 1, 1, 1, 1, 1)
fmin = minimize(fun=self.cost_function_all, fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead", args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 200, "maxiter": 400}) options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 200, "maxiter": 400})
return fmin, self.base_model.get_parameters() return fmin, self.base_model.get_parameters()
# similar results to fit routine 1
def fit_routine_2(self, start_parameters): def fit_routine_2(self, start_parameters):
self.counter = 0 self.counter = 0
# fit only v_offset, mem_tau, input_scaling, dend_tau
x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
start_parameters["input_scaling"], self.tau_a, start_parameters["delta_a"],
start_parameters["dend_tau"]])
initial_simplex = create_init_simples(x0, search_scale=2)
# error_list = [error_bf, error_vs, error_sc, error_cv,
# error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_weights = (0, 2, 2, 2, 1, 1, 1, 1)
fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400})
best_pars = fmin.x
x0 = np.array([best_pars[0], best_pars[2], # mem_tau, input_scaling
best_pars[4], best_pars[5]]) # delta_a, dend_tau
initial_simplex = create_init_simples(x0, search_scale=2)
error_weights = (0, 1, 1, 1, 3, 2, 3, 2)
fmin = minimize(fun=self.cost_function_only_adaption,
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400})
return fmin, self.base_model.get_parameters()
def fit_routine_3(self, start_parameters):
self.counter = 0
x0 = np.array([start_parameters["mem_tau"], start_parameters["input_scaling"], # mem_tau, input_scaling x0 = np.array([start_parameters["mem_tau"], start_parameters["input_scaling"], # mem_tau, input_scaling
start_parameters["delta_a"], start_parameters["dend_tau"]]) # delta_a, dend_tau start_parameters["delta_a"], start_parameters["dend_tau"]]) # delta_a, dend_tau
initial_simplex = create_init_simples(x0, search_scale=2) initial_simplex = create_init_simples(x0, search_scale=2)
error_weights = (0, 1, 1, 1, 3, 2, 3, 2) error_weights = (0, 1, 1, 1, 1, 3, 2, 3, 2)
fmin = minimize(fun=self.cost_function_only_adaption, fmin = minimize(fun=self.cost_function_only_adaption,
args=(error_weights,), x0=x0, method="Nelder-Mead", args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400}) options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400})
@ -335,7 +115,7 @@ class Fitter:
initial_simplex = create_init_simples(x0, search_scale=2) initial_simplex = create_init_simples(x0, search_scale=2)
# error_list = [error_bf, error_vs, error_sc, error_cv, # error_list = [error_bf, error_vs, error_sc, error_cv,
# error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] # error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_weights = (0, 2, 2, 2, 1, 1, 1, 1) error_weights = (0, 2, 2, 2, 2, 1, 1, 1, 1)
fmin = minimize(fun=self.cost_function_all, fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead", args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400}) options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 100, "maxiter": 400})
@ -349,6 +129,7 @@ class Fitter:
self.base_model.set_variable("tau_a", X[3]) self.base_model.set_variable("tau_a", X[3])
self.base_model.set_variable("delta_a", X[4]) self.base_model.set_variable("delta_a", X[4])
self.base_model.set_variable("dend_tau", X[5]) self.base_model.set_variable("dend_tau", X[5])
self.base_model.set_variable("refractory_period", X[6])
base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0) base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
# find right v-offset # find right v-offset
@ -471,6 +252,8 @@ class Fitter:
vector_strength = model_baseline.get_vector_strength() vector_strength = model_baseline.get_vector_strength()
serial_correlation = model_baseline.get_serial_correlation(self.sc_max_lag) serial_correlation = model_baseline.get_serial_correlation(self.sc_max_lag)
coefficient_of_variation = model_baseline.get_coefficient_of_variation() coefficient_of_variation = model_baseline.get_coefficient_of_variation()
burstiness = model_baseline.get_burstiness()
fi_curve_model = get_fi_curve_class(model, self.fi_contrasts, self.eod_freq) fi_curve_model = get_fi_curve_class(model, self.fi_contrasts, self.eod_freq)
f_zeros = fi_curve_model.get_f_zero_frequencies() f_zeros = fi_curve_model.get_f_zero_frequencies()
@ -483,11 +266,14 @@ class Fitter:
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
error_vs = abs((vector_strength - self.vector_strength) / 0.1) error_vs = abs((vector_strength - self.vector_strength) / 0.1)
error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.1) error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.1)
error_bursty = (abs(burstiness - self.burstiness) / 0.02)
error_sc = 0 error_sc = 0
for i in range(self.sc_max_lag): for i in range(self.sc_max_lag):
error_sc = abs((serial_correlation[i] - self.serial_correlation[i]) / 0.1) error_sc += abs((serial_correlation[i] - self.serial_correlation[i]) / 0.1)
error_sc = error_sc / self.sc_max_lag # error_sc = error_sc / self.sc_max_lag
error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / (self.f_inf_slope/20)) error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / (self.f_inf_slope/20))
error_f_inf = calculate_list_error(f_infinities, self.f_inf_values) error_f_inf = calculate_list_error(f_infinities, self.f_inf_values)
@ -497,7 +283,7 @@ class Fitter:
/ (self.f_zero_slope_at_straight / 10) / (self.f_zero_slope_at_straight / 10)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) error_f_zero = calculate_list_error(f_zeros, self.f_zero_values)
error_list = [error_bf, error_vs, error_sc, error_cv, error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight] error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight]
if error_weights is not None and len(error_weights) == len(error_list): if error_weights is not None and len(error_weights) == len(error_list):
@ -506,28 +292,6 @@ class Fitter:
elif error_weights is not None: elif error_weights is not None:
warn("Error: weights had different length than errors and were ignored!") warn("Error: weights had different length than errors and were ignored!")
# error = sum(error_list)
# self.counter += 1
# if self.counter % 200 == 0: # and False:
# print("\nCost function run times: {:}\n".format(self.counter),
# "Total weighted error: {:.4f}\n".format(error),
# "Baseline frequency - expected: {:.0f}, current: {:.0f}, error: {:.3f}\n".format(
# self.baseline_freq, baseline_freq, error_bf),
# "Vector strength - expected: {:.2f}, current: {:.2f}, error: {:.3f}\n".format(
# self.vector_strength, vector_strength, error_vs),
# "Serial correlation - expected: {:.2f}, current: {:.2f}, error: {:.3f}\n".format(
# self.serial_correlation[0], serial_correlation[0], error_sc),
# "Coefficient of variation - expected: {:.2f}, current: {:.2f}, error: {:.3f}\n".format(
# self.coefficient_of_variation, coefficient_of_variation, error_cv),
# "f-infinity slope - expected: {:.0f}, current: {:.0f}, error: {:.3f}\n".format(
# self.f_inf_slope, f_infinities_slope, error_f_inf_slope),
# "f-infinity values:\nexpected:", np.around(self.f_inf_values), "\ncurrent: ", np.around(f_infinities),
# "\nerror: {:.3f}\n".format(error_f_inf),
# "f-zero slope - expected: {:.0f}, current: {:.0f}, error: {:.3f}\n".format(
# self.f_zero_slope_at_straight, f_zero_slope_at_straight, error_f_zero_slope_at_straight),
# "f-zero values:\nexpected:", np.around(self.f_zero_values), "\ncurrent: ", np.around(f_zeros),
# "\nerror: {:.3f}".format(error_f_zero))
return error_list return error_list
@ -559,4 +323,4 @@ def create_init_simples(x0, search_scale=3.):
if __name__ == '__main__': if __name__ == '__main__':
main() print("use run_fitter.py to run the Fitter.")

View File

@ -1,6 +1,6 @@
from models.LIFACnoise import LifacNoiseModel from models.LIFACnoise import LifacNoiseModel
from CellData import CellData, icelldata_of_dir from CellData import icelldata_of_dir
from Baseline import get_baseline_class from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class from FiCurve import get_fi_curve_class
from Fitter import Fitter from Fitter import Fitter
@ -11,24 +11,25 @@ import os
import multiprocessing as mp import multiprocessing as mp
SAVE_PATH_PREFIX = ""
FIT_ROUTINE = ""
def main(): def main():
count = 0 count = 0
for data in icelldata_of_dir("./data/"): for data in icelldata_of_dir("./data/"):
count += 1 count += 1
if count <= 3: # if count <= 3:
continue # continue
trace = data.get_base_traces(trace_type=data.V1)
if len(trace) == 0:
print("NO V1 TRACE FOUND")
continue
fit_cell_parrallel(data, [p for p in iget_start_parameters()]) fit_cell_parrallel(data, [p for p in iget_start_parameters()])
break
def fit_cell_parrallel(cell_data, start_parameters): def fit_cell_parrallel(cell_data, start_parameters):
cell_path = os.path.basename(cell_data.get_data_path()) cell_path = os.path.basename(cell_data.get_data_path())
print(cell_path) print(cell_path)
core_count = mp.cpu_count() core_count = mp.cpu_count()
pool = mp.Pool(core_count - 3) pool = mp.Pool(core_count - 1)
fitter = Fitter() fitter = Fitter()
fitter.set_data_reference_values(cell_data) fitter.set_data_reference_values(cell_data)
@ -38,14 +39,16 @@ def fit_cell_parrallel(cell_data, start_parameters):
print("Time taken for all start parameters ({:}): {:.2f}s".format(len(start_parameters), time2-time1)) print("Time taken for all start parameters ({:}): {:.2f}s".format(len(start_parameters), time2-time1))
for i, (fmin, fin_pars) in enumerate(outputs): for i, (fmin, fin_pars) in enumerate(outputs):
error = fitter.calculate_errors(model=LifacNoiseModel(fin_pars)) error = fitter.calculate_errors(model=LifacNoiseModel(fin_pars))
print_comparision_cell_model(cell_data, fin_pars, plot=True, save_path="./test_routines/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(i+1, sum(error))) save_path = "./test_routines/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(i+1, sum(error))
save_fitting_run_info(cell_data, fin_pars, start_parameters[i],
plot=True, save_path=save_path)
def test_fit_routines(): def test_fit_routines():
fitter = Fitter() fitter = Fitter()
names = ("routine_1", "routine_2", "routine_3") names = ("routine_1", "routine_2")
global FIT_ROUTINE global FIT_ROUTINE
for i, routine in enumerate([fitter.fit_routine_1, fitter.fit_routine_2, fitter.fit_routine_3]): for i, routine in enumerate([fitter.fit_routine_1, fitter.fit_routine_2]):
FIT_ROUTINE = names[i] FIT_ROUTINE = names[i]
run_with_real_data(fitter, routine) run_with_real_data(fitter, routine)
@ -91,15 +94,17 @@ def iget_start_parameters():
noise_strength_list = [0.03] # [0.02, 0.06] noise_strength_list = [0.03] # [0.02, 0.06]
dend_tau_list = [0.001, 0.002] dend_tau_list = [0.001, 0.002]
delta_a_list = [0.035, 0.065] delta_a_list = [0.035, 0.065]
ref_time_list = [0.00065]
for mem_tau in mem_tau_list: for mem_tau in mem_tau_list:
for input_scaling in input_scaling_list: for input_scaling in input_scaling_list:
for noise_strength in noise_strength_list: for noise_strength in noise_strength_list:
for dend_tau in dend_tau_list: for dend_tau in dend_tau_list:
for delta_a in delta_a_list: for delta_a in delta_a_list:
for ref_time in ref_time_list:
yield {"mem_tau": mem_tau, "input_scaling": input_scaling, yield {"mem_tau": mem_tau, "input_scaling": input_scaling,
"noise_strength": noise_strength, "dend_tau": dend_tau, "noise_strength": noise_strength, "dend_tau": dend_tau,
"delta_a": delta_a} "delta_a": delta_a, "refractory_period": ref_time}
def run_with_real_data(fitter, fit_routine_func, parallel=False): def run_with_real_data(fitter, fit_routine_func, parallel=False):
@ -149,16 +154,10 @@ def run_with_real_data(fitter, fit_routine_func, parallel=False):
print(parameters) print(parameters)
end_time = time.time() end_time = time.time()
parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/" parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/"
if not os.path.exists(parameter_set_path):
os.makedirs(parameter_set_path)
with open(parameter_set_path + "parameters_info.txt".format(start_par_count), "w") as file:
file.writelines(["start_parameters:\t" + str(start_parameters),
"\nfinal_parameters:\t" + str(parameters),
"\nfinal_fmin:\t" + str(fmin)])
print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time))) print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time)))
# print(results_path) # print(results_path)
print_comparision_cell_model(cell_data, parameters, save_fitting_run_info(cell_data, parameters, start_parameters,
plot=True, save_path=parameter_set_path) plot=True, save_path=parameter_set_path)
# from Sounds import play_finished_sound # from Sounds import play_finished_sound
@ -166,10 +165,15 @@ def run_with_real_data(fitter, fit_routine_func, parallel=False):
pass pass
def print_comparision_cell_model(cell_data, parameters, plot=False, save_path=None): def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, save_path=None):
if save_path is not None: if save_path is not None:
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
with open(save_path + "parameters_info.txt", "w") as file:
file.writelines(["start_parameters:\t" + str(start_parameters),
"\nfinal_parameters:\t" + str(parameters)])
model = LifacNoiseModel(parameters) model = LifacNoiseModel(parameters)
eod_frequency = cell_data.get_eod_frequency() eod_frequency = cell_data.get_eod_frequency()
@ -178,6 +182,7 @@ def print_comparision_cell_model(cell_data, parameters, plot=False, save_path=No
m_vs = model_baseline.get_vector_strength() m_vs = model_baseline.get_vector_strength()
m_sc = model_baseline.get_serial_correlation(1) m_sc = model_baseline.get_serial_correlation(1)
m_cv = model_baseline.get_coefficient_of_variation() m_cv = model_baseline.get_coefficient_of_variation()
m_burst = model_baseline.get_burstiness()
model_ficurve = get_fi_curve_class(model, cell_data.get_fi_contrasts(), eod_frequency) model_ficurve = get_fi_curve_class(model, cell_data.get_fi_contrasts(), eod_frequency)
m_f_infinities = model_ficurve.get_f_inf_frequencies() m_f_infinities = model_ficurve.get_f_inf_frequencies()
@ -190,6 +195,7 @@ def print_comparision_cell_model(cell_data, parameters, plot=False, save_path=No
c_vs = data_baseline.get_vector_strength() c_vs = data_baseline.get_vector_strength()
c_sc = data_baseline.get_serial_correlation(1) c_sc = data_baseline.get_serial_correlation(1)
c_cv = data_baseline.get_coefficient_of_variation() c_cv = data_baseline.get_coefficient_of_variation()
c_burst = data_baseline.get_burstiness()
data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts()) data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts())
c_f_inf_slope = data_fi_curve.get_f_inf_slope() c_f_inf_slope = data_fi_curve.get_f_inf_slope()
@ -213,6 +219,7 @@ def print_comparision_cell_model(cell_data, parameters, plot=False, save_path=No
value_file.write("baseline_frequency\t{:.2f}\t{:.2f}\n".format(c_bf, m_bf)) value_file.write("baseline_frequency\t{:.2f}\t{:.2f}\n".format(c_bf, m_bf))
value_file.write("vector_strength\t{:.2f}\t{:.2f}\n".format(c_vs, m_vs)) value_file.write("vector_strength\t{:.2f}\t{:.2f}\n".format(c_vs, m_vs))
value_file.write("serial_correlation\t{:.2f}\t{:.2f}\n".format(c_sc[0], m_sc[0])) value_file.write("serial_correlation\t{:.2f}\t{:.2f}\n".format(c_sc[0], m_sc[0]))
value_file.write("Burstiness\t{:.2f}\t{:.2f}\n".format(c_burst, m_burst))
value_file.write("coefficient_of_variation\t{:.2f}\t{:.2f}\n".format(c_cv, m_cv)) value_file.write("coefficient_of_variation\t{:.2f}\t{:.2f}\n".format(c_cv, m_cv))
value_file.write("f_inf_slope\t{:.2f}\t{:.2f}\n".format(c_f_inf_slope, m_f_infinities_slope)) value_file.write("f_inf_slope\t{:.2f}\t{:.2f}\n".format(c_f_inf_slope, m_f_infinities_slope))
value_file.write("f_zero_slope\t{:.2f}\t{:.2f}\n".format(c_f_zero_slope, m_f_zero_slope)) value_file.write("f_zero_slope\t{:.2f}\t{:.2f}\n".format(c_f_zero_slope, m_f_zero_slope))

View File

@ -3,17 +3,25 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from DataParserFactory import get_parser from DataParserFactory import get_parser
import pprint import pprint
from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class
from CellData import icelldata_of_dir
from models.LIFACnoise import LifacNoiseModel
parameter_bursty_model = {'step_size': 5e-05, 'mem_tau': 0.0066693150193490695, 'v_base': 0, 'v_zero': 0,
'threshold': 1, 'v_offset': -45.703125, 'input_scaling': 172.13861987237314,
'delta_a': 0.06148215166012024, 'tau_a': 0.03391674075000068, 'a_zero': 2,
'noise_strength': 0.0684136549210377, 'dend_tau': 0.0013694103932013805,
'refractory_period': 0.001}
eod = 752
model = LifacNoiseModel(parameter_bursty_model)
baseline_model = get_baseline_class(model, 752, trials=2)
baseline_model.get_burstiness()
base_path = "../data/2012-06-27-an-invivo-1" quit()
fi_file = base_path + "/fispikes1.dat"
baseline_file = base_path + "/basespikes1.dat"
sam_file = base_path + "/samallspikes1.dat"
stimuli_file = base_path + "/stimuli.dat"
parser = get_parser(base_path)
spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes = parser.__get_sam_spiketimes__() for cell_data in icelldata_of_dir("../data/"):
baseline = get_baseline_class(cell_data)
baseline.get_burstiness()
print(eod_freqs)