save graphs during fitting process

This commit is contained in:
a.ott 2020-05-11 15:28:03 +02:00
parent 11c37c9f2a
commit c1db288f97
3 changed files with 49 additions and 102 deletions

View File

@ -59,7 +59,7 @@ class Baseline:
plt.title("Baseline Serial correlation")
plt.xlabel("Lag")
plt.ylabel("Correlation")
plt.ylim((-1, 1))
plt.plot(np.arange(1,max_lag+1, 1), self.get_serial_correlation(max_lag))
if save_path is not None:
@ -144,12 +144,12 @@ class BaselineCellData(Baseline):
v1_traces = self.data.get_base_traces(self.data.V1)
spiketimes = self.data.get_base_spikes()
fig, axes = plt.subplots(4, 1, sharex="True")
fig, axes = plt.subplots(4, 1, sharex='col')
for i in range(len(times)):
axes[0].plot(times[i], eods[i])
axes[1].plot(times[i], v1_traces[i])
axes[2].plot(spiketimes, [1]*len(spiketimes), 'o')
axes[2].plot(spiketimes[i], [1 for i in range(len(spiketimes[i]))], 'o')
t, f = hF.calculate_time_and_frequency_trace(spiketimes[i], self.data.get_sampling_interval())
axes[3].plot(t, f)
@ -203,11 +203,11 @@ class BaselineModel(Baseline):
def plot_baseline(self, save_path=None):
# eod, v1, spiketimes, frequency
fig, axes = plt.subplots(4, 1, sharex="True")
fig, axes = plt.subplots(4, 1, sharex="col")
axes[0].plot(self.time, self.eod)
axes[1].plot(self.time, self.v1)
axes[2].plot(self.spiketimes, [1]*len(self.spiketimes), 'o')
axes[2].plot(self.spiketimes, [1 for i in range(len(self.spiketimes))], 'o')
t, f = hF.calculate_time_and_frequency_trace(self.spiketimes, self.model.get_sampling_interval())
axes[3].plot(t, f)

View File

@ -61,88 +61,6 @@ class FICurve:
stimulus_start, stimulus_duration, sampling_interval)
self.f_infinities.append(f_infinity)
# def __calculate_f_baseline__(self, time, frequency, buffer=0.025):
#
# stim_start = self.cell_data.get_stimulus_start() - time[0]
# sampling_interval = self.cell_data.get_sampling_interval()
# if stim_start < 0.1:
# warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
#
# start_idx = 0
# end_idx = int((stim_start-buffer)/sampling_interval)
# f_baseline = np.mean(frequency[start_idx:end_idx])
#
# return f_baseline
#
# def __calculate_f_zero__(self, time, frequency, peak_buffer_percent=0.05, buffer=0.025):
#
# stimulus_start = self.cell_data.get_stimulus_start() - time[0] # time start is generally != 0 and != delay
# sampling_interval = self.cell_data.get_sampling_interval()
#
# freq_before = frequency[0:int((stimulus_start - buffer) / sampling_interval)]
# min_before = min(freq_before)
# max_before = max(freq_before)
# mean_before = np.mean(freq_before)
#
# # time where the f-zero is searched in
# start_idx = int((stimulus_start-0.1*buffer) / sampling_interval)
# end_idx = int((stimulus_start + buffer) / sampling_interval)
#
# min_during_start_of_stim = min(frequency[start_idx:end_idx])
# max_during_start_of_stim = max(frequency[start_idx:end_idx])
#
# if abs(mean_before-min_during_start_of_stim) > abs(max_during_start_of_stim-mean_before):
# f_zero = min_during_start_of_stim
# else:
# f_zero = max_during_start_of_stim
#
# peak_buffer = (max_before - min_before) * peak_buffer_percent
# if min_before - peak_buffer <= f_zero <= max_before + peak_buffer:
# end_idx = start_idx + int((end_idx-start_idx)/2)
# f_zero = np.mean(frequency[start_idx:end_idx])
#
# return f_zero
#
# # start_idx = int(stimulus_start / sampling_interval)
# # end_idx = int((stimulus_start + buffer*2) / sampling_interval)
# #
# # freq_before = frequency[start_idx-(int(length_of_mean/sampling_interval)):start_idx]
# # fb_mean = np.mean(freq_before)
# # fb_std = np.std(freq_before)
# #
# # peak_frequency = fb_mean
# # count = 0
# # for i in range(start_idx + 1, end_idx):
# # if fb_mean-3*fb_std <= frequency[i] <= fb_mean+3*fb_std:
# # continue
# #
# # if abs(frequency[i] - fb_mean) > abs(peak_frequency - fb_mean):
# # peak_frequency = frequency[i]
# # count += 1
#
# # return peak_frequency
#
# def __calculate_f_infinity__(self, time, frequency, length=0.1, buffer=0.025):
# stimulus_end_time = self.cell_data.get_stimulus_start() + self.cell_data.get_stimulus_duration() - time[0]
#
# start_idx = int((stimulus_end_time - length - buffer) / self.cell_data.get_sampling_interval())
# end_idx = int((stimulus_end_time - buffer) / self.cell_data.get_sampling_interval())
#
# # TODO add way to plot detected f_zero, f_inf, f_base. With detection of remaining slope?
# # x = np.arange(start_idx, end_idx, 1) # time[start_idx:end_idx]
# # slope, intercept, r_value, p_value, std_err = linregress(x, frequency[start_idx:end_idx])
# # if p_value < 0.0001:
# # plt.title("significant slope: {:.2f}, p: {:.5f}, r: {:.5f}".format(slope, p_value, r_value))
# # plt.plot(x, [i*slope + intercept for i in x], color="black")
# #
# #
# # plt.plot((start_idx, end_idx), (np.mean(frequency[start_idx:end_idx]), np.mean(frequency[start_idx:end_idx])), label="f_inf")
# # plt.legend()
# # plt.show()
# # plt.close()
#
# return np.mean(frequency[start_idx:end_idx])
def get_f_zero_inverse_at_frequency(self, frequency):
b_vars = self.boltzmann_fit_vars
return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])

View File

@ -52,6 +52,7 @@ def run_with_real_data():
results_path = "results/" + os.path.split(cell_data.get_data_path())[-1] + "/"
print("results at:", results_path)
results_path += "parameter_set_{}".format(start_par_count) + "/"
start_time = time.time()
fitter = Fitter()
@ -64,12 +65,11 @@ def run_with_real_data():
if not os.path.exists(results_path):
os.makedirs(results_path)
with open(results_path + "fit_parameters_start_{}.txt".format(start_par_count), "w") as file:
with open(results_path + "parameters_info.txt".format(start_par_count), "w") as file:
file.writelines(["start_parameters:\t" + str(start_parameters),
"\nfinal_parameters:\t" + str(parameters),
"\nfinal_fmin:\t" + str(fmin)])
results_path += SAVE_PATH_PREFIX + "par_set_" + str(start_par_count) + "_"
print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time)))
# print(results_path)
print_comparision_cell_model(cell_data, parameters, plot=True, savepath=results_path)
@ -90,12 +90,12 @@ def print_comparision_cell_model(cell_data: CellData, parameters, plot=False, sa
m_sc = model_baseline.get_serial_correlation(1)
m_cv = model_baseline.get_coefficient_of_variation()
f_baselines, f_zeros, m_f_infinities = res_model.calculate_fi_curve(fi_curve.stimulus_value,
_, m_f_zeros, m_f_infinities = res_model.calculate_fi_curve(fi_curve.stimulus_value,
cell_data.get_eod_frequency())
f_infinities_fit = hF.fit_clipped_line(fi_curve.stimulus_value, m_f_infinities)
m_f_infinities_slope = f_infinities_fit[0]
f_zeros_fit = hF.fit_boltzmann(fi_curve.stimulus_value, f_zeros)
f_zeros_fit = hF.fit_boltzmann(fi_curve.stimulus_value, m_f_zeros)
m_f_zero_slope = fu.full_boltzmann_straight_slope(f_zeros_fit[0], f_zeros_fit[1], f_zeros_fit[2], f_zeros_fit[3])
data_baseline = get_baseline_class(cell_data)
@ -104,20 +104,47 @@ def print_comparision_cell_model(cell_data: CellData, parameters, plot=False, sa
c_sc = data_baseline.get_serial_correlation(1)
c_cv = data_baseline.get_coefficient_of_variation()
c_f_slope = fi_curve.get_f_infinity_slope()
c_f_values = fi_curve.f_infinities
c_f_inf_slope = fi_curve.get_f_infinity_slope()
c_f_inf_values = fi_curve.f_infinities
c_f_zero_slope = fi_curve.get_fi_curve_slope_of_straight()
c_f_zero_values = fi_curve.f_zeros
print("EOD-frequency: {:.2f}".format(cell_data.get_eod_frequency()))
print("bf: cell - {:.2f} vs model {:.2f}".format(c_bf, m_bf))
print("vs: cell - {:.2f} vs model {:.2f}".format(c_vs, m_vs))
print("sc: cell - {:.2f} vs model {:.2f}".format(c_sc[0], m_sc[0]))
print("cv: cell - {:.2f} vs model {:.2f}".format(c_cv, m_cv))
print("f_inf_slope: cell - {:.2f} vs model {:.2f}".format(c_f_slope, m_f_infinities_slope))
print("f infinity values:\n cell -", c_f_values, "\n model -", m_f_infinities)
print("f_inf_slope: cell - {:.2f} vs model {:.2f}".format(c_f_inf_slope, m_f_infinities_slope))
print("f infinity values:\n cell -", c_f_inf_values, "\n model -", m_f_infinities)
print("f_zero_slope: cell - {:.2f} vs model {:.2f}".format(c_f_zero_slope, m_f_zero_slope))
print("f zero values:\n cell -", c_f_zero_values, "\n model -", f_zeros)
print("f zero values:\n cell -", c_f_zero_values, "\n model -", m_f_zeros)
if savepath is not None:
with open(savepath + "value_comparision.tsv", 'w') as value_file:
value_file.write("Variable\tCell\tModel\n")
value_file.write("baseline_frequency\t{:.2f}\t{:.2f}\n".format(c_bf, m_bf))
value_file.write("vector_strength\t{:.2f}\t{:.2f}\n".format(c_vs, m_vs))
value_file.write("serial_correlation\t{:.2f}\t{:.2f}\n".format(c_sc[0], m_sc[0]))
value_file.write("coefficient_of_variation\t{:.2f}\t{:.2f}\n".format(c_cv, m_cv))
value_file.write("f_inf_slope\t{:.2f}\t{:.2f}\n".format(c_f_inf_slope, m_f_infinities_slope))
value_file.write("f_zero_slope\t{:.2f}\t{:.2f}\n".format(c_f_zero_slope, m_f_zero_slope))
if plot:
# plot cell images:
cell_save_path = savepath + "/cell/"
if not os.path.exists(cell_save_path):
os.makedirs(cell_save_path)
# data_baseline.plot_baseline(cell_save_path + "baseline.png")
data_baseline.plot_inter_spike_interval_histogram(cell_save_path + "isi-histogram.png")
data_baseline.plot_serial_correlation(6, cell_save_path + "serial_correlation.png")
# plot model images
model_save_path = savepath + "/model/"
if not os.path.exists(model_save_path):
os.makedirs(model_save_path)
# model_baseline.plot_baseline(model_save_path + "baseline.png")
model_baseline.plot_inter_spike_interval_histogram(model_save_path + "isi-histogram.png")
model_baseline.plot_serial_correlation(6, model_save_path + "serial_correlation.png")
if plot:
f_b, f_zero, f_inf = res_model.calculate_fi_curve(cell_data.get_fi_contrasts(), cell_data.get_eod_frequency())
@ -139,7 +166,7 @@ class Fitter:
self.fi_contrasts = []
self.eod_freq = 0
self.sc_max_lag = 1
self.sc_max_lag = 2
# values to be replicated:
self.baseline_freq = 0
@ -192,9 +219,7 @@ class Fitter:
# return self.fit_model(fit_adaption=False)
def fit_routine_5(self, cell_data=None, start_parameters=None):
global SAVE_PATH_PREFIX
SAVE_PATH_PREFIX = "fit_routine_5_"
# errors: [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
# [error_bf, error_vs, error_sc, error_cv, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
self.counter = 0
# fit only v_offset, mem_tau, input_scaling, dend_tau
if start_parameters is None:
@ -203,7 +228,7 @@ class Fitter:
x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
start_parameters["input_scaling"], self.tau_a, self.delta_a, start_parameters["dend_tau"]])
initial_simplex = create_init_simples(x0, search_scale=2)
error_weights = (0, 1, 1, 1, 1, 1, 2, 1)
error_weights = (0, 2, 3, 1, 1, 1, 0.5, 1)
fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 400, "maxiter": 400})
@ -382,9 +407,13 @@ class Fitter:
# calculate errors with reference values
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
error_vs = abs((vector_strength - self.vector_strength) / self.vector_strength)
error_sc = abs((serial_correlation[0] - self.serial_correlation[0]) / self.serial_correlation[0])
error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / self.coefficient_of_variation)
error_sc = 0
for i in range(self.sc_max_lag):
error_sc = abs((serial_correlation[i] - self.serial_correlation[i]) / self.serial_correlation[i])
error_sc = error_sc / self.sc_max_lag
error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / self.f_inf_slope) * 4
error_f_inf = calculate_f_values_error(f_infinities, self.f_inf_values) * .5