import numpy as np from warnings import warn from thunderfish.eventdetection import threshold_crossing_times, threshold_crossings, detect_peaks from scipy.optimize import curve_fit import functions as fu from numba import jit import matplotlib.pyplot as plt def fit_clipped_line(x, y): popt, pcov = curve_fit(fu.clipped_line, x, y) return popt def fit_boltzmann(x, y): max_f0 = float(max(y)) if max_f0 == 0: return [0, 0, 0, 0] min_f0 = 0.1 # float(min(self.f_zeros)) mean_int = float(np.mean(x)) total_increase = max_f0 - min_f0 total_change_int = max(x) - min(x) start_k = float((total_increase / total_change_int * 4) / max_f0) try: popt, pcov = curve_fit(fu.full_boltzmann, x, y, p0=(max_f0, min_f0, start_k, mean_int), maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf])) except RuntimeError as e: print("Error in fit boltzmann: ", str(e)) print("x_values:", x) print("y_values:", y) return [0, 0, 0, 0] return popt @jit(nopython=True) def rectify_stimulus_array(stimulus_array: np.ndarray): return np.array([x if x > 0 else 0 for x in stimulus_array]) def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): i = 0 diffs = np.diff(sorted(intensities)) margin = np.mean(diffs) * 0.6666 while True: if i >= len(intensities): break intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin) i += 1 # Sort the lists so that intensities are increasing x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] intensities = x[0] spiketimes = x[1] return intensities, spiketimes, trans_amplitudes def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin): intensity = intensities[index] indices_to_merge = [] for i in range(index+1, len(intensities)): if np.abs(intensities[i]-intensity) < margin: indices_to_merge.append(i) if len(indices_to_merge) != 0: indices_to_merge.reverse() trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge] all_the_same = True for j in range(1, len(trans_amplitude_values)): if not trans_amplitude_values[0] == trans_amplitude_values[j]: all_the_same = False break if all_the_same: for idx in indices_to_merge: del trans_amplitudes[idx] else: raise RuntimeError("Trans_amplitudes not the same....") for idx in indices_to_merge: spiketimes[index].extend(spiketimes[idx]) del spiketimes[idx] del intensities[idx] return intensities, spiketimes, trans_amplitudes def all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, stimulus_start=0, time_in_ms=False): """ Expects spiketimes to be a 3dim list with the first dimension being the trial the second the count of runs of spikes and the last the individual spikes_times: [[[trial1-run1-spike1, trial1-run1-spike2, ...],[trial1-run2-spike1, ...]],[[trial2-run1-spike1, ...], [..]]] :param stimulus_start: the time point at which the actual stimulus starts :param spiketimes: time points of action potentials :param sampling_interval: the sampling interval used / will also be used for the frequency trace :param time_in_ms: whether the time is in ms or seconds :return: the mean frequency trace for each trial and its time trace """ times = [] mean_frequencies = [] for i in range(len(spiketimes)): trial_time_trace = [] trial_freq_trace = [] for j in range(len(spiketimes[i])): time, isi_freq = calculate_time_and_frequency_trace(spiketimes[i][j], sampling_interval, time_in_ms) if time[0] > stimulus_start: print("Trial not used as its frequency trace started after the stimulus start!") continue trial_freq_trace.append(isi_freq) trial_time_trace.append(time) time, mean_freq = calculate_mean_of_frequency_traces(trial_time_trace, trial_freq_trace, sampling_interval) times.append(time) mean_frequencies.append(mean_freq) return times, mean_frequencies def calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): """ Calculates the frequency over time according to the inter spike intervals. :param spiketimes: sorted time points spikes were measured array_like :param sampling_interval: the sampling interval in which the frequency should be given back :param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval :return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike """ if len(spiketimes) <= 1: return [] isis = np.diff(spiketimes) if sampling_interval > round(min(isis), 7): raise ValueError("The sampling interval is bigger than the some isis! cannot accurately compute the trace.\n" "Sampling interval {:.5f}, smallest isi: {:.5f}".format(sampling_interval, min(isis))) if time_in_ms: isis = isis / 1000 sampling_interval = sampling_interval / 1000 full_frequency = np.array([]) for isi in isis: if isi < 0: raise ValueError("There was a negative interspike interval, the spiketimes need to be sorted") if isi == 0: warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") print("ISI was zero:", spiketimes) continue freq = 1 / isi frequency_step = np.full(int(round(isi * (1 / sampling_interval))), freq) full_frequency = np.concatenate((full_frequency, frequency_step)) return full_frequency def calculate_time_and_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): if len(spiketimes) < 2: return [0], [0] # raise ValueError("Cannot compute a time and frequency vector with fewer than 2 spikes") frequency = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms) time = np.arange(spiketimes[0], spiketimes[-1], sampling_interval) if len(time) != len(frequency): if len(time) > len(frequency): time = time[:len(frequency)] return time, frequency def calculate_mean_of_frequency_traces(trial_time_traces, trial_frequency_traces, sampling_interval): """ calculates the mean_trace of the given frequency traces -> mean at each time point for traces starting at different times :param trial_time_traces: :param trial_frequency_traces: :param sampling_interval: :return: """ ends = [t[-1] for t in trial_time_traces] starts = [t[0] for t in trial_time_traces] latest_start = max(starts) earliest_end = min(ends) length = int(round((earliest_end - latest_start) / sampling_interval)) shortened_time = (np.arange(0, length) * sampling_interval) + latest_start shortened_freqs = [] for i in range(len(trial_frequency_traces)): start_idx = int(round((latest_start - trial_time_traces[i][0]) / sampling_interval)) end_idx = int(round((earliest_end - trial_time_traces[i][0]) / sampling_interval)) shortened_freqs.append(trial_frequency_traces[i][start_idx:end_idx]) mean_freq = [sum(e) / len(e) for e in zip(*shortened_freqs)] # for i in range(len(trial_time_traces)): # if i > 5: # break # plt.plot(trial_time_traces[i], trial_frequency_traces[i]) # # plt.plot(shortened_time, mean_freq, color="black") # plt.show() # plt.close() # if len(mean_freq) == len(shortened_time): # print("time and freq trace worked out.") # else: # print("time and freq trace were different length. time- freq:" + str(len(shortened_time)-len(mean_freq))) return shortened_time, mean_freq def mean_freq_of_spiketimes_after_time_x(spiketimes, time_x, time_in_ms=False): """ Calculates the mean frequency of the portion of spiketimes that is after last_x_time """ spiketimes = np.array(spiketimes) if len(spiketimes) <= 1: return 0 relevant_spikes = spiketimes[spiketimes > time_x] if len(relevant_spikes) <= 1: return 0 return calculate_mean_isi_freq(relevant_spikes, time_in_ms) def calculate_mean_isi_freq(spiketimes, time_in_ms=False): if len(spiketimes) < 2: return 0 isis = np.diff(spiketimes) if time_in_ms: isis = isis / 1000 freqs = 1 / isis weights = isis / np.min(isis) return sum(freqs * weights) / sum(weights) # @jit(nopython=True) # only faster at around 30 000 calls def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float: # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) if len(spiketimes) <= 2: return 0 isi = np.diff(spiketimes) std = np.std(isi) mean = np.mean(isi) return std/mean # @jit(nopython=True) # maybe faster with more than ~60 000 calls def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray: isi = np.diff(spiketimes) if len(spiketimes) < max_lag + 1 or len(spiketimes) < 20: warn("Cannot compute serial correlation with list shorter than max lag...") return np.zeros(max_lag) # raise ValueError("Given list to short, with given max_lag") cor = np.zeros(max_lag) for lag in range(max_lag): lag = lag + 1 first = isi[:-lag] second = isi[lag:] cor[lag-1] = np.corrcoef(first, second)[0][1] return cor def calculate_eod_frequency(time, eod): # TODO for few samples very volatile measure! up_indicies, down_indicies = threshold_crossings(eod, 0) up_times, down_times = threshold_crossing_times(time, eod, 0, up_indicies, down_indicies) if len(up_times) == 0: return 0 durations = np.diff(up_times) mean_duration = np.mean(durations) return 1/mean_duration def calculate_vector_strength_from_spiketimes(time, eod, spiketimes, sampling_interval): spiketime_indices = np.array(np.around((np.array(spiketimes) + time[0]) / sampling_interval), dtype=int) rel_spikes, eod_durs = eods_around_spikes(time, eod, spiketime_indices) return __vector_strength__(rel_spikes, eod_durs) def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000): split_start = 0 step_size = split_step break_threshold = 0.25 splits = [] if len(v1) < min_length: splits = [(0, len(v1))] else: last_max = max(v1[0:min_length]) last_min = min(v1[0:min_length]) idx = min_length while idx < len(v1): if idx + step_size > len(v1): splits.append((split_start, len(v1))) break # max_dif = abs((max(v1[idx:idx+step_size]) / last_max) - 1) # min_dif = abs((min(v1[idx:idx+step_size]) / last_min) - 1) # print("last_max: {:.2f}, current_max: {:.2f}".format(last_max, max(v1[idx:idx+step_size]))) # print("max_dif: {:.2f}, min_dif: {:.2f}".format(max_dif, min_dif)) max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold if not max_similar or not min_similar: # print("new split") end_idx = np.argmin(v1[idx-20:idx+21]) - 20 splits.append((split_start, idx+end_idx)) split_start = idx+end_idx last_max = max(v1[split_start:split_start + min_length]) last_min = min(v1[split_start:split_start + min_length]) idx = split_start + min_length continue else: pass # print("elongated!") idx += step_size if splits[-1][1] != len(v1): splits.append((split_start, len(v1))) # plt.plot(v1) # for s in splits: # plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]]))) all_peaks = [] for s in splits: first_index = s[0] last_index = s[1] std = np.std(v1[first_index:last_index]) peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold) peaks = peaks + first_index # plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o') all_peaks.extend(peaks) # plt.show() # plt.close() # all_peaks = np.array(all_peaks) return all_peaks def detect_spiketimes(time, v1, threshold=2.0, min_length=5000, split_step=1000): all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold=threshold, min_length=min_length, split_step=split_step) return [time[p_idx] for p_idx in all_peak_indicies] def eods_around_spikes(time, eod, spiketime_idices): eod_durations = [] relative_spike_times = [] sign_changes = np.sign(eod[:-1]) != np.sign(eod[1:]) eod_trace_increasing = eod[:-1] < eod[1:] eod_zero_crossings_indices = np.where(sign_changes & eod_trace_increasing)[0] for spike_idx in spiketime_idices: # test if it is inside two detected crossings if eod_zero_crossings_indices[0] > spike_idx > eod_zero_crossings_indices[-1]: continue zero_crossing_index_of_eod_end = np.argmax(eod_zero_crossings_indices > spike_idx) end_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end] start_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end - 1] eod_durations.append(time[end_time_idx] - time[start_time_idx]) relative_spike_times.append(time[spike_idx] - time[start_time_idx]) # try: # start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx) # # eod_durations.append(end_time-start_time) # spiketime = time[spike_idx] # relative_spike_times.append(spiketime - start_time) # except IndexError as e: # continue return np.array(relative_spike_times), np.array(eod_durations) # def search_eod_start_and_end_times(time, eod, index): # # TODO might break if a spike is in the cut off first or last eod! # # # search start_time: # previous = index # working_idx = index-1 # while True: # if eod[working_idx] < 0 < eod[previous]: # first_value = eod[working_idx] # second_value = eod[previous] # # dif = second_value - first_value # part = np.abs(first_value/dif) # # time_dif = np.abs(time[previous] - time[working_idx]) # start_time = time[working_idx] + time_dif*part # # break # # previous = working_idx # working_idx -= 1 # # # search end_time # previous = index # working_idx = index + 1 # while True: # if eod[previous] < 0 < eod[working_idx]: # first_value = eod[previous] # second_value = eod[working_idx] # # dif = second_value - first_value # part = np.abs(first_value / dif) # # time_dif = np.abs(time[previous] - time[working_idx]) # end_time = time[working_idx] + time_dif * part # # break # # previous = working_idx # working_idx += 1 # # return start_time, end_time def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray): # adapted from Ramona n = len(relative_spike_times) if n == 0: return -1 phase_times = (relative_spike_times / eod_durations) * 2 * np.pi vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2) return vs def detect_f_zero_in_frequency_trace(time, frequency, stimulus_start, sampling_interval, peak_buffer_percent=0.05, buffer=0.025): if time[0] + 2*buffer > stimulus_start: print("F_zero detection: Not enough frequency trace before start of the stimulus.") return 0 freq_before = frequency[int(time[0]+buffer/sampling_interval):int((stimulus_start - time[0] - buffer) / sampling_interval)] min_before = min(freq_before) max_before = max(freq_before) mean_before = np.mean(freq_before) # time where the f-zero is searched in start_idx, end_idx = time_window_detect_f_zero(time[0], stimulus_start, sampling_interval, buffer) if start_idx < 0: raise ValueError("Time window to detect f_zero starts in an negative index!") min_during_start_of_stim = min(frequency[start_idx:end_idx]) max_during_start_of_stim = max(frequency[start_idx:end_idx]) if abs(mean_before-min_during_start_of_stim) > abs(max_during_start_of_stim-mean_before): f_zero = min_during_start_of_stim else: f_zero = max_during_start_of_stim peak_buffer = (max_before - min_before) * peak_buffer_percent if min_before - peak_buffer <= f_zero <= max_before + peak_buffer: end_idx = start_idx + int((end_idx-start_idx)/2) f_zero = np.mean(frequency[start_idx:end_idx]) # import matplotlib.pyplot as plt # plt.plot(time, frequency) # plt.plot(time[start_idx:end_idx], [f_zero for i in range(end_idx-start_idx)]) # plt.show() max_frequency = int(1/sampling_interval) int_f_zero = int(f_zero) if int_f_zero > max_frequency: raise AssertionError("Detection of f-zero went very wrong! frequency above 1/sampling_interval.") if int_f_zero > max(frequency): raise AssertionError("detected f_zero bigger than the highest peak in the frequency trace...") return f_zero def time_window_detect_f_zero(time_start, stimulus_start, sampling_interval, buffer=0.025): stimulus_start = stimulus_start - time_start start_idx = int((stimulus_start - 0.5 * buffer) / sampling_interval) end_idx = int((stimulus_start + buffer) / sampling_interval) return start_idx, end_idx def detect_f_infinity_in_freq_trace(time, frequency, stimulus_start, stimulus_duration, sampling_interval, length=0.1, buffer=0.025): start_idx, end_idx = time_window_detect_f_infinity(time[0], stimulus_start, stimulus_duration, sampling_interval, length, buffer) return np.mean(frequency[start_idx:end_idx]) def time_window_detect_f_infinity(time_start, stimulus_start, stimulus_duration, sampling_interval, length=0.1, buffer=0.025): stimulus_end_time = stimulus_start + stimulus_duration - time_start start_idx = int((stimulus_end_time - length - buffer) / sampling_interval) end_idx = int((stimulus_end_time - buffer) / sampling_interval) return start_idx, end_idx def detect_f_baseline_in_freq_trace(time, frequency, stimulus_start, sampling_interval, buffer=0.025): start_idx, end_idx = time_window_detect_f_baseline(time[0], stimulus_start, sampling_interval, buffer) f_baseline = np.mean(frequency[start_idx:end_idx]) return f_baseline def time_window_detect_f_baseline(time_start, stimulus_start, sampling_interval, buffer=0.025): stim_start = stimulus_start - time_start if stim_start < 0.1: warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") start_idx = int(buffer / sampling_interval) end_idx = int((stim_start - buffer) / sampling_interval) return start_idx, end_idx