From 2f0c19ea5a9d528e9704fa48834be766fafdb0f0 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Wed, 13 May 2020 16:45:12 +0200 Subject: [PATCH] lots of changes - errors in f_zero calculation? --- AdaptionCurrent.py | 4 +- Baseline.py | 130 +++++++++++++++++++++--------------- CellData.py | 8 ++- FiCurve.py | 21 ++++-- Fitter.py | 59 +++++++--------- helperFunctions.py | 36 ++++++++-- introduction/test.py | 9 +++ variableEffect.py | 155 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 319 insertions(+), 103 deletions(-) create mode 100644 variableEffect.py diff --git a/AdaptionCurrent.py b/AdaptionCurrent.py index 332a108..0511464 100644 --- a/AdaptionCurrent.py +++ b/AdaptionCurrent.py @@ -143,9 +143,9 @@ class Adaption: # print("current tau: {:.1f}ms".format(np.median(tau_effs) * (fi_curve_slope / f_infinity_slope) * 1000)) # new way to calculate with the fi curve slope at the intersection point of it and the f_inf line - factor = self.fi_curve.get_fi_curve_slope_at_f_zero_intersection() / f_infinity_slope + factor = self.fi_curve.get_f_zero_fit_slope_at_f_inf_fit_intersection() / f_infinity_slope self.tau_real = np.median(tau_effs) * factor - print("###### tau: {:.1f}ms".format(self.tau_real*1000), "other f_0 slope:", self.fi_curve.get_fi_curve_slope_at_f_zero_intersection()) + print("###### tau: {:.1f}ms".format(self.tau_real*1000), "other f_0 slope:", self.fi_curve.get_f_zero_fit_slope_at_f_inf_fit_intersection()) def get_tau_real(self): return np.median(self.tau_real) diff --git a/Baseline.py b/Baseline.py index dd6faea..dfb1417 100644 --- a/Baseline.py +++ b/Baseline.py @@ -4,7 +4,6 @@ from models.LIFACnoise import LifacNoiseModel from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus import helperFunctions as hF import numpy as np -from warnings import warn import matplotlib.pyplot as plt @@ -35,7 +34,53 @@ class Baseline: raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") @staticmethod - def plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2): + def _get_baseline_frequency_given_data(spiketimes): + base_freqs = [] + for st in spiketimes: + base_freqs.append(hF.calculate_mean_isi_freq(st)) + + return np.median(base_freqs) + + @staticmethod + def _get_serial_correlation_given_data(max_lag, spikestimes): + serial_cors = [] + for st in spikestimes: + sc = hF.calculate_serial_correlation(st, max_lag) + serial_cors.append(sc) + serial_cors = np.array(serial_cors) + + return np.mean(serial_cors, axis=0) + + @staticmethod + def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval): + vs_per_trial = [] + for i in range(len(spiketimes)): + vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval) + vs_per_trial.append(vs) + + return np.mean(vs_per_trial) + + @staticmethod + def _get_coefficient_of_variation_given_data(spiketimes): + # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) + cvs = [] + for st in spiketimes: + st = np.array(st) + cvs.append(hF.calculate_coefficient_of_variation(st)) + + return np.mean(cvs) + + @staticmethod + def _get_interspike_intervals_given_data(spiketimes): + isis = [] + for st in spiketimes: + st = np.array(st) + isis.extend(np.diff(st)) + + return isis + + @staticmethod + def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2): """ plots the stimulus / eod, together with the v1, spiketimes and frequency :return: @@ -113,18 +158,8 @@ class BaselineCellData(Baseline): def get_baseline_frequency(self): if self.baseline_frequency == -1: - base_freqs = [] - for freq in self.data.get_mean_isi_frequencies(): - delay = self.data.get_delay() - sampling_interval = self.data.get_sampling_interval() - if delay < 0.1: - warn("BaselineCellData:get_baseline_Frequency(): Quite short delay at the start.") - - idx_start = int(0.025 / sampling_interval) - idx_end = int((delay - 0.025) / sampling_interval) - base_freqs.append(np.mean(freq[idx_start:idx_end])) - - self.baseline_frequency = np.median(base_freqs) + spiketimes = self.data.get_base_spikes() + self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes) return self.baseline_frequency @@ -132,44 +167,23 @@ class BaselineCellData(Baseline): if self.vector_strength == -1: times = self.data.get_base_traces(self.data.TIME) eods = self.data.get_base_traces(self.data.EOD) - v1_traces = self.data.get_base_traces(self.data.V1) - self.vector_strength = hF.calculate_vector_strength_from_v1_trace(times, eods, v1_traces) - + spiketimes = self.data.get_base_spikes() + sampling_interval = self.data.get_sampling_interval() + self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: - serial_cors = [] - for spiketimes in self.data.get_base_spikes(): - sc = hF.calculate_serial_correlation(spiketimes, max_lag) - serial_cors.append(sc) - serial_cors = np.array(serial_cors) - mean_sc = np.mean(serial_cors, axis=0) - - self.serial_correlation = mean_sc + self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes()) return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: - spiketimes = self.data.get_base_spikes() - # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) - cvs = [] - for st in spiketimes: - st = np.array(st) - cvs.append(hF.calculate_coefficient_of_variation(st)) - - self.coefficient_of_variation = np.mean(cvs) + self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes()) return self.coefficient_of_variation def get_interspike_intervals(self): - spiketimes = self.data.get_base_spikes() - # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) - isis = [] - for st in spiketimes: - st = np.array(st) - isis.extend(np.diff(st)) - - return isis + return self._get_interspike_intervals_given_data(self.data.get_base_spikes()) def plot_baseline(self, save_path=None, time_length=0.2): # eod, v1, spiketimes, frequency @@ -179,51 +193,59 @@ class BaselineCellData(Baseline): v1_trace = self.data.get_base_traces(self.data.V1)[0] spiketimes = self.data.get_base_spikes()[0] - self.plot_baseline_given_data(time, eod, v1_trace, spiketimes, - self.data.get_sampling_interval(), save_path, time_length) + self._plot_baseline_given_data(time, eod, v1_trace, spiketimes, + self.data.get_sampling_interval(), save_path, time_length) class BaselineModel(Baseline): simulation_time = 30 - def __init__(self, model: LifacNoiseModel, eod_frequency): + def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1): super().__init__() self.model = model self.eod_frequency = eod_frequency self.stimulus = SinusoidalStepStimulus(eod_frequency, 0) - self.v1, self.spiketimes = model.simulate_fast(self.stimulus, self.simulation_time) self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval()) self.time = np.arange(0, self.simulation_time, model.get_sampling_interval()) + self.v1_traces = [] + self.spiketimes = [] + for i in range(trials): + v, st = model.simulate_fast(self.stimulus, self.simulation_time) + self.v1_traces.append(v) + self.spiketimes.append(st) + def get_baseline_frequency(self): if self.baseline_frequency == -1: - self.baseline_frequency = hF.calculate_mean_isi_freq(self.spiketimes) - + self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: - self.vector_strength = hF.calculate_vector_strength_from_spiketimes(self.time, self.eod, self.spiketimes, - self.model.get_sampling_interval()) + times = [self.time] * len(self.spiketimes) + eods = [self.eod] * len(self.spiketimes) + sampling_interval = self.model.get_sampling_interval() + self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval) + return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: - self.serial_correlation = hF.calculate_serial_correlation(self.spiketimes, max_lag) + self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes) return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: - self.coefficient_of_variation = hF.calculate_coefficient_of_variation(self.spiketimes) + self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes) return self.coefficient_of_variation def get_interspike_intervals(self): - return np.diff(self.spiketimes) + return self._get_interspike_intervals_given_data(self.spiketimes) def plot_baseline(self, save_path=None, time_length=0.2): - self.plot_baseline_given_data(self.time, self.eod, self.v1, self.spiketimes, - self.model.get_sampling_interval(), save_path, time_length) + self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0], + self.model.get_sampling_interval(), save_path, time_length) def get_baseline_class(data, eod_freq=None) -> Baseline: diff --git a/CellData.py b/CellData.py index 837ecd2..6538714 100644 --- a/CellData.py +++ b/CellData.py @@ -58,8 +58,14 @@ class CellData: def get_base_spikes(self): if self.base_spikes is None: - self.base_spikes = self.parser.get_baseline_spiketimes() + times = self.get_base_traces(self.TIME) + eods = self.get_base_traces(self.EOD) + v1_traces = self.get_base_traces(self.V1) + spiketimes = [] + for i in range(len(times)): + spiketimes.append(hf.detect_spiketimes(times[i], v1_traces[i])) + self.base_spikes = spiketimes return self.base_spikes def get_base_isis(self): diff --git a/FiCurve.py b/FiCurve.py index 9100ff2..54b5664 100644 --- a/FiCurve.py +++ b/FiCurve.py @@ -82,7 +82,7 @@ class FICurve: else: return x_values[intersection_indicies[0]] - def get_fi_curve_slope_at_f_zero_intersection(self): + def get_f_zero_fit_slope_at_f_inf_fit_intersection(self): x = self.get_f_zero_and_f_inf_intersection() fit_vars = self.f_zero_fit return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) @@ -321,9 +321,11 @@ class FICurveCellData(FICurve): class FICurveModel(FICurve): - def __init__(self, model, stimulus_values, eod_frequency): + def __init__(self, model, stimulus_values, eod_frequency, trials=1): self.eod_frequency = eod_frequency self.model = model + self.trials = trials + self.spiketimes = [] super().__init__(stimulus_values) def calculate_all_frequency_points(self): @@ -338,10 +340,18 @@ class FICurveModel(FICurve): for c in self.stimulus_values: stimulus = SinusoidalStepStimulus(self.eod_frequency, c, stim_start, stim_duration) - _, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time) - time, frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval) - if len(spiketimes) < 10 or len(time) == 0 or min(time) > stim_start \ + frequency_traces = [] + time_traces = [] + for i in range(self.trials): + _, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time) + trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval) + frequency_traces.append(trial_frequency) + time_traces.append(trial_time) + + time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval) + + if len(time) == 0 or min(time) > stim_start \ or max(time) < stim_start + stim_duration: print("Too few spikes to calculate f_inf, f_0 and f_base") self.f_inf_frequencies.append(0) @@ -358,6 +368,7 @@ class FICurveModel(FICurve): f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, stim_start, sampling_interval) self.f_baseline_frequencies.append(f_baseline) + def plot_f_point_detections(self, save_path=None): raise NotImplementedError("TODO sorry... " "The model version of the FiCurve class is still missing this implementation") diff --git a/Fitter.py b/Fitter.py index 3959fe5..90776b6 100644 --- a/Fitter.py +++ b/Fitter.py @@ -5,8 +5,6 @@ from CellData import CellData, icelldata_of_dir from Baseline import get_baseline_class from FiCurve import get_fi_curve_class from AdaptionCurrent import Adaption -import helperFunctions as hF -import functions as fu import numpy as np from warnings import warn from scipy.optimize import minimize @@ -70,8 +68,6 @@ def run_with_real_data(): start_par_count += 1 print("START PARAMETERS:", start_par_count) - parameter_set_path = results_path + "start_parameter_set_{}".format(start_par_count) + "/" - start_time = time.time() fitter = Fitter() fmin, parameters = fitter.fit_model_to_data(cell_data, start_parameters) @@ -79,6 +75,7 @@ def run_with_real_data(): print(fmin) print(parameters) end_time = time.time() + parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/" if not os.path.exists(parameter_set_path): os.makedirs(parameter_set_path) with open(parameter_set_path + "parameters_info.txt".format(start_par_count), "w") as file: @@ -178,7 +175,7 @@ class Fitter: self.f_inf_slope = 0 self.f_zero_values = [] - self.f_zero_slope = 0 + self.f_zero_slopes = [] self.f_zero_fit = [] self.tau_a = 0 @@ -203,12 +200,13 @@ class Fitter: self.f_zero_values = fi_curve.f_zero_frequencies self.f_zero_fit = fi_curve.f_zero_fit - self.f_zero_slope = fi_curve.get_f_zero_fit_slope_at_straight() + self.f_zero_slopes = [fi_curve.get_f_zero_fit_slope_at_stimulus_value(c) for c in self.fi_contrasts] # around 1/3 of the value at straight # self.f_zero_slope = fi_curve.get_fi_curve_slope_at(fi_curve.get_f_zero_and_f_inf_intersection()) - self.delta_a = (self.f_zero_slope / self.f_inf_slope) / 1000 # seems to work if divided by 1000... + # seems to work if divided by 1000... + self.delta_a = (fi_curve.get_f_zero_fit_slope_at_straight() / self.f_inf_slope) / 1000 adaption = Adaption(data, fi_curve) self.tau_a = adaption.get_tau_real() @@ -218,7 +216,6 @@ class Fitter: return self.fit_routine_5(data, start_parameters) def fit_routine_5(self, cell_data=None, start_parameters=None): - # [error_bf, error_vs, error_sc, error_cv, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] self.counter = 0 # fit only v_offset, mem_tau, input_scaling, dend_tau if start_parameters is None: @@ -227,7 +224,10 @@ class Fitter: x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"], start_parameters["input_scaling"], self.tau_a, self.delta_a, start_parameters["dend_tau"]]) initial_simplex = create_init_simples(x0, search_scale=2) - error_weights = (0, 2, 5, 1, 1, 1, 0.5, 1) + + # error_list = [error_bf, error_vs, error_sc, error_cv, + # error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] + error_weights = (0, 1, 1, 1, 1, 1, 1, 1) fmin = minimize(fun=self.cost_function_all, args=(error_weights,), x0=x0, method="Nelder-Mead", options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 400, "maxiter": 400}) @@ -386,22 +386,13 @@ class Fitter: coefficient_of_variation = model_baseline.get_coefficient_of_variation() # f_infinities, f_infinities_slope = self.base_model.calculate_fi_markers(self.fi_contrasts, self.eod_freq) - f_baselines, f_zeros, f_infinities = self.base_model.calculate_fi_curve(self.fi_contrasts, self.eod_freq) - try: - f_infinities_fit = hF.fit_clipped_line(self.fi_contrasts, f_infinities) - except Exception as e: - print("EXCEPTION IN FIT LINE!") - print(e) - f_infinities_fit = [0, 0] - - f_infinities_slope = f_infinities_fit[0] - try: - f_zeros_fit = hF.fit_boltzmann(self.fi_contrasts, f_zeros) - except Exception as e: - print("EXCEPTION IN FIT BOLTZMANN!") - print(e) - f_zeros_fit = [0, 0, 0, 0] - f_zero_slope = fu.full_boltzmann_straight_slope(f_zeros_fit[0], f_zeros_fit[1], f_zeros_fit[2], f_zeros_fit[3]) + + fi_curve_model = get_fi_curve_class(self.base_model, self.fi_contrasts, self.eod_freq) + f_zeros = fi_curve_model.get_f_zero_frequencies() + f_infinities = fi_curve_model.get_f_inf_frequencies() + f_infinities_slope = fi_curve_model.get_f_inf_slope() + f_zero_slopes = [fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(x) for x in self.fi_contrasts] + # print("fi-curve features calculated!") # calculate errors with reference values error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) @@ -413,20 +404,20 @@ class Fitter: error_sc = abs((serial_correlation[i] - self.serial_correlation[i]) / self.serial_correlation[i]) error_sc = error_sc / self.sc_max_lag - error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / self.f_inf_slope) * 4 - error_f_inf = calculate_f_values_error(f_infinities, self.f_inf_values) * .5 + error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / self.f_inf_slope) + error_f_inf = calculate_list_error(f_infinities, self.f_inf_values) - error_f_zero_slope = abs((f_zero_slope - self.f_zero_slope) / self.f_zero_slope) - error_f_zero = calculate_f_values_error(f_zeros, self.f_zero_values) + error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes) + error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) error_list = [error_bf, error_vs, error_sc, error_cv, - error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] + error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slopes] if error_weights is not None and len(error_weights) == len(error_list): for i in range(len(error_weights)): error_list[i] = error_list[i] * error_weights[i] - elif len(error_weights) != len(error_list): + if len(error_weights) != len(error_list): warn("Error weights had different length than errors and were ignored!") error = sum(error_list) @@ -446,14 +437,14 @@ class Fitter: self.f_inf_slope, f_infinities_slope, error_f_inf_slope), "f-infinity values:\nexpected:", np.around(self.f_inf_values), "\ncurrent: ", np.around(f_infinities), "\nerror: {:.3f}\n".format(error_f_inf), - "f-zero slope - expected: {:.0f}, current: {:.0f}, error: {:.3f}\n".format( - self.f_zero_slope, f_zero_slope, error_f_zero_slope), + "f-zero slopes:\nexpected:", np.around(self.f_zero_slopes), "\ncurrent: ", np.around(f_zero_slopes), + "\nerror: {:.3f}".format(error_f_zero_slopes), "f-zero values:\nexpected:", np.around(self.f_zero_values), "\ncurrent: ", np.around(f_zeros), "\nerror: {:.3f}".format(error_f_zero)) return error_list -def calculate_f_values_error(fit, reference): +def calculate_list_error(fit, reference): error = 0 for i in range(len(reference)): # TODO ??? add a constant to f_inf to allow for small differences in small values diff --git a/helperFunctions.py b/helperFunctions.py index 406c195..65b404a 100644 --- a/helperFunctions.py +++ b/helperFunctions.py @@ -14,6 +14,9 @@ def fit_clipped_line(x, y): def fit_boltzmann(x, y): max_f0 = float(max(y)) + if max_f0 == 0: + return [0, 0, 0, 0] + min_f0 = 0.1 # float(min(self.f_zeros)) mean_int = float(np.mean(x)) @@ -21,10 +24,15 @@ def fit_boltzmann(x, y): total_change_int = max(x) - min(x) start_k = float((total_increase / total_change_int * 4) / max_f0) - popt, pcov = curve_fit(fu.full_boltzmann, x, y, - p0=(max_f0, min_f0, start_k, mean_int), - maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf])) - + try: + popt, pcov = curve_fit(fu.full_boltzmann, x, y, + p0=(max_f0, min_f0, start_k, mean_int), + maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf])) + except RuntimeError as e: + print("Error in fit boltzmann: ", str(e)) + print("x_values:", x) + print("y_values:", y) + return [0, 0, 0, 0] return popt @@ -248,7 +256,9 @@ def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float: def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray: isi = np.diff(spiketimes) if len(spiketimes) < max_lag + 1: - raise ValueError("Given list to short, with given max_lag") + warn("Cannot compute serial correlation with list shorter than max lag...") + return np.zeros(max_lag) + # raise ValueError("Given list to short, with given max_lag") cor = np.zeros(max_lag) for lag in range(max_lag): @@ -286,7 +296,7 @@ def calculate_vector_strength_from_v1_trace(times, eods, v1_traces): print("-----LENGTH OF TIMES = 0") for recording in range(len(times)): - spiketime_idices = detect_spikes(v1_traces[recording]) + spiketime_idices = detect_spikes_indices(v1_traces[recording]) rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices) relative_spike_times.extend(rel_spikes) eod_durations.extend(eod_durs) @@ -305,7 +315,7 @@ def calculate_vector_strength_from_spiketimes(time, eod, spiketimes, sampling_in return __vector_strength__(rel_spikes, eod_durs) -def detect_spikes(v1, split=20, threshold=3): +def detect_spikes_indices(v1, split=20, threshold=3): total = len(v1) all_peaks = [] @@ -323,6 +333,12 @@ def detect_spikes(v1, split=20, threshold=3): return all_peaks +def detect_spiketimes(time, v1, split=20, threshold=3): + all_peak_indicies = detect_spikes_indices(v1, split, threshold) + + return [time[p_idx] for p_idx in all_peak_indicies] + + def calculate_phases(relative_spike_times, eod_durations): phase_times = np.zeros(len(relative_spike_times)) @@ -455,6 +471,12 @@ def detect_f_zero_in_frequency_trace(time, frequency, stimulus_start, sampling_i # plt.plot(time[start_idx:end_idx], [f_zero for i in range(end_idx-start_idx)]) # plt.show() + max_frequency = int(1/sampling_interval) + int_f_zero = int(f_zero) + if int_f_zero > max_frequency: + raise AssertionError("Detection of f-zero went very wrong! frequency above 1/sampling_interval.") + if int_f_zero > max(frequency): + raise AssertionError("detected f_zero bigger than the highest peak in the frequency trace...") return f_zero diff --git a/introduction/test.py b/introduction/test.py index f3edf7b..27e19ec 100644 --- a/introduction/test.py +++ b/introduction/test.py @@ -1,12 +1,21 @@ import numpy as np import matplotlib.pyplot as plt import helperFunctions as hF +import functions as fu import time from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus def main(): + x_values = np.arange(-1, 1, 0.01) + popt = [0, 0, 0, 0] + y_values = [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values] + + plt.plot(x_values, y_values) + plt.show() + quit() + for freq in [700, 50, 100, 500, 1000]: reps = 1000 start = time.time() diff --git a/variableEffect.py b/variableEffect.py new file mode 100644 index 0000000..0f06fcc --- /dev/null +++ b/variableEffect.py @@ -0,0 +1,155 @@ + +from models.LIFACnoise import LifacNoiseModel +from Baseline import BaselineModel +from FiCurve import FICurveModel +import numpy as np +import matplotlib.pyplot as plt +import copy + + +SEARCH_WIDTH = 3 +SEARCH_PRECISION = 30 +CONTRASTS = np.arange(-0.4, 0.45, 0.1) + +def main(): + model_parameters = {'threshold': 1, + 'step_size': 5e-05, + 'a_zero': 2, + 'delta_a': 0.2032269898801589, + 'mem_tau': 0.011314027210564803, + 'noise_strength': 0.056724809998220195, + 'v_zero': 0, + 'v_base': 0, + 'tau_a': 0.05958195972016753, + 'input_scaling': 119.81500448274554, + 'dend_tau': 0.0027746086464721723, + 'v_offset': -24.21875} + + + + parameters_to_test = ["input_scaling", "dend_tau", "mem_tau", "noise_strength", "v_offset", "delta_a", "tau_a"] + effect_data = [] + for p in parameters_to_test: + print("Working on parameter " + p) + effect_data.append(test_parameter_effect(model_parameters, p)) + + plot_effects(effect_data, "./figures/variable_effect/") + + +def test_parameter_effect(model_parameters, test_parameter): + model_parameters = copy.deepcopy(model_parameters) + start_value = model_parameters[test_parameter] + + start = start_value*(1/SEARCH_WIDTH) + end = start_value*SEARCH_WIDTH + step = (end - start) / SEARCH_PRECISION + values = np.arange(start, end+step, step) + + bf = [] + vs = [] + sc = [] + cv = [] + + f_inf_s = [] + f_inf_v = [] + f_zero_s = [] + f_zero_v = [] + + broken_i = [] + for i in range(len(values)): + model_parameters[test_parameter] = values[i] + model = LifacNoiseModel(model_parameters) + + fi_curve = FICurveModel(model, CONTRASTS, 600, trials=50) + f_inf_s.append(fi_curve.get_f_inf_slope()) + f_inf_v.append(fi_curve.get_f_inf_frequencies()) + f_zero_s.append(fi_curve.get_f_zero_fit_slope_at_stimulus_value(0.1)) + f_zero_v.append(fi_curve.get_f_zero_frequencies()) + + baseline = BaselineModel(model, 600, trials=10) + bf.append(baseline.get_baseline_frequency()) + vs.append(baseline.get_vector_strength()) + sc.append(baseline.get_serial_correlation(2)) + cv.append(baseline.get_coefficient_of_variation()) + + values = list(values) + if len(broken_i) > 0: + broken_i = sorted(broken_i, reverse=True) + for i in broken_i: + del values[i] + + return ParameterEffectData(values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v) + # plot_effects(values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v) + + +def plot_effects(par_effect_data_list, save_path=None): + + fig, axes = plt.subplots(8, len(par_effect_data_list), figsize=(32, 4*len(par_effect_data_list)), sharex="col") + + names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v", "f_zero_s", "f_zero_v") + + for j in range(len(par_effect_data_list)): + ped = par_effect_data_list[j] + + ranges = ((0, max(ped.get_data("bf")) * 1.1), (0, 1), (-1, 1), (0, 1), + (0, max(ped.get_data("f_inf_s")) * 1.1), (0, 800), + (0, max(ped.get_data("f_zero_s")) * 1.1), (0, 3000)) + values = ped.values + + for i in range(len(names)): + y_data = ped.get_data(names[i]) + axes[i, j].plot(values, y_data) + axes[i, j].set_ylim(ranges[i]) + + if j == 0: + axes[i, j].set_ylabel(names[i]) + + if i == 0: + axes[i, j].set_title(ped.test_parameter) + + plt.tight_layout() + if save_path is not None: + plt.savefig(save_path + "variable_effect_master_plot.png") + else: + plt.show() + plt.close() + + +class ParameterEffectData: + data_names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v" "f_zero_s", "f_zero_v") + + def __init__(self, values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v): + self.values = values + self.test_parameter = test_parameter + self.bf = bf + self.vs = vs + self.sc = sc + self.cv = cv + self.f_inf_s = f_inf_s + self.f_inf_v = f_inf_v + self.f_zero_s = f_zero_s + self.f_zero_v = f_zero_v + + def get_data(self, name): + if name == "bf": + return self.bf + elif name == "vs": + return self.vs + elif name == "sc": + return self.sc + elif name == "cv": + return self.cv + elif name == "f_inf_s": + return self.f_inf_s + elif name == "f_inf_v": + return self.f_inf_v + elif name == "f_zero_s": + return self.f_zero_s + elif name == "f_zero_v": + return self.f_zero_v + else: + raise ValueError("Unknown attribute name!") + + +if __name__ == '__main__': + main()