diff --git a/AdaptionCurrent.py b/AdaptionCurrent.py index 97e54d2..f94b215 100644 --- a/AdaptionCurrent.py +++ b/AdaptionCurrent.py @@ -139,6 +139,9 @@ class Adaption: self.tau_real = np.median(taus) + def get_tau_real(self): + return np.median(self.tau_real) + def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False): if delete_previous: for val in self.cell_data.get_fi_contrasts(): diff --git a/fit_lifacnoise.py b/fit_lifacnoise.py new file mode 100644 index 0000000..5549c52 --- /dev/null +++ b/fit_lifacnoise.py @@ -0,0 +1,226 @@ +from models.LIFACnoise import LifacNoiseModel +from CellData import CellData, icelldata_of_dir +from FiCurve import FICurve +from AdaptionCurrent import Adaption +from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus +import helperFunctions as hF +import numpy as np +from scipy.optimize import curve_fit, minimize +import functions as fu +import time +import matplotlib.pyplot as plt + +def main(): + + for celldata in icelldata_of_dir("./data/"): + start_time = time.time() + fitter = Fitter(celldata) + fmin, parameters = fitter.fit_model_to_data() + + print(fmin) + print(parameters) + end_time = time.time() + + print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time))) + break + pass + + +class Fitter: + + def __init__(self, data: CellData, step_size=None): + if step_size is not None: + self.model = LifacNoiseModel({"step_size": step_size}) + else: + self.model = LifacNoiseModel({"step_size": 0.05}) + self.data = data + self.fi_contrasts = [] + self.eod_freq = 0 + + self.modulation_frequency = 10 + self.sc_max_lag = 1 + + # expected values the model has to replicate + self.baseline_freq = 0 + self.vector_strength = -1 + self.serial_correlation = [] + + self.f_infinities = [] + self.f_infinities_slope = 0 + + # fixed values needed to fit model + self.a_tau = 0 + self.a_delta = 0 + + self.counter = 0 + self.calculate_needed_values_from_data() + + def calculate_needed_values_from_data(self): + self.eod_freq = self.data.get_eod_frequency() + + self.baseline_freq = self.data.get_base_frequency() + self.vector_strength = self.data.get_vector_strength() + self.serial_correlation = self.data.get_serial_correlation(self.sc_max_lag) + + fi_curve = FICurve(self.data, contrast=True) + self.fi_contrasts = fi_curve.stimulus_value + self.f_infinities = fi_curve.f_infinities + self.f_infinities_slope = fi_curve.get_f_infinity_slope() + + f_zero_slope = fi_curve.get_fi_curve_slope_of_straight() + self.a_delta = f_zero_slope / self.f_infinities_slope + + adaption = Adaption(self.data, fi_curve) + self.a_tau = adaption.get_tau_real() + + # mem_tau, (threshold?), (v_offset), noise_strength, input_scaling + def cost_function(self, X, tau_a=10, delta_a=3, error_scaling=()): + # set model parameters to the given ones: + self.model.set_variable("mem_tau", X[0]) + self.model.set_variable("noise_strength", X[1]) + self.model.set_variable("input_scaling", X[2]) + self.model.set_variable("tau_a", tau_a) + self.model.set_variable("delta_a", delta_a) + + # minimize the difference in baseline_freq first by fitting v_offset + v_offset = self.__fit_v_offset_to_baseline_frequency__() + self.model.set_variable("v_offset", v_offset) + + # only eod with amplitude 1 and no modulation + base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) + _, spiketimes = self.model.simulate_fast(base_stimulus, 30) + + baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 5) + # print("model:", baseline_freq, "data:", self.baseline_freq) + + relative_spiketimes = np.array([s % (1/self.eod_freq) for s in spiketimes]) + eod_durations = np.full((len(spiketimes)), 1/self.eod_freq) + vector_strength = hF.__vector_strength__(relative_spiketimes, eod_durations) + serial_correlation = hF.calculate_serial_correlation(np.array(spiketimes), self.sc_max_lag) + + f_infinities = [] + for contrast in self.fi_contrasts: + stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, contrast, self.modulation_frequency) + _, spiketimes = self.model.simulate_fast(stimulus, 0.5) + + if len(spiketimes) < 2: + f_infinities.append(0) + else: + f_infinity = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 0.4) + f_infinities.append(f_infinity) + + popt, pcov = curve_fit(fu.line, self.fi_contrasts, f_infinities, maxfev=10000) + + f_infinities_slope = popt[0] + + error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) + error_vs = abs((vector_strength - self.vector_strength) / self.vector_strength) + error_sc = abs((serial_correlation[0] - self.serial_correlation[0]) / self.serial_correlation[0]) + error_f_inf_slope = abs((f_infinities_slope - self.f_infinities_slope) / self.f_infinities_slope) + #print("vs:", vector_strength, self.vector_strength) + #print("sc", serial_correlation[0], self.serial_correlation[0]) + #print("f slope:", f_infinities_slope, self.f_infinities_slope) + error_f_inf = 0 + for i in range(len(f_infinities)): + error_f_inf += abs((f_infinities[i] - self.f_infinities[i]) / f_infinities[i]) + + error_f_inf = error_f_inf / len(f_infinities) + self.counter += 1 + # print("mem_tau:", X[0], "noise:", X[0], "input_scaling:", X[2]) + print("Cost function run times:", self.counter, "errors:", [error_bf, error_vs, error_sc, error_f_inf_slope, error_f_inf]) + return error_bf + error_vs + error_sc + error_f_inf_slope + error_f_inf + + def __fit_v_offset_to_baseline_frequency__(self): + test_model = self.model.get_model_copy() + voltage_step_size = 1000 + simulation_time = 2 + v_offset_start = 0 + v_offset_current = v_offset_start + + test_model.set_variable("v_offset", v_offset_current) + base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) + _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) + if len(spiketimes) < 5: + baseline_freq = 0 + else: + baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) + + if baseline_freq < self.baseline_freq: + upwards = True + v_offset_current += voltage_step_size + else: + upwards = False + v_offset_current -= voltage_step_size + + # search for a value below and above the baseline freq: + while True: + # print(self.counter, baseline_freq, self.baseline_freq, v_offset_current) + # self.counter += 1 + test_model.set_variable("v_offset", v_offset_current) + base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) + _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) + + if len(spiketimes) < 2: + baseline_freq = 0 + else: + baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) + + if baseline_freq < self.baseline_freq and upwards: + v_offset_current += voltage_step_size + + elif baseline_freq < self.baseline_freq and not upwards: + break + + elif baseline_freq > self.baseline_freq and upwards: + break + + elif baseline_freq > self.baseline_freq and not upwards: + v_offset_current -= voltage_step_size + + elif baseline_freq == self.baseline_freq: + return v_offset_current + + # found the edges use them to allow binary search: + if upwards: + lower_bound = v_offset_current - voltage_step_size + upper_bound = v_offset_current + else: + lower_bound = v_offset_current + upper_bound = v_offset_current + voltage_step_size + + while True: + middle = lower_bound + (upper_bound - lower_bound)/2 + # print(self.counter, "measured_freq:", baseline_freq, "wanted_freq:", self.baseline_freq, "current middle:", middle) + # self.counter += 1 + test_model.set_variable("v_offset", middle) + base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) + _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) + + if len(spiketimes) < 2: + baseline_freq = 0 + else: + baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) + + if abs(baseline_freq - self.baseline_freq) < 1: + # print("close enough:", baseline_freq, self.baseline_freq, abs(baseline_freq - self.baseline_freq)) + break + elif baseline_freq < self.baseline_freq: + lower_bound = middle + else: + upper_bound = middle + + return middle + + def fit_model_to_data(self): + x0 = np.array([20, 15, 75]) + init_simplex = np.array([np.array([2, 1, 10]), np.array([40, 100, 140]), np.array([20, 50, 70]), np.array([150, 1, 200])]) + fmin = minimize(fun=self.cost_function, x0=x0, args=(self.a_tau, self.a_delta), method="Nelder-Mead", options={"initial_simplex": init_simplex}) + + + #fmin = minimize(fun=self.cost_function, x0=x0, args=(self.a_tau, self.a_delta), method="BFGS") + + return fmin, self.model.get_parameters() + + +if __name__ == '__main__': + main() diff --git a/generalTests.py b/generalTests.py index 3f1aa74..79a2453 100644 --- a/generalTests.py +++ b/generalTests.py @@ -1,4 +1,100 @@ -g = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -print(g[3:]) -print(g[:-3]) + +import helperFunctions as hf +from CellData import icelldata_of_dir +import functions as fu +import numpy as np +import time +import matplotlib.pyplot as plt +import os +from scipy.signal import argrelmax +from thunderfish.eventdetection import detect_peaks +from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus +from models.LIFACnoise import LifacNoiseModel + +def time_test_function(): + for n in [1000]: # number of calls + print("calls:", n) + start = time.time() + for i in range(n): + data = np.random.normal(size=10000) + y = [fu.rectify(x) for x in data] + end = time.time() + print("time:", end - start) + + +def test_cell_data(): + for cell_data in icelldata_of_dir("./data/"): + #if "2012-12-20-ad" not in cell_data.get_data_path(): + # continue + print() + print(cell_data.get_data_path()) + if len(cell_data.get_base_traces(cell_data.TIME)) != 0: + # print("works!") + #print("VS:", cell_data.get_vector_strength()) + #print("SC:", cell_data.get_serial_correlation(5)) + print("Eod freq:", cell_data.get_eod_frequency()) + else: + pass + #print("NNNOOOOOOOOOO!") + #print("spiketimes:", len(cell_data.get_base_spikes())) + #print("Times:", len(cell_data.get_base_traces(cell_data.TIME))) + #print("EOD:", len(cell_data.get_base_traces(cell_data.EOD))) + + +def test_peak_detection(): + for cell_data in icelldata_of_dir("./data/"): + print() + print(cell_data.get_data_path()) + times = cell_data.get_base_traces(cell_data.TIME) + eod = cell_data.get_base_traces(cell_data.EOD) + v1 = cell_data.get_base_traces(cell_data.V1) + for i in range(len(v1)): + pieces = 20 + v1_trace = v1[i] + total = len(v1_trace) + all_peaks = [] + plt.plot(times[i], v1[i]) + + for n in range(pieces): + length = int(total/pieces) + first_index = n*length + last_index = (n+1)*length + std = np.std(v1_trace[first_index:last_index]) + peaks, _ = detect_peaks(v1_trace[first_index:last_index], std * 3) + peaks = peaks + first_index + all_peaks.extend(peaks) + plt.plot(times[i][first_index], v1_trace[first_index], 'o', color="green") + + all_peaks = np.array(all_peaks) + + plt.plot(times[i][all_peaks], v1[i][all_peaks], 'o', color='red') + + plt.show() + + +def test_simulation_speed(): + parameters = {'mem_tau': 21.348990483539083, 'delta_a': 20.41809814660199, 'input_scaling': 3.0391541280864196, 'v_offset': 26.25, 'threshold': 1, 'v_base': 0, 'step_size': 0.01, 'tau_a': 158.0404259501454, 'a_zero': 0, 'v_zero': 0, 'noise_strength': 2.87718460648148} + model = LifacNoiseModel(parameters) + repetitions = 30 + seconds = 10 + stimulus = SinusAmplitudeModulationStimulus(750, 0.3, 10) + t_start = time.time() + for i in range(repetitions): + v, spikes = model.simulate_fast(stimulus, seconds) + + plt.plot(v) + plt.show() + + t_end = time.time() + + + print("took:", round((t_end-t_start)/repetitions, 2), "seconds for " + str(seconds) + "s simulation", "step size:", parameters["step_size"]) + + +if __name__ == '__main__': + # time_test_function() + # test_cell_data() + # test_peak_detection() + test_simulation_speed() + pass diff --git a/helperFunctions.py b/helperFunctions.py index 79d09cc..bef6538 100644 --- a/helperFunctions.py +++ b/helperFunctions.py @@ -1,11 +1,6 @@ -import os -import pyrelacs.DataLoader as dl import numpy as np -import matplotlib.pyplot as plt from warnings import warn -import scipy.stats -from numba import jit -import numba as numba +from thunderfish.eventdetection import detect_peaks, threshold_crossing_times, threshold_crossings def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): @@ -123,6 +118,27 @@ def calculate_mean_frequency(trial_times, trial_freqs): return time, mean_freq +def mean_freq_of_spiketimes_after_time_x(spiketimes, time_x): + """ Calculates the mean frequency of the portion of spiketimes that is after last_x_time """ + + idx = -1 + if time_x < spiketimes[int(len(spiketimes)/2)]: + for i in range(len(spiketimes)): + if spiketimes[i] > time_x: + idx = i + 1 + break + else: + for i in range(len(spiketimes) - 1, -1, -1): + if spiketimes[i] < time_x: + idx = i + 1 + break + + all_isi = np.diff(spiketimes[idx:]) / 1000 + if len(all_isi) < 5: + return 0 + mean_freq = np.mean([1 / isi for isi in all_isi]) + return mean_freq + # @jit(nopython=True) # only faster at around 30 000 calls def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float: # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) @@ -150,12 +166,132 @@ def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.nda return cor +def calculate_eod_frequency(time, eod): + + up_indicies, down_indicies = threshold_crossings(eod, 0) + up_times, down_times = threshold_crossing_times(time, eod, 0, up_indicies, down_indicies) + + durations = np.diff(up_times) + mean_duration = np.mean(durations) + + return 1/mean_duration + + +def calculate_vector_strength(times, eods, v1_traces): + # Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 + # dl.iload_traces(repro='BaselineActivity') + + relative_spike_times = [] + eod_durations = [] + + if len(times) == 0: + print("-----LENGTH OF TIMES = 0") + + for recording in range(len(times)): + spiketime_idices = detect_spikes(v1_traces[recording]) + rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices) + relative_spike_times.extend(rel_spikes) + eod_durations.extend(eod_durs) + print(__vector_strength__(np.array(rel_spikes), np.array(eod_durs))) + + relative_spike_times = np.array(relative_spike_times) + eod_durations = np.array(eod_durations) + + return __vector_strength__(relative_spike_times, eod_durations) + + +def detect_spikes(v1, split=20, threshold=3): + total = len(v1) + all_peaks = [] + + for n in range(split): + length = int(total / split) + first_index = n * length + last_index = (n + 1) * length + std = np.std(v1[first_index:last_index]) + peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold) + peaks = peaks + first_index + all_peaks.extend(peaks) + + all_peaks = np.array(all_peaks) + + return all_peaks + + +def calculate_phases(relative_spike_times, eod_durations): + phase_times = np.zeros(len(relative_spike_times)) + + for i in range(len(relative_spike_times)): + phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi + + return phase_times + + +def eods_around_spikes(time, eod, spiketime_idices): + eod_durations = [] + relative_spike_times = [] + + for spike_idx in spiketime_idices: + + start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx) + + eod_durations.append(end_time-start_time) + spiketime = time[spike_idx] + relative_spike_times.append(spiketime - start_time) + + return relative_spike_times, eod_durations + + +def search_eod_start_and_end_times(time, eod, index): + # TODO might break if a spike is in the cut off first or last eod! + + # search start_time: + previous = index + working_idx = index-1 + while True: + if eod[working_idx] < 0 < eod[previous]: + first_value = eod[working_idx] + second_value = eod[previous] + + dif = second_value - first_value + part = np.abs(first_value/dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + start_time = time[working_idx] + time_dif*part + + break + + previous = working_idx + working_idx -= 1 + + # search end_time + previous = index + working_idx = index + 1 + while True: + if eod[previous] < 0 < eod[working_idx]: + first_value = eod[previous] + second_value = eod[working_idx] + + dif = second_value - first_value + part = np.abs(first_value / dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + end_time = time[working_idx] + time_dif * part + + break + + previous = working_idx + working_idx += 1 + + return start_time, end_time + + def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray): # adapted from Ramona n = len(relative_spike_times) if n == 0: - return 0 + return -1 phase_times = (relative_spike_times / eod_durations) * 2 * np.pi vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2) diff --git a/introduction/introductionBaseline.py b/introduction/introductionBaseline.py index 6cbb14a..76b944b 100644 --- a/introduction/introductionBaseline.py +++ b/introduction/introductionBaseline.py @@ -106,7 +106,7 @@ def calculate_vector_strength(times, eods, spiketimes, v1s): relative_spike_times.extend(rel_spikes) eod_durations.extend(eod_durs) - vs = __vector_strength__(rel_spikes, eod_durs) + vs = __vector_strength__(np.array(rel_spikes), eod_durs) phases = calculate_phases(rel_spikes, eod_durs) plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png") @@ -114,7 +114,7 @@ def calculate_vector_strength(times, eods, spiketimes, v1s): plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording]) - return __vector_strength__(relative_spike_times, eod_durations) + return __vector_strength__(np.array(relative_spike_times), eod_durations) def eods_around_spikes(time, eod, spiketimes): diff --git a/introduction/test_minimize.py b/introduction/test_minimize.py index e2fc298..00e85db 100644 --- a/introduction/test_minimize.py +++ b/introduction/test_minimize.py @@ -6,7 +6,7 @@ import numpy as np def main(): guess = np.zeros(3) fmin = minimize(fun=cost1, x0=guess, args=(3, 20, -25), method="Nelder-Mead") - + print(np.mean(fmin["final_simplex"][0], axis=0)) print(fmin) diff --git a/models/AbstractModel.py b/models/AbstractModel.py index dc601a2..9bbe623 100644 --- a/models/AbstractModel.py +++ b/models/AbstractModel.py @@ -67,7 +67,7 @@ class AbstractModel: for k in params.keys(): self.parameters[k] = params[k] - for key in range(len(self.DEFAULT_VALUES.keys())): + for key in self.DEFAULT_VALUES.keys(): if key not in self.parameters.keys(): self.parameters[key] = self.DEFAULT_VALUES[key] diff --git a/models/LIFACnoise.py b/models/LIFACnoise.py index ba62546..906f1c6 100644 --- a/models/LIFACnoise.py +++ b/models/LIFACnoise.py @@ -3,6 +3,8 @@ from stimuli.AbstractStimulus import AbstractStimulus from models.AbstractModel import AbstractModel import numpy as np import functions as fu +from numba import jit +import time class LifacNoiseModel(AbstractModel): @@ -30,6 +32,7 @@ class LifacNoiseModel(AbstractModel): # self.frequency_trace = [] def simulate(self, stimulus: AbstractStimulus, total_time_s): + self.stimulus = stimulus output_voltage = [] adaption = [] @@ -61,6 +64,31 @@ class LifacNoiseModel(AbstractModel): return output_voltage, spiketimes + def simulate_fast(self, stimulus: AbstractStimulus, total_time_s): + + v_zero = self.parameters["v_zero"] + a_zero = self.parameters["a_zero"] + step_size = self.parameters["step_size"] + threshold = self.parameters["threshold"] + v_base = self.parameters["v_base"] + delta_a = self.parameters["delta_a"] + tau_a = self.parameters["tau_a"] + v_offset = self.parameters["v_offset"] + mem_tau = self.parameters["mem_tau"] + noise_strength = self.parameters["noise_strength"] + + stimulus_array = stimulus.as_array(total_time_s, step_size) + + parameters = np.array([v_zero, a_zero, step_size, threshold, v_base, delta_a, tau_a, v_offset, mem_tau, noise_strength]) + voltage_trace, adaption, spiketimes = simulate_fast(stimulus_array, total_time_s, parameters) + + self.stimulus = stimulus + self.voltage_trace = voltage_trace + self.adaption_trace = adaption + self.spiketimes = spiketimes + + return voltage_trace, spiketimes + def _calculate_voltage_step(self, current_v, input_v): v_base = self.parameters["v_base"] step_size = self.parameters["step_size"] @@ -114,3 +142,80 @@ class LifacNoiseModel(AbstractModel): total_time = len(self.voltage_trace) / self.parameters["step_size"] return [delay, start, duration, total_time] + + def get_model_copy(self): + return LifacNoiseModel(self.parameters) + + def calculate_baseline_markers(self, stimulus_freq=750): + """ + calculates the baseline markers baseline frequency, vector strength and serial correlation + based on simulated 30 seconds with a standard Sinusoidal stimulus with the given frequency + + :return: baseline_freq, vs, sc + """ + + + + + pass + + def calculate_fi_markers(self, contrasts, ): + """ + calculates the fi markers f_infinity, f_infinity_slope for given contrasts + based on simulated 2 seconds for each contrast + :return: + """ + + +def stimulus_to_numpy_array(stimulus: AbstractStimulus, total_time_s, step_size): + total_time_points = int(total_time_s * 1000 / step_size) + stimulus_values = np.zeros(total_time_points) + for idx in range(len(stimulus_values)): + # rectified input: + stimulus_values[idx] = fu.rectify(stimulus.value_at_time_in_ms(step_size*idx)) + + return stimulus_values + + +@jit(nopython=True) +def simulate_fast(rectified_stimulus_array, total_time_s, parameters: np.ndarray): + + v_zero = parameters[0] + a_zero = parameters[1] + step_size = parameters[2] + threshold = parameters[3] + v_base = parameters[4] + delta_a = parameters[5] + tau_a = parameters[6] + v_offset = parameters[7] + mem_tau = parameters[8] + noise_strength = parameters[9] + + time = np.arange(0, total_time_s * 1000, step_size) + length = len(time) + output_voltage = np.zeros(length) + adaption = np.zeros(length) + stimulus_values = rectified_stimulus_array + + spiketimes = [] + output_voltage[0] = v_zero + adaption[0] = a_zero + + for i in range(len(time)-1): + noise_value = np.random.normal() + noise = noise_strength * noise_value / np.sqrt(step_size) + + output_voltage[i] = output_voltage[i-1] + ((v_base - output_voltage[i-1] + v_offset + stimulus_values[i] - adaption[i-1] + noise) / mem_tau) * step_size + adaption[i] = adaption[i-1] + ((-adaption[i-1]) / tau_a) * step_size + + if output_voltage[i] > threshold: + output_voltage[i] = v_base + spiketimes.append(i*step_size) + adaption[i] += delta_a / (tau_a / 1000) + + return output_voltage, adaption, spiketimes + + + + + diff --git a/stimuli/AbstractStimulus.py b/stimuli/AbstractStimulus.py index 649f7d6..fdf8f30 100644 --- a/stimuli/AbstractStimulus.py +++ b/stimuli/AbstractStimulus.py @@ -8,22 +8,25 @@ class AbstractStimulus: raise NotImplementedError("This is an abstract class!") def get_stimulus_start_ms(self): - raise NotImplementedError("This is an abstract class!") + return self.get_stimulus_start_s() * 1000 def get_stimulus_start_s(self): - return self.get_stimulus_start_ms() / 1000 + raise NotImplementedError("This is an abstract class!") def get_stimulus_duration_ms(self): - raise NotImplementedError("This is an abstract class!") + return self.get_stimulus_duration_s() * 1000 def get_stimulus_duration_s(self): - return self.get_stimulus_duration_ms() / 1000 + raise NotImplementedError("This is an abstract class!") def get_stimulus_end_ms(self): return self.get_stimulus_start_ms() + self.get_stimulus_duration_ms() def get_stimulus_end_s(self): - return self.get_stimulus_end_ms() / 1000 + return self.get_stimulus_start_s() + self.get_stimulus_duration_s() def get_amplitude(self): raise NotImplementedError("This is an abstract class!") + + def as_array(self, time_start, total_time, step_size): + raise NotImplementedError("This is an abstract class!") diff --git a/stimuli/StepStimulus.py b/stimuli/StepStimulus.py index cae3bf9..34a6ac4 100644 --- a/stimuli/StepStimulus.py +++ b/stimuli/StepStimulus.py @@ -5,8 +5,9 @@ from stimuli.AbstractStimulus import AbstractStimulus class StepStimulus(AbstractStimulus): def __init__(self, start, duration, value, base_value=0, seconds=True): - self.start = 0 - self.duration = 0 + if duration < 0: + raise ValueError("Duration cannot be negative") + self.base_value = base_value self.value = value if seconds: @@ -22,10 +23,10 @@ class StepStimulus(AbstractStimulus): else: return self.base_value - def get_stimulus_start_ms(self): + def get_stimulus_start_s(self): return self.start - def get_stimulus_duration_ms(self): + def get_stimulus_duration_s(self): return self.duration def get_amplitude(self): diff --git a/stimuli/StimulusSequence.py b/stimuli/StimulusSequence.py new file mode 100644 index 0000000..917cc12 --- /dev/null +++ b/stimuli/StimulusSequence.py @@ -0,0 +1,20 @@ + +from stimuli.AbstractStimulus import AbstractStimulus + + +class StimulusSequence(AbstractStimulus): + + def __init__(self, stimulus_list): + self.stimuli = stimulus_list + + def value_at_time_in_s(self, time_point): + pass + + def get_stimulus_start_ms(self): + pass + + def get_stimulus_duration_ms(self): + pass + + def get_amplitude(self): + pass \ No newline at end of file diff --git a/tests/ModelTests.py b/tests/ModelTests.py index 659a447..8ca28b3 100644 --- a/tests/ModelTests.py +++ b/tests/ModelTests.py @@ -3,10 +3,10 @@ import matplotlib.pyplot as plt import numpy as np import helperFunctions as hf from models.FirerateModel import FirerateModel -from models.LIFAC import LIFACModel from models.LIFACnoise import LifacNoiseModel from stimuli.StepStimulus import StepStimulus from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus +import functions as fu def main(): @@ -14,6 +14,7 @@ def main(): # test_stepsize_influence() test_lifac_noise() + def test_stepsize_influence(): # model = LIFACModel() model = FirerateModel() @@ -60,10 +61,10 @@ def test_lifac_noise(): model.simulate(stimulus, total_time) - fig, axes = plt.subplots(nrows=3, sharex=True) + fig, axes = plt.subplots(nrows=3, sharex="col") sparse_time = np.arange(0, total_time, 1/5000) axes[0].plot(sparse_time, [stimulus.value_at_time_in_s(x) for x in sparse_time], label="given stimulus") - axes[0].plot(sparse_time, [hf.rectify(stimulus.value_at_time_in_s(x)) for x in sparse_time], label="seen stimulus") + axes[0].plot(sparse_time, [fu.rectify(stimulus.value_at_time_in_s(x)) for x in sparse_time], label="seen stimulus") axes[0].set_title("Stimulus") axes[0].set_ylabel("stimulus strength") axes[0].legend() @@ -79,7 +80,6 @@ def test_lifac_noise(): spiketimes_small_step = model.get_spiketimes() - model.set_variable("step_size", 0.02) model.simulate(stimulus, total_time) print(model.get_adaption_trace()[int(0.1/(0.01/1000))]) @@ -119,6 +119,5 @@ def test_lifac_noise(): plt.show() - if __name__ == '__main__': main() \ No newline at end of file diff --git a/tests/stimuli/TestStepStimulus.py b/tests/stimuli/TestStepStimulus.py new file mode 100644 index 0000000..4d3c0a7 --- /dev/null +++ b/tests/stimuli/TestStepStimulus.py @@ -0,0 +1,55 @@ +import unittest +from stimuli.StepStimulus import StepStimulus +import numpy as np + + +class TestStepStimulus(unittest.TestCase): + + def test_time_getters(self): + starts = [2, -5, 0, 10] + durations = [2, 1000, 0.5] + value = 10 + + for start in starts: + for duration in durations: + stimulus = StepStimulus(start, duration, value) + + self.assertEqual(start, stimulus.get_stimulus_start_s(), "reported start (s) was wrong") + self.assertEqual(duration, stimulus.get_stimulus_duration_s(), "reported duration (s) was wrong") + self.assertEqual(start + duration, stimulus.get_stimulus_end_s(), "reported end (s) was wrong") + + self.assertEqual(start * 1000, stimulus.get_stimulus_start_ms(), "reported start (ms) was wrong") + self.assertEqual(duration * 1000, stimulus.get_stimulus_duration_ms(), "reported duration (ms) was wrong") + self.assertEqual((start + duration)*1000, stimulus.get_stimulus_end_ms(), "reported end (ms) was wrong") + + def test_duration_must_be_positive(self): + self.assertRaises(ValueError, StepStimulus, 1, -1, 3) + + def test_value_at(self): + start = 0 + duration = 2 + value = 5 + base_value = -1 + stimulus = StepStimulus(start, duration, value, base_value) + + for i in np.arange(start-1, start+duration+1, 0.1): + if i < start or i > start+duration: + self.assertEqual(stimulus.value_at_time_in_s(i), base_value) + self.assertEqual(stimulus.value_at_time_in_ms(i*1000), base_value) + else: + self.assertEqual(stimulus.value_at_time_in_s(i), value) + self.assertEqual(stimulus.value_at_time_in_ms(i * 1000), value) + + def test_amplitude(self): + stim_values = [-10, -5, 0, 15, 20] + base_values = [-10, -5, -2, 0, 1, 15] + + for s_value in stim_values: + for b_value in base_values: + stimulus = StepStimulus(0, 1, s_value, b_value) + + self.assertEqual(stimulus.get_amplitude(), s_value-b_value) + + +if __name__ == '__main__': + unittest.main()