lots of changes - errors in f_zero calculation?

This commit is contained in:
a.ott 2020-05-13 16:45:12 +02:00
parent 1f82f06209
commit 2f0c19ea5a
8 changed files with 319 additions and 103 deletions

View File

@ -143,9 +143,9 @@ class Adaption:
# print("current tau: {:.1f}ms".format(np.median(tau_effs) * (fi_curve_slope / f_infinity_slope) * 1000))
# new way to calculate with the fi curve slope at the intersection point of it and the f_inf line
factor = self.fi_curve.get_fi_curve_slope_at_f_zero_intersection() / f_infinity_slope
factor = self.fi_curve.get_f_zero_fit_slope_at_f_inf_fit_intersection() / f_infinity_slope
self.tau_real = np.median(tau_effs) * factor
print("###### tau: {:.1f}ms".format(self.tau_real*1000), "other f_0 slope:", self.fi_curve.get_fi_curve_slope_at_f_zero_intersection())
print("###### tau: {:.1f}ms".format(self.tau_real*1000), "other f_0 slope:", self.fi_curve.get_f_zero_fit_slope_at_f_inf_fit_intersection())
def get_tau_real(self):
return np.median(self.tau_real)

View File

@ -4,7 +4,6 @@ from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import numpy as np
from warnings import warn
import matplotlib.pyplot as plt
@ -35,7 +34,53 @@ class Baseline:
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
@staticmethod
def plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2):
def _get_baseline_frequency_given_data(spiketimes):
base_freqs = []
for st in spiketimes:
base_freqs.append(hF.calculate_mean_isi_freq(st))
return np.median(base_freqs)
@staticmethod
def _get_serial_correlation_given_data(max_lag, spikestimes):
serial_cors = []
for st in spikestimes:
sc = hF.calculate_serial_correlation(st, max_lag)
serial_cors.append(sc)
serial_cors = np.array(serial_cors)
return np.mean(serial_cors, axis=0)
@staticmethod
def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
vs_per_trial = []
for i in range(len(spiketimes)):
vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
vs_per_trial.append(vs)
return np.mean(vs_per_trial)
@staticmethod
def _get_coefficient_of_variation_given_data(spiketimes):
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
cvs = []
for st in spiketimes:
st = np.array(st)
cvs.append(hF.calculate_coefficient_of_variation(st))
return np.mean(cvs)
@staticmethod
def _get_interspike_intervals_given_data(spiketimes):
isis = []
for st in spiketimes:
st = np.array(st)
isis.extend(np.diff(st))
return isis
@staticmethod
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2):
"""
plots the stimulus / eod, together with the v1, spiketimes and frequency
:return:
@ -113,18 +158,8 @@ class BaselineCellData(Baseline):
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
base_freqs = []
for freq in self.data.get_mean_isi_frequencies():
delay = self.data.get_delay()
sampling_interval = self.data.get_sampling_interval()
if delay < 0.1:
warn("BaselineCellData:get_baseline_Frequency(): Quite short delay at the start.")
idx_start = int(0.025 / sampling_interval)
idx_end = int((delay - 0.025) / sampling_interval)
base_freqs.append(np.mean(freq[idx_start:idx_end]))
self.baseline_frequency = np.median(base_freqs)
spiketimes = self.data.get_base_spikes()
self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)
return self.baseline_frequency
@ -132,44 +167,23 @@ class BaselineCellData(Baseline):
if self.vector_strength == -1:
times = self.data.get_base_traces(self.data.TIME)
eods = self.data.get_base_traces(self.data.EOD)
v1_traces = self.data.get_base_traces(self.data.V1)
self.vector_strength = hF.calculate_vector_strength_from_v1_trace(times, eods, v1_traces)
spiketimes = self.data.get_base_spikes()
sampling_interval = self.data.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
serial_cors = []
for spiketimes in self.data.get_base_spikes():
sc = hF.calculate_serial_correlation(spiketimes, max_lag)
serial_cors.append(sc)
serial_cors = np.array(serial_cors)
mean_sc = np.mean(serial_cors, axis=0)
self.serial_correlation = mean_sc
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
spiketimes = self.data.get_base_spikes()
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
cvs = []
for st in spiketimes:
st = np.array(st)
cvs.append(hF.calculate_coefficient_of_variation(st))
self.coefficient_of_variation = np.mean(cvs)
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
return self.coefficient_of_variation
def get_interspike_intervals(self):
spiketimes = self.data.get_base_spikes()
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isis = []
for st in spiketimes:
st = np.array(st)
isis.extend(np.diff(st))
return isis
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
def plot_baseline(self, save_path=None, time_length=0.2):
# eod, v1, spiketimes, frequency
@ -179,51 +193,59 @@ class BaselineCellData(Baseline):
v1_trace = self.data.get_base_traces(self.data.V1)[0]
spiketimes = self.data.get_base_spikes()[0]
self.plot_baseline_given_data(time, eod, v1_trace, spiketimes,
self.data.get_sampling_interval(), save_path, time_length)
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
self.data.get_sampling_interval(), save_path, time_length)
class BaselineModel(Baseline):
simulation_time = 30
def __init__(self, model: LifacNoiseModel, eod_frequency):
def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
super().__init__()
self.model = model
self.eod_frequency = eod_frequency
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
self.v1, self.spiketimes = model.simulate_fast(self.stimulus, self.simulation_time)
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
self.v1_traces = []
self.spiketimes = []
for i in range(trials):
v, st = model.simulate_fast(self.stimulus, self.simulation_time)
self.v1_traces.append(v)
self.spiketimes.append(st)
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
self.baseline_frequency = hF.calculate_mean_isi_freq(self.spiketimes)
self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
self.vector_strength = hF.calculate_vector_strength_from_spiketimes(self.time, self.eod, self.spiketimes,
self.model.get_sampling_interval())
times = [self.time] * len(self.spiketimes)
eods = [self.eod] * len(self.spiketimes)
sampling_interval = self.model.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
self.serial_correlation = hF.calculate_serial_correlation(self.spiketimes, max_lag)
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = hF.calculate_coefficient_of_variation(self.spiketimes)
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
return self.coefficient_of_variation
def get_interspike_intervals(self):
return np.diff(self.spiketimes)
return self._get_interspike_intervals_given_data(self.spiketimes)
def plot_baseline(self, save_path=None, time_length=0.2):
self.plot_baseline_given_data(self.time, self.eod, self.v1, self.spiketimes,
self.model.get_sampling_interval(), save_path, time_length)
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
self.model.get_sampling_interval(), save_path, time_length)
def get_baseline_class(data, eod_freq=None) -> Baseline:

View File

@ -58,8 +58,14 @@ class CellData:
def get_base_spikes(self):
if self.base_spikes is None:
self.base_spikes = self.parser.get_baseline_spiketimes()
times = self.get_base_traces(self.TIME)
eods = self.get_base_traces(self.EOD)
v1_traces = self.get_base_traces(self.V1)
spiketimes = []
for i in range(len(times)):
spiketimes.append(hf.detect_spiketimes(times[i], v1_traces[i]))
self.base_spikes = spiketimes
return self.base_spikes
def get_base_isis(self):

View File

@ -82,7 +82,7 @@ class FICurve:
else:
return x_values[intersection_indicies[0]]
def get_fi_curve_slope_at_f_zero_intersection(self):
def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
x = self.get_f_zero_and_f_inf_intersection()
fit_vars = self.f_zero_fit
return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
@ -321,9 +321,11 @@ class FICurveCellData(FICurve):
class FICurveModel(FICurve):
def __init__(self, model, stimulus_values, eod_frequency):
def __init__(self, model, stimulus_values, eod_frequency, trials=1):
self.eod_frequency = eod_frequency
self.model = model
self.trials = trials
self.spiketimes = []
super().__init__(stimulus_values)
def calculate_all_frequency_points(self):
@ -338,10 +340,18 @@ class FICurveModel(FICurve):
for c in self.stimulus_values:
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, stim_start, stim_duration)
_, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time)
time, frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
if len(spiketimes) < 10 or len(time) == 0 or min(time) > stim_start \
frequency_traces = []
time_traces = []
for i in range(self.trials):
_, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time)
trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
frequency_traces.append(trial_frequency)
time_traces.append(trial_time)
time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
if len(time) == 0 or min(time) > stim_start \
or max(time) < stim_start + stim_duration:
print("Too few spikes to calculate f_inf, f_0 and f_base")
self.f_inf_frequencies.append(0)
@ -358,6 +368,7 @@ class FICurveModel(FICurve):
f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, stim_start, sampling_interval)
self.f_baseline_frequencies.append(f_baseline)
def plot_f_point_detections(self, save_path=None):
raise NotImplementedError("TODO sorry... "
"The model version of the FiCurve class is still missing this implementation")

View File

@ -5,8 +5,6 @@ from CellData import CellData, icelldata_of_dir
from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class
from AdaptionCurrent import Adaption
import helperFunctions as hF
import functions as fu
import numpy as np
from warnings import warn
from scipy.optimize import minimize
@ -70,8 +68,6 @@ def run_with_real_data():
start_par_count += 1
print("START PARAMETERS:", start_par_count)
parameter_set_path = results_path + "start_parameter_set_{}".format(start_par_count) + "/"
start_time = time.time()
fitter = Fitter()
fmin, parameters = fitter.fit_model_to_data(cell_data, start_parameters)
@ -79,6 +75,7 @@ def run_with_real_data():
print(fmin)
print(parameters)
end_time = time.time()
parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/"
if not os.path.exists(parameter_set_path):
os.makedirs(parameter_set_path)
with open(parameter_set_path + "parameters_info.txt".format(start_par_count), "w") as file:
@ -178,7 +175,7 @@ class Fitter:
self.f_inf_slope = 0
self.f_zero_values = []
self.f_zero_slope = 0
self.f_zero_slopes = []
self.f_zero_fit = []
self.tau_a = 0
@ -203,12 +200,13 @@ class Fitter:
self.f_zero_values = fi_curve.f_zero_frequencies
self.f_zero_fit = fi_curve.f_zero_fit
self.f_zero_slope = fi_curve.get_f_zero_fit_slope_at_straight()
self.f_zero_slopes = [fi_curve.get_f_zero_fit_slope_at_stimulus_value(c) for c in self.fi_contrasts]
# around 1/3 of the value at straight
# self.f_zero_slope = fi_curve.get_fi_curve_slope_at(fi_curve.get_f_zero_and_f_inf_intersection())
self.delta_a = (self.f_zero_slope / self.f_inf_slope) / 1000 # seems to work if divided by 1000...
# seems to work if divided by 1000...
self.delta_a = (fi_curve.get_f_zero_fit_slope_at_straight() / self.f_inf_slope) / 1000
adaption = Adaption(data, fi_curve)
self.tau_a = adaption.get_tau_real()
@ -218,7 +216,6 @@ class Fitter:
return self.fit_routine_5(data, start_parameters)
def fit_routine_5(self, cell_data=None, start_parameters=None):
# [error_bf, error_vs, error_sc, error_cv, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
self.counter = 0
# fit only v_offset, mem_tau, input_scaling, dend_tau
if start_parameters is None:
@ -227,7 +224,10 @@ class Fitter:
x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
start_parameters["input_scaling"], self.tau_a, self.delta_a, start_parameters["dend_tau"]])
initial_simplex = create_init_simples(x0, search_scale=2)
error_weights = (0, 2, 5, 1, 1, 1, 0.5, 1)
# error_list = [error_bf, error_vs, error_sc, error_cv,
# error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_weights = (0, 1, 1, 1, 1, 1, 1, 1)
fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 400, "maxiter": 400})
@ -386,22 +386,13 @@ class Fitter:
coefficient_of_variation = model_baseline.get_coefficient_of_variation()
# f_infinities, f_infinities_slope = self.base_model.calculate_fi_markers(self.fi_contrasts, self.eod_freq)
f_baselines, f_zeros, f_infinities = self.base_model.calculate_fi_curve(self.fi_contrasts, self.eod_freq)
try:
f_infinities_fit = hF.fit_clipped_line(self.fi_contrasts, f_infinities)
except Exception as e:
print("EXCEPTION IN FIT LINE!")
print(e)
f_infinities_fit = [0, 0]
f_infinities_slope = f_infinities_fit[0]
try:
f_zeros_fit = hF.fit_boltzmann(self.fi_contrasts, f_zeros)
except Exception as e:
print("EXCEPTION IN FIT BOLTZMANN!")
print(e)
f_zeros_fit = [0, 0, 0, 0]
f_zero_slope = fu.full_boltzmann_straight_slope(f_zeros_fit[0], f_zeros_fit[1], f_zeros_fit[2], f_zeros_fit[3])
fi_curve_model = get_fi_curve_class(self.base_model, self.fi_contrasts, self.eod_freq)
f_zeros = fi_curve_model.get_f_zero_frequencies()
f_infinities = fi_curve_model.get_f_inf_frequencies()
f_infinities_slope = fi_curve_model.get_f_inf_slope()
f_zero_slopes = [fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(x) for x in self.fi_contrasts]
# print("fi-curve features calculated!")
# calculate errors with reference values
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
@ -413,20 +404,20 @@ class Fitter:
error_sc = abs((serial_correlation[i] - self.serial_correlation[i]) / self.serial_correlation[i])
error_sc = error_sc / self.sc_max_lag
error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / self.f_inf_slope) * 4
error_f_inf = calculate_f_values_error(f_infinities, self.f_inf_values) * .5
error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / self.f_inf_slope)
error_f_inf = calculate_list_error(f_infinities, self.f_inf_values)
error_f_zero_slope = abs((f_zero_slope - self.f_zero_slope) / self.f_zero_slope)
error_f_zero = calculate_f_values_error(f_zeros, self.f_zero_values)
error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values)
error_list = [error_bf, error_vs, error_sc, error_cv,
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slopes]
if error_weights is not None and len(error_weights) == len(error_list):
for i in range(len(error_weights)):
error_list[i] = error_list[i] * error_weights[i]
elif len(error_weights) != len(error_list):
if len(error_weights) != len(error_list):
warn("Error weights had different length than errors and were ignored!")
error = sum(error_list)
@ -446,14 +437,14 @@ class Fitter:
self.f_inf_slope, f_infinities_slope, error_f_inf_slope),
"f-infinity values:\nexpected:", np.around(self.f_inf_values), "\ncurrent: ", np.around(f_infinities),
"\nerror: {:.3f}\n".format(error_f_inf),
"f-zero slope - expected: {:.0f}, current: {:.0f}, error: {:.3f}\n".format(
self.f_zero_slope, f_zero_slope, error_f_zero_slope),
"f-zero slopes:\nexpected:", np.around(self.f_zero_slopes), "\ncurrent: ", np.around(f_zero_slopes),
"\nerror: {:.3f}".format(error_f_zero_slopes),
"f-zero values:\nexpected:", np.around(self.f_zero_values), "\ncurrent: ", np.around(f_zeros),
"\nerror: {:.3f}".format(error_f_zero))
return error_list
def calculate_f_values_error(fit, reference):
def calculate_list_error(fit, reference):
error = 0
for i in range(len(reference)):
# TODO ??? add a constant to f_inf to allow for small differences in small values

View File

@ -14,6 +14,9 @@ def fit_clipped_line(x, y):
def fit_boltzmann(x, y):
max_f0 = float(max(y))
if max_f0 == 0:
return [0, 0, 0, 0]
min_f0 = 0.1 # float(min(self.f_zeros))
mean_int = float(np.mean(x))
@ -21,10 +24,15 @@ def fit_boltzmann(x, y):
total_change_int = max(x) - min(x)
start_k = float((total_increase / total_change_int * 4) / max_f0)
popt, pcov = curve_fit(fu.full_boltzmann, x, y,
p0=(max_f0, min_f0, start_k, mean_int),
maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf]))
try:
popt, pcov = curve_fit(fu.full_boltzmann, x, y,
p0=(max_f0, min_f0, start_k, mean_int),
maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf]))
except RuntimeError as e:
print("Error in fit boltzmann: ", str(e))
print("x_values:", x)
print("y_values:", y)
return [0, 0, 0, 0]
return popt
@ -248,7 +256,9 @@ def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
isi = np.diff(spiketimes)
if len(spiketimes) < max_lag + 1:
raise ValueError("Given list to short, with given max_lag")
warn("Cannot compute serial correlation with list shorter than max lag...")
return np.zeros(max_lag)
# raise ValueError("Given list to short, with given max_lag")
cor = np.zeros(max_lag)
for lag in range(max_lag):
@ -286,7 +296,7 @@ def calculate_vector_strength_from_v1_trace(times, eods, v1_traces):
print("-----LENGTH OF TIMES = 0")
for recording in range(len(times)):
spiketime_idices = detect_spikes(v1_traces[recording])
spiketime_idices = detect_spikes_indices(v1_traces[recording])
rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices)
relative_spike_times.extend(rel_spikes)
eod_durations.extend(eod_durs)
@ -305,7 +315,7 @@ def calculate_vector_strength_from_spiketimes(time, eod, spiketimes, sampling_in
return __vector_strength__(rel_spikes, eod_durs)
def detect_spikes(v1, split=20, threshold=3):
def detect_spikes_indices(v1, split=20, threshold=3):
total = len(v1)
all_peaks = []
@ -323,6 +333,12 @@ def detect_spikes(v1, split=20, threshold=3):
return all_peaks
def detect_spiketimes(time, v1, split=20, threshold=3):
all_peak_indicies = detect_spikes_indices(v1, split, threshold)
return [time[p_idx] for p_idx in all_peak_indicies]
def calculate_phases(relative_spike_times, eod_durations):
phase_times = np.zeros(len(relative_spike_times))
@ -455,6 +471,12 @@ def detect_f_zero_in_frequency_trace(time, frequency, stimulus_start, sampling_i
# plt.plot(time[start_idx:end_idx], [f_zero for i in range(end_idx-start_idx)])
# plt.show()
max_frequency = int(1/sampling_interval)
int_f_zero = int(f_zero)
if int_f_zero > max_frequency:
raise AssertionError("Detection of f-zero went very wrong! frequency above 1/sampling_interval.")
if int_f_zero > max(frequency):
raise AssertionError("detected f_zero bigger than the highest peak in the frequency trace...")
return f_zero

View File

@ -1,12 +1,21 @@
import numpy as np
import matplotlib.pyplot as plt
import helperFunctions as hF
import functions as fu
import time
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
def main():
x_values = np.arange(-1, 1, 0.01)
popt = [0, 0, 0, 0]
y_values = [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values]
plt.plot(x_values, y_values)
plt.show()
quit()
for freq in [700, 50, 100, 500, 1000]:
reps = 1000
start = time.time()

155
variableEffect.py Normal file
View File

@ -0,0 +1,155 @@
from models.LIFACnoise import LifacNoiseModel
from Baseline import BaselineModel
from FiCurve import FICurveModel
import numpy as np
import matplotlib.pyplot as plt
import copy
SEARCH_WIDTH = 3
SEARCH_PRECISION = 30
CONTRASTS = np.arange(-0.4, 0.45, 0.1)
def main():
model_parameters = {'threshold': 1,
'step_size': 5e-05,
'a_zero': 2,
'delta_a': 0.2032269898801589,
'mem_tau': 0.011314027210564803,
'noise_strength': 0.056724809998220195,
'v_zero': 0,
'v_base': 0,
'tau_a': 0.05958195972016753,
'input_scaling': 119.81500448274554,
'dend_tau': 0.0027746086464721723,
'v_offset': -24.21875}
parameters_to_test = ["input_scaling", "dend_tau", "mem_tau", "noise_strength", "v_offset", "delta_a", "tau_a"]
effect_data = []
for p in parameters_to_test:
print("Working on parameter " + p)
effect_data.append(test_parameter_effect(model_parameters, p))
plot_effects(effect_data, "./figures/variable_effect/")
def test_parameter_effect(model_parameters, test_parameter):
model_parameters = copy.deepcopy(model_parameters)
start_value = model_parameters[test_parameter]
start = start_value*(1/SEARCH_WIDTH)
end = start_value*SEARCH_WIDTH
step = (end - start) / SEARCH_PRECISION
values = np.arange(start, end+step, step)
bf = []
vs = []
sc = []
cv = []
f_inf_s = []
f_inf_v = []
f_zero_s = []
f_zero_v = []
broken_i = []
for i in range(len(values)):
model_parameters[test_parameter] = values[i]
model = LifacNoiseModel(model_parameters)
fi_curve = FICurveModel(model, CONTRASTS, 600, trials=50)
f_inf_s.append(fi_curve.get_f_inf_slope())
f_inf_v.append(fi_curve.get_f_inf_frequencies())
f_zero_s.append(fi_curve.get_f_zero_fit_slope_at_stimulus_value(0.1))
f_zero_v.append(fi_curve.get_f_zero_frequencies())
baseline = BaselineModel(model, 600, trials=10)
bf.append(baseline.get_baseline_frequency())
vs.append(baseline.get_vector_strength())
sc.append(baseline.get_serial_correlation(2))
cv.append(baseline.get_coefficient_of_variation())
values = list(values)
if len(broken_i) > 0:
broken_i = sorted(broken_i, reverse=True)
for i in broken_i:
del values[i]
return ParameterEffectData(values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v)
# plot_effects(values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v)
def plot_effects(par_effect_data_list, save_path=None):
fig, axes = plt.subplots(8, len(par_effect_data_list), figsize=(32, 4*len(par_effect_data_list)), sharex="col")
names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v", "f_zero_s", "f_zero_v")
for j in range(len(par_effect_data_list)):
ped = par_effect_data_list[j]
ranges = ((0, max(ped.get_data("bf")) * 1.1), (0, 1), (-1, 1), (0, 1),
(0, max(ped.get_data("f_inf_s")) * 1.1), (0, 800),
(0, max(ped.get_data("f_zero_s")) * 1.1), (0, 3000))
values = ped.values
for i in range(len(names)):
y_data = ped.get_data(names[i])
axes[i, j].plot(values, y_data)
axes[i, j].set_ylim(ranges[i])
if j == 0:
axes[i, j].set_ylabel(names[i])
if i == 0:
axes[i, j].set_title(ped.test_parameter)
plt.tight_layout()
if save_path is not None:
plt.savefig(save_path + "variable_effect_master_plot.png")
else:
plt.show()
plt.close()
class ParameterEffectData:
data_names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v" "f_zero_s", "f_zero_v")
def __init__(self, values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v):
self.values = values
self.test_parameter = test_parameter
self.bf = bf
self.vs = vs
self.sc = sc
self.cv = cv
self.f_inf_s = f_inf_s
self.f_inf_v = f_inf_v
self.f_zero_s = f_zero_s
self.f_zero_v = f_zero_v
def get_data(self, name):
if name == "bf":
return self.bf
elif name == "vs":
return self.vs
elif name == "sc":
return self.sc
elif name == "cv":
return self.cv
elif name == "f_inf_s":
return self.f_inf_s
elif name == "f_inf_v":
return self.f_inf_v
elif name == "f_zero_s":
return self.f_zero_s
elif name == "f_zero_v":
return self.f_zero_v
else:
raise ValueError("Unknown attribute name!")
if __name__ == '__main__':
main()