diff --git a/CellData.py b/CellData.py index 40f0725..ad0d165 100644 --- a/CellData.py +++ b/CellData.py @@ -1,4 +1,3 @@ - import DataParserFactory as dpf from warnings import warn from os import listdir @@ -88,16 +87,16 @@ class CellData: def get_mean_isi_frequencies(self): if self.mean_isi_frequencies is None: - self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), - self.get_time_start(), - self.get_sampling_interval()) + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces( + self.get_fi_spiketimes(), self.get_sampling_interval()) + return self.mean_isi_frequencies def get_time_axes_mean_frequencies(self): if self.time_axes is None: - self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), - self.get_time_start(), - self.get_sampling_interval()) + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces( + self.get_fi_spiketimes(), self.get_sampling_interval()) + return self.time_axes def get_base_frequency(self): @@ -163,7 +162,7 @@ class CellData: sampling_interval = self.get_sampling_interval() frequencies = [] for eod in eods: - time = np.arange(0, len(eod)*sampling_interval, sampling_interval) + time = np.arange(0, len(eod) * sampling_interval, sampling_interval) frequencies.append(hf.calculate_eod_frequency(time, eod)) return np.mean(frequencies) @@ -172,7 +171,8 @@ class CellData: if self.fi_spiketimes is None: trans_amplitudes, intensities, spiketimes = self.parser.get_fi_curve_spiketimes() - self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities(intensities, spiketimes, trans_amplitudes) + self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities( + intensities, spiketimes, trans_amplitudes) # def get_metadata(self): # self.__read_metadata__() diff --git a/RelationAdaptionVariables.py b/RelationAdaptionVariables.py index d330d96..55fc47f 100644 --- a/RelationAdaptionVariables.py +++ b/RelationAdaptionVariables.py @@ -36,7 +36,7 @@ def find_fitting_line(lifac_model, stimulus_strengths): if len(spiketimes) == 0: frequencies.append(0) continue - time, freq = hf.calculate_isi_frequency(spiketimes, 0, lifac_model.get_sampling_interval() / 1000) + time, freq = hf.calculate_isi_frequency_trace(spiketimes, 0, lifac_model.get_sampling_interval() / 1000) frequencies.append(freq[-1]) @@ -72,7 +72,7 @@ def find_relation(lifac, line_vars, stimulus_strengths, parameter="", value=0, c stimulus = StepStimulus(0, duration, stim) lifac.simulate(stimulus, duration) spiketimes = lifac.get_spiketimes() - time, freq = hf.calculate_isi_frequency(spiketimes, 0, lifac.get_sampling_interval()/1000) + time, freq = hf.calculate_isi_frequency_trace(spiketimes, 0, lifac.get_sampling_interval() / 1000) adapted_frequencies.append(freq[-1]) goal_adapted_freq = freq[-1] diff --git a/fit_lifacnoise.py b/fit_lifacnoise.py index bf52acb..59bd5c3 100644 --- a/fit_lifacnoise.py +++ b/fit_lifacnoise.py @@ -12,8 +12,14 @@ import matplotlib.pyplot as plt def main(): - run_test_with_fixed_model() + # run_test_with_fixed_model() + # quit() + fitter = Fitter() + fmin, params = fitter.fit_model_to_values(700, 1400, [-0.3], 1, [0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3], [1370, 1380, 1390, 1400, 1410, 1420, 1430], 100, 0.02, 0.01) + + print("calculated parameters:") + print(params) def run_with_real_data(): for celldata in icelldata_of_dir("./data/"): @@ -26,7 +32,7 @@ def run_with_real_data(): end_time = time.time() print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time))) - break + pass @@ -35,14 +41,13 @@ def run_test_with_fixed_model(): a_delta = 0.08 parameters = {'mem_tau': 5, 'delta_a': a_delta, 'input_scaling': 100, - 'v_offset': 50, 'threshold': 1, 'v_base': 0, 'step_size': 0.05, 'tau_a': a_tau, + 'v_offset': 80, 'threshold': 1, 'v_base': 0, 'step_size': 0.00005, 'tau_a': a_tau, 'a_zero': 0, 'v_zero': 0, 'noise_strength': 0.5} model = LifacNoiseModel(parameters) eod_freq = 750 contrasts = np.arange(0.5, 1.51, 0.1) modulation_freq = 10 - print(contrasts) baseline_freq, vector_strength, serial_correlation = model.calculate_baseline_markers(eod_freq) f_infinities, f_infinities_slope = model.calculate_fi_markers(contrasts, eod_freq, modulation_freq) @@ -61,7 +66,7 @@ class Fitter: if step_size is not None: self.model = LifacNoiseModel({"step_size": step_size}) else: - self.model = LifacNoiseModel({"step_size": 0.05}) + self.model = LifacNoiseModel({"step_size": 0.0005}) # self.data = data self.fi_contrasts = [] self.eod_freq = 0 @@ -113,11 +118,12 @@ class Fitter: # minimize the difference in baseline_freq first by fitting v_offset # v_offset = self.__fit_v_offset_to_baseline_frequency__() - v_offset = self.model.find_v_offset(self.baseline_freq, self.eod_freq) + base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) + + v_offset = self.model.find_v_offset(self.baseline_freq, base_stimulus) self.model.set_variable("v_offset", v_offset) # only eod with amplitude 1 and no modulation - base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) _, spiketimes = self.model.simulate_fast(base_stimulus, 30) baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, freq_sampling_rate, 5) @@ -131,12 +137,12 @@ class Fitter: f_infinities = [] for contrast in self.fi_contrasts: stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, contrast, self.modulation_frequency) - _, spiketimes = self.model.simulate_fast(stimulus, 0.5) + _, spiketimes = self.model.simulate_fast(stimulus, 1) if len(spiketimes) < 2: f_infinities.append(0) else: - f_infinity = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, freq_sampling_rate, 0.4) + f_infinity = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, freq_sampling_rate, 0.5) f_infinities.append(f_infinity) popt, pcov = curve_fit(fu.line, self.fi_contrasts, f_infinities, maxfev=10000) @@ -161,87 +167,6 @@ class Fitter: print("Cost function run times:", self.counter, "error sum:", sum(errors), errors) return error_bf + error_vs + error_sc + error_f_inf_slope + error_f_inf - def __fit_v_offset_to_baseline_frequency__(self): - test_model = self.model.get_model_copy() - voltage_step_size = 1000 - simulation_time = 2 - v_offset_start = 0 - v_offset_current = v_offset_start - - test_model.set_variable("v_offset", v_offset_current) - base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) - _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) - if len(spiketimes) < 5: - baseline_freq = 0 - else: - baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) - - if baseline_freq < self.baseline_freq: - upwards = True - v_offset_current += voltage_step_size - else: - upwards = False - v_offset_current -= voltage_step_size - - # search for a value below and above the baseline freq: - while True: - # print(self.counter, baseline_freq, self.baseline_freq, v_offset_current) - # self.counter += 1 - test_model.set_variable("v_offset", v_offset_current) - base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) - _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) - - if len(spiketimes) < 2: - baseline_freq = 0 - else: - baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) - - if baseline_freq < self.baseline_freq and upwards: - v_offset_current += voltage_step_size - - elif baseline_freq < self.baseline_freq and not upwards: - break - - elif baseline_freq > self.baseline_freq and upwards: - break - - elif baseline_freq > self.baseline_freq and not upwards: - v_offset_current -= voltage_step_size - - elif baseline_freq == self.baseline_freq: - return v_offset_current - - # found the edges use them to allow binary search: - if upwards: - lower_bound = v_offset_current - voltage_step_size - upper_bound = v_offset_current - else: - lower_bound = v_offset_current - upper_bound = v_offset_current + voltage_step_size - - while True: - middle = lower_bound + (upper_bound - lower_bound)/2 - # print(self.counter, "measured_freq:", baseline_freq, "wanted_freq:", self.baseline_freq, "current middle:", middle) - # self.counter += 1 - test_model.set_variable("v_offset", middle) - base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0) - _, spiketimes = test_model.simulate_fast(base_stimulus, simulation_time) - - if len(spiketimes) < 2: - baseline_freq = 0 - else: - baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_time/2) - - if abs(baseline_freq - self.baseline_freq) < 5: - # print("close enough:", baseline_freq, self.baseline_freq, abs(baseline_freq - self.baseline_freq)) - break - elif baseline_freq < self.baseline_freq: - lower_bound = middle - else: - upper_bound = middle - - return middle - def fit_model_to_data(self, data: CellData): self.calculate_needed_values_from_data(data) return self.fit_model() diff --git a/helperFunctions.py b/helperFunctions.py index c97906d..74a0d2f 100644 --- a/helperFunctions.py +++ b/helperFunctions.py @@ -55,35 +55,50 @@ def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes return intensities, spiketimes, trans_amplitudes -def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval): +def all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, time_in_ms=True): + """ + Expects spiketimes to be a 3dim list with the first dimension being the trial + the second the count of runs of spikes and the last the individual spikes_times: + [[[trial1-run1-spike1, trial1-run1-spike2, ...],[trial1-run2-spike1, ...]],[[trial2-run1-spike1, ...], [..]]] + :param spiketimes: + :param sampling_interval: + :param time_in_ms: + :return: the mean frequency trace for each trial and its time trace + """ times = [] mean_frequencies = [] for i in range(len(spiketimes)): - trial_times = [] - trial_means = [] + trial_time_trace = [] + trial_freq_trace = [] for j in range(len(spiketimes[i])): - time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval) - trial_means.append(isi_freq) - trial_times.append(time) + time, isi_freq = calculate_time_and_frequency_trace(spiketimes[i][j], sampling_interval, time_in_ms) + trial_freq_trace.append(isi_freq) + trial_time_trace.append(time) - time, mean_freq = calculate_mean_frequency(trial_times, trial_means) + time, mean_freq = calculate_mean_of_frequency_traces(trial_time_trace, trial_freq_trace, sampling_interval) times.append(time) mean_frequencies.append(mean_freq) return times, mean_frequencies -def calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms=True): +def calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): """ Calculates the frequency over time according to the inter spike intervals. - :param spiketimes: time points spikes were measured array_like + :param spiketimes: sorted time points spikes were measured array_like :param sampling_interval: the sampling interval in which the frequency should be given back :param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval :return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike """ + + if len(spiketimes) <= 1: + return [] + isis = np.diff(spiketimes) + if sampling_interval > min(isis): + raise ValueError("The sampling interval is bigger than the some isis! cannot accurately compute the trace.") if time_in_ms: isis = isis / 1000 @@ -91,6 +106,8 @@ def calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms=True): full_frequency = np.array([]) for isi in isis: + if isi < 0: + raise ValueError("There was a negative interspike interval, the spiketimes need to be sorted") if isi == 0: warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") print("ISI was zero:", spiketimes) @@ -102,15 +119,40 @@ def calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms=True): return full_frequency -def calculate_mean_frequency(trial_times, trial_freqs): - lengths = [len(t) for t in trial_times] - shortest = min(lengths) +def calculate_time_and_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): + frequency = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms) + + time = np.arange(spiketimes[0], spiketimes[-1], sampling_interval) + + return time, frequency + + +def calculate_mean_of_frequency_traces(trial_time_traces, trial_frequency_traces, sampling_interval): + """ + calculates the mean_trace of the given frequency traces -> mean at each time point + for traces starting at different times + :param trial_time_traces: + :param trial_frequency_traces: + :param sampling_interval: + :return: + """ + ends = [t[-1] for t in trial_time_traces] + starts = [t[0] for t in trial_time_traces] + latest_start = max(starts) + earliest_end = min(ends) + + shortened_time = np.arange(latest_start, earliest_end+sampling_interval, sampling_interval) + + shortened_freqs = [] + for i in range(len(trial_frequency_traces)): + start_idx = int((latest_start - trial_time_traces[i][0]) / sampling_interval) + end_idx = int((earliest_end - trial_time_traces[i][0]) / sampling_interval) - time = trial_times[0][0:shortest] - shortend_freqs = [freq[0:shortest] for freq in trial_freqs] - mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)] + shortened_freqs.append(trial_frequency_traces[i][start_idx:end_idx]) - return time, mean_freq + mean_freq = [sum(e) / len(e) for e in zip(*shortened_freqs)] + + return shortened_time, mean_freq def mean_freq_of_spiketimes_after_time_x(spiketimes, sampling_interval, time_x, time_in_ms=False): @@ -119,14 +161,29 @@ def mean_freq_of_spiketimes_after_time_x(spiketimes, sampling_interval, time_x, if len(spiketimes) <= 1: return 0 - freq = calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms) + freq = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms) # returned frequency starts at the idx = int((time_x-spiketimes[0]) / sampling_interval) - - mean_freq = np.mean(freq[idx:]) + rest_array = freq[idx:] + mean_freq = np.mean(rest_array) return mean_freq +def calculate_mean_isi_freq(spiketimes, time_in_ms=False): + if len(spiketimes) < 2: + return 0 + + isis = np.diff(spiketimes) + if time_in_ms: + isis = isis / 1000 + freqs = 1 / isis + weights = isis / np.min(isis) + + return sum(freqs * weights) / sum(weights) + + + + # @jit(nopython=True) # only faster at around 30 000 calls def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float: # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) diff --git a/introduction/test.py b/introduction/test.py index 5792884..e120870 100644 --- a/introduction/test.py +++ b/introduction/test.py @@ -1,10 +1,65 @@ import numpy as np import matplotlib.pyplot as plt - +import helperFunctions as hF +import time def main(): - pass + for freq in [700, 50, 100, 500, 1000]: + reps = 1000 + start = time.time() + for i in range(reps): + mean_isi = 1 / freq + n = 0.7 + phase_locking_strength = 0.7 + size = 100000 + + final_isis = np.array([]) + while len(final_isis) < size: + + isis = np.random.normal(mean_isi, mean_isi*n, size) + isi_phase = (isis % mean_isi) / mean_isi + diff = abs_phase_diff(isi_phase, 0.5) + chance = np.random.random(size) + + isis_phase_cleaned = [] + for i in range(len(diff)): + if 1-diff[i]**0.05 > chance[i]: + isis_phase_cleaned.append(isis[i]) + + final_isis = np.concatenate((final_isis, isis_phase_cleaned)) + + spikes = np.cumsum(final_isis) + spikes = np.sort(spikes[spikes > 0]) + clean_isis = np.diff(spikes) + + bins = np.arange(-0.01, 0.01, 0.0001) + plt.hist(clean_isis, alpha=0.5, bins=bins) + plt.hist(isis, alpha=0.5, bins=bins) + plt.show() + quit() + + end = time.time() + + print("It took {:.2f} s to simulate 10s of spikes at {} Hz".format(end-start, freq)) + + +def abs_phase_diff(rel_phases:list, ref_phase:float): + """ + + :param rel_phases: relative phases as a list of values between 0 and 1 + :param ref_phase: reference phase to which the difference is calculated (between 0 and 1) + :return: list of absolute differences + """ + + diff = [abs(min(x-ref_phase, x-ref_phase+1)) for x in rel_phases] + + return diff if __name__ == '__main__': - main() + print(-2.4%0.35, (int(-2.4/0.35)-1)*0.35) + + hF.calculate_isi_frequency_trace([0, 2, 1, 3], 0.5) + + + #main() diff --git a/models/LIFACnoise.py b/models/LIFACnoise.py index a3eef19..40b3a11 100644 --- a/models/LIFACnoise.py +++ b/models/LIFACnoise.py @@ -238,7 +238,8 @@ def binary_search_base_freq(model: LifacNoiseModel, base_stimulus, goal_frequenc def test_v_offset(model: LifacNoiseModel, v_offset, base_stimulus, simulation_length): model.set_variable("v_offset", v_offset) _, spiketimes = model.simulate_fast(base_stimulus, simulation_length) - freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 0.005, simulation_length/3) + + freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 0.0005, simulation_length/3) return freq diff --git a/tests/ModelTests.py b/tests/ModelTests.py index 8ca28b3..b79aad9 100644 --- a/tests/ModelTests.py +++ b/tests/ModelTests.py @@ -73,7 +73,7 @@ def test_lifac_noise(): axes[1].set_title("Voltage trace") axes[1].set_ylabel("voltage") - t, f = hf.calculate_isi_frequency(model.get_spiketimes(), 0, step_size) + t, f = hf.calculate_isi_frequency_trace(model.get_spiketimes(), 0, step_size) axes[2].plot(t, f) axes[2].set_title("ISI frequency trace") axes[2].set_ylabel("Frequency") @@ -85,7 +85,7 @@ def test_lifac_noise(): print(model.get_adaption_trace()[int(0.1/(0.01/1000))]) step_size = model.get_parameters()["step_size"] / 1000 time = np.arange(0, total_time, step_size) - t, f = hf.calculate_isi_frequency(model.get_spiketimes(), 0, step_size) + t, f = hf.calculate_isi_frequency_trace(model.get_spiketimes(), 0, step_size) axes[1].plot(time, model.get_voltage_trace()) axes[2].plot(t, f) diff --git a/unittests/testFrequencyFunctions.py b/unittests/testFrequencyFunctions.py new file mode 100644 index 0000000..ed342ce --- /dev/null +++ b/unittests/testFrequencyFunctions.py @@ -0,0 +1,241 @@ +import unittest +import numpy as np +import helperFunctions as hF +import matplotlib.pyplot as plt +from warnings import warn + + +class FrequencyFunctionsTester(unittest.TestCase): + + noise_levels = [0, 0.05, 0.1, 0.2] + frequencies = [0, 1, 5, 30, 100, 500, 750, 1000] + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_calculate_eod_frequency(self): + start = 0 + end = 5 + step = 0.1 / 1000 + freqs = [0, 1, 10, 500, 700, 1000] + for freq in freqs: + time = np.arange(start, end, step) + eod = np.sin(freq*(2*np.pi) * time) + self.assertEqual(freq, round(hF.calculate_eod_frequency(time, eod), 2)) + + def test_mean_freq_of_spiketimes_after_time_x(self): + simulation_time = 8 + for freq in self.frequencies: + for n in self.noise_levels: + spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) + sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.00005, simulation_time/4, time_in_ms=False) + + max_diff = round(n*(10+0.7*np.sqrt(freq)), 2) + # print("noise: {:.2f}".format(n), "\texpected: {:.2f}".format(freq), "\tgotten: {:.2f}".format(round(sim_freq, 2)), "\tfreq diff: {:.2f}".format(abs(freq-round(sim_freq, 2))), "\tmax_diff:", max_diff) + self.assertTrue(abs(freq-round(sim_freq)) <= max_diff, msg="expected freq: {:.2f} vs calculated: {:.2f}. max diff was {:.2f}".format(freq, sim_freq, max_diff)) + + def test_calculate_isi_frequency(self): + simulation_time = 1 + sampling_interval = 0.00005 + + for freq in self.frequencies: + for n in self.noise_levels: + spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) + sim_freq = hF.calculate_isi_frequency_trace(spikes, sampling_interval, time_in_ms=False) + + isis = np.diff(spikes) + step_length = isis / sampling_interval + rounded_step_length = np.around(step_length) + expected_length = sum(rounded_step_length) + + length = len(sim_freq) + self.assertEqual(expected_length, length) + + def test_calculate_isi_frequency_trace(self): + sampling_intervals = [0.00005, 0.001, 0.01, 0.2, 0.5, 1] + + test1 = [0, 1, 2, 3, 4] # 1-1-1-1 only 1s in the result + test2 = [0, 1, 3, 5, 6] # 1-2-2-1 + test3 = [0, 3, 10, 12, 15] # 3-7-2-3 + pos_tests = [test1, test2, test3] + + test4 = generate_jittered_spiketimes(100, 0.2) + test5 = generate_jittered_spiketimes(500, 0.2) + test6 = generate_jittered_spiketimes(1000, 0) + realistic_tests = [test4, test5, test6] + + test_neg_isi = [0, 3, 4, 2, 5] # should raise error non sorted spiketimes + test_too_small_sampling_rate = [0.001, 0.0015, 0.002] + neg_tests = [test_neg_isi, test_too_small_sampling_rate] + + for test in pos_tests: + for sampling_interval in sampling_intervals: + calculated_trace = hF.calculate_isi_frequency_trace(test, sampling_interval, time_in_ms=False) + diffs = np.diff(test) + j = 0 + count = 0 + value = 1/diffs[j] + for i in range(len(calculated_trace)): + if calculated_trace[i] == value: + count += 1 + else: + expected_length = round(diffs[j] / sampling_interval) + + # if there are multiple isis of the same length after each other add them together + while expected_length < count and value == 1/diffs[j+1]: + j += 1 + expected_length += round(diffs[j] / sampling_interval, 0) + + self.assertEqual(count, expected_length, msg="Length of isi frequency part is not right: expected {:.1f} vs {:.1f}".format(float(count), expected_length)) + j += 1 + value = 1/diffs[j] + count = 1 + + for test in neg_tests: + self.assertRaises(ValueError, hF.calculate_isi_frequency_trace, test, 0.2, False) + + def test_calculate_time_and_frequency_trace(self): + + # !!! the produced frequency trace is tested in the test function for specifically the freq_Trace function + sampling_intervals = [0.0001, 0.1, 0.5, 1] + + test1 = [0, 1, 2, 5, 7] + test2 = [1, 3, 5, 6, 7, 10] + test3 = [-1, 2, 4, 5, 11] + + pos_tests = [test1, test2, test3] + + for sampling_interval in sampling_intervals: + for test in pos_tests: + time, freq = hF.calculate_time_and_frequency_trace(test, sampling_interval, time_in_ms=False) + + self.assertEqual(test[0], time[0]) + self.assertEqual(test[-1], round(time[-1]+sampling_interval)) + + def test_calculate_mean_of_frequency_traces(self): + # TODO expand this test to more than this single test case + test1_f = [0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1] + test1_t = np.arange(0, 8, 0.5) + test2_f = [1, 2, 2, 3, 3, 4] + test2_t = np.arange(0.5, 7.5, 0.5) + + time_traces = [test1_t, test2_t] + freq_traces = [test1_f, test2_f] + time, mean = hF.calculate_mean_of_frequency_traces(time_traces, freq_traces, 0.5) + + expected_time = np.arange(0.5, 7.5, 0.5) + + expected_mean = [0.75, 1.25, 1.25, 2, 2, 2.5] + time_equal = np.all([time[i] == expected_time[i] for i in range(len(time))]) + mean_equal = np.all([mean[i] == expected_mean[i] for i in range(len(mean))]) + self.assertTrue(time_equal) + self.assertTrue(mean_equal, msg="expected:\n" + str(expected_mean) + "\n actual: \n" + str(mean)) + self.assertEqual(len(expected_mean), len(mean)) + self.assertEqual(len(expected_time), len(time), msg="expected:\n" + str(expected_time) + "\n actual: \n" + str(time)) + + # TODO: + # all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, time_in_ms=True): + + +def generate_jittered_spiketimes(frequency, noise_level=0., start=0, end=5, method='normal'): + + if method is 'normal': + return normal_dist_jittered_spikes(frequency, noise_level, start, end) + + elif method is 'poisson': + if noise_level != 0: + warn("Poisson jittered spike trains don't support a noise level! ") + return poisson_jittered_spikes(frequency, start, end) + + +def poisson_jittered_spikes(frequency, start, end): + if frequency == 0: + return [] + + mean_isi = 1 / frequency + + spikes = [] + for part in np.arange(start, end+mean_isi, mean_isi): + num_spikes_in_part = np.random.poisson(1) + positions = np.sort(np.random.random(num_spikes_in_part)) + + while not __poisson_min_dist_test__(positions): + positions = np.sort(np.random.random(num_spikes_in_part)) + + for pos in positions: + spikes.append(part+pos*mean_isi) + + while spikes[-1] > end: + del spikes[-1] + + return spikes + + +def __poisson_min_dist_test__(positions): + if len(positions) > 1: + diffs = np.diff(positions) + if len(diffs[diffs < 0.0001]) > 0: + return False + + return True + + +def normal_dist_jittered_spikes(frequency, noise_level, start, end): + if frequency == 0: + return [] + + mean_isi = 1 / frequency + if noise_level == 0: + return np.arange(start, end, mean_isi) + + isis = np.random.normal(mean_isi, noise_level*mean_isi, int((end-start)*1.05/mean_isi)) + spikes = np.cumsum(isis) + start + spikes = np.sort(spikes) + + if spikes[-1] > end: + return spikes[spikes < end] + + else: + additional_spikes = [spikes[-1] + np.random.normal(mean_isi, noise_level*mean_isi)] + + while additional_spikes[-1] < end: + next_isi = np.random.normal(mean_isi, noise_level*mean_isi) + additional_spikes.append(additional_spikes[-1] + next_isi) + + additional_spikes = np.sort(np.array(additional_spikes[:-1])) + spikes = np.concatenate((spikes, additional_spikes)) + + return spikes + + +def test_distribution(): + simulation_time = 5 + freqs = [5, 30, 100, 500, 1000] + noise_level = [0.05, 0.1, 0.2, 0.3] + repetitions = 1000 + for freq in freqs: + diffs_per_noise = [] + for n in noise_level: + diffs = [] + print("#### - freq:", freq, "noise level:", n ) + for reps in range(repetitions): + spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) + sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.0002, simulation_time / 4, time_in_ms=False) + diffs.append(sim_freq-freq) + + diffs_per_noise.append(diffs) + + fig, axs = plt.subplots(1, len(noise_level), figsize=(3.5*len(noise_level), 4), sharex='all') + + for i in range(len(diffs_per_noise)): + max_diff = np.max(np.abs(diffs_per_noise[i])) + print("Freq: ", freq, "noise: {:.2f}".format(noise_level[i]), "mean: {:.2f}".format(np.mean(diffs_per_noise[i])), "max_diff: {:.4f}".format(max_diff)) + bins = np.arange(-max_diff, max_diff, 2*max_diff/100) + axs[i].hist(diffs_per_noise[i], bins=bins) + axs[i].set_title('Noise level: {:.2f}'.format(noise_level[i])) + + plt.show() + plt.close() \ No newline at end of file diff --git a/unittests/testHelperFunctions.py b/unittests/testHelperFunctions.py index 47351e0..7618176 100644 --- a/unittests/testHelperFunctions.py +++ b/unittests/testHelperFunctions.py @@ -15,16 +15,6 @@ class HelperFunctionsTester(unittest.TestCase): def tearDown(self): pass - def test_calculate_eod_frequency(self): - start = 0 - end = 5 - step = 0.1 / 1000 - freqs = [0, 1, 10, 500, 700, 1000] - for freq in freqs: - time = np.arange(start, end, step) - eod = np.sin(freq*(2*np.pi) * time) - self.assertEqual(freq, round(hF.calculate_eod_frequency(time, eod), 2)) - def test__vector_strength__is_1(self): length = 2000 rel_spike_times = np.full(length, 0.3) @@ -40,91 +30,8 @@ class HelperFunctionsTester(unittest.TestCase): self.assertEqual(0, round(hF.__vector_strength__(rel_spike_times, eod_durations), 5)) - def test_mean_freq_of_spiketimes_after_time_x(self): - simulation_time = 8 - for freq in self.frequencies: - for n in self.noise_levels: - spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) - sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.00005, simulation_time/4, time_in_ms=False) - - max_diff = round(n*(10+0.7*np.sqrt(freq)), 2) - # print("noise: {:.2f}".format(n), "\texpected: {:.2f}".format(freq), "\tgotten: {:.2f}".format(round(sim_freq, 2)), "\tfreq diff: {:.2f}".format(abs(freq-round(sim_freq, 2))), "\tmax_diff:", max_diff) - self.assertTrue(abs(freq-round(sim_freq)) <= max_diff, msg="expected freq: {:.2f} vs calculated: {:.2f}. max diff was {:.2f}".format(freq, sim_freq, max_diff)) - - def test_calculate_isi_frequency(self): - simulation_time = 1 - sampling_interval = 0.00005 - - for freq in self.frequencies: - for n in self.noise_levels: - spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) - sim_freq = hF.calculate_isi_frequency(spikes, sampling_interval, time_in_ms=False) - - isis = np.diff(spikes) - step_length = isis / sampling_interval - rounded_step_length = np.around(step_length) - expected_length = sum(rounded_step_length) - - length = len(sim_freq) - self.assertEqual(expected_length, length) - # def test(self): # test_distribution() -def generate_jittered_spiketimes(frequency, noise_level, start=0, end=5): - if frequency == 0: - return [] - - mean_isi = 1 / frequency - if noise_level == 0: - return np.arange(start, end, mean_isi) - - spikes = [start] - count = 0 - while True: - next_isi = np.random.normal(mean_isi, noise_level*mean_isi) - if next_isi <= 0: - count += 1 - continue - next_spike = spikes[-1] + next_isi - if next_spike > end: - break - spikes.append(spikes[-1] + next_isi) - - # print("count: {:} percentage of missed: {:.2f}".format(count, count/len(spikes))) - if count > 0.01*len(spikes): - print("!!! Danger of lowering actual simulated frequency") - pass - return spikes - - -def test_distribution(): - simulation_time = 5 - freqs = [5, 30, 100, 500, 1000] - noise_level = [0.05, 0.1, 0.2, 0.3] - repetitions = 1000 - for freq in freqs: - diffs_per_noise = [] - for n in noise_level: - diffs = [] - print("#### - freq:", freq, "noise level:", n ) - for reps in range(repetitions): - spikes = generate_jittered_spiketimes(freq, n, end=simulation_time) - sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.0002, simulation_time / 4, time_in_ms=False) - diffs.append(sim_freq-freq) - - diffs_per_noise.append(diffs) - - fig, axs = plt.subplots(1, len(noise_level), figsize=(3.5*len(noise_level), 4), sharex='all') - - for i in range(len(diffs_per_noise)): - max_diff = np.max(np.abs(diffs_per_noise[i])) - print("Freq: ", freq, "noise: {:.2f}".format(noise_level[i]), "mean: {:.2f}".format(np.mean(diffs_per_noise[i])), "max_diff: {:.4f}".format(max_diff)) - bins = np.arange(-max_diff, max_diff, 2*max_diff/100) - axs[i].hist(diffs_per_noise[i], bins=bins) - axs[i].set_title('Noise level: {:.2f}'.format(noise_level[i])) - - plt.show() - plt.close() \ No newline at end of file