import numpy as np from warnings import warn from thunderfish.eventdetection import detect_peaks, threshold_crossing_times, threshold_crossings def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): i = 0 diffs = np.diff(sorted(intensities)) margin = np.mean(diffs) * 0.6666 while True: if i >= len(intensities): break intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin) i += 1 # Sort the lists so that intensities are increasing x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] intensities = x[0] spiketimes = x[1] return intensities, spiketimes, trans_amplitudes def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin): intensity = intensities[index] indices_to_merge = [] for i in range(index+1, len(intensities)): if np.abs(intensities[i]-intensity) < margin: indices_to_merge.append(i) if len(indices_to_merge) != 0: indices_to_merge.reverse() trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge] all_the_same = True for j in range(1, len(trans_amplitude_values)): if not trans_amplitude_values[0] == trans_amplitude_values[j]: all_the_same = False break if all_the_same: for idx in indices_to_merge: del trans_amplitudes[idx] else: raise RuntimeError("Trans_amplitudes not the same....") for idx in indices_to_merge: spiketimes[index].extend(spiketimes[idx]) del spiketimes[idx] del intensities[idx] return intensities, spiketimes, trans_amplitudes def all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, time_in_ms=True): """ Expects spiketimes to be a 3dim list with the first dimension being the trial the second the count of runs of spikes and the last the individual spikes_times: [[[trial1-run1-spike1, trial1-run1-spike2, ...],[trial1-run2-spike1, ...]],[[trial2-run1-spike1, ...], [..]]] :param spiketimes: :param sampling_interval: :param time_in_ms: :return: the mean frequency trace for each trial and its time trace """ times = [] mean_frequencies = [] for i in range(len(spiketimes)): trial_time_trace = [] trial_freq_trace = [] for j in range(len(spiketimes[i])): time, isi_freq = calculate_time_and_frequency_trace(spiketimes[i][j], sampling_interval, time_in_ms) trial_freq_trace.append(isi_freq) trial_time_trace.append(time) time, mean_freq = calculate_mean_of_frequency_traces(trial_time_trace, trial_freq_trace, sampling_interval) times.append(time) mean_frequencies.append(mean_freq) return times, mean_frequencies def calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): """ Calculates the frequency over time according to the inter spike intervals. :param spiketimes: sorted time points spikes were measured array_like :param sampling_interval: the sampling interval in which the frequency should be given back :param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval :return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike """ if len(spiketimes) <= 1: return [] isis = np.diff(spiketimes) if sampling_interval > min(isis): raise ValueError("The sampling interval is bigger than the some isis! cannot accurately compute the trace.") if time_in_ms: isis = isis / 1000 sampling_interval = sampling_interval / 1000 full_frequency = np.array([]) for isi in isis: if isi < 0: raise ValueError("There was a negative interspike interval, the spiketimes need to be sorted") if isi == 0: warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") print("ISI was zero:", spiketimes) continue freq = 1 / isi frequency_step = np.full(int(round(isi * (1 / sampling_interval))), freq) full_frequency = np.concatenate((full_frequency, frequency_step)) return full_frequency def calculate_time_and_frequency_trace(spiketimes, sampling_interval, time_in_ms=False): frequency = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms) time = np.arange(spiketimes[0], spiketimes[-1], sampling_interval) return time, frequency def calculate_mean_of_frequency_traces(trial_time_traces, trial_frequency_traces, sampling_interval): """ calculates the mean_trace of the given frequency traces -> mean at each time point for traces starting at different times :param trial_time_traces: :param trial_frequency_traces: :param sampling_interval: :return: """ ends = [t[-1] for t in trial_time_traces] starts = [t[0] for t in trial_time_traces] latest_start = max(starts) earliest_end = min(ends) shortened_time = np.arange(latest_start, earliest_end+sampling_interval, sampling_interval) shortened_freqs = [] for i in range(len(trial_frequency_traces)): start_idx = int((latest_start - trial_time_traces[i][0]) / sampling_interval) end_idx = int((earliest_end - trial_time_traces[i][0]) / sampling_interval) shortened_freqs.append(trial_frequency_traces[i][start_idx:end_idx]) mean_freq = [sum(e) / len(e) for e in zip(*shortened_freqs)] return shortened_time, mean_freq def mean_freq_of_spiketimes_after_time_x(spiketimes, sampling_interval, time_x, time_in_ms=False): """ Calculates the mean frequency of the portion of spiketimes that is after last_x_time """ if len(spiketimes) <= 1: return 0 freq = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms) # returned frequency starts at the idx = int((time_x-spiketimes[0]) / sampling_interval) rest_array = freq[idx:] mean_freq = np.mean(rest_array) return mean_freq def calculate_mean_isi_freq(spiketimes, time_in_ms=False): if len(spiketimes) < 2: return 0 isis = np.diff(spiketimes) if time_in_ms: isis = isis / 1000 freqs = 1 / isis weights = isis / np.min(isis) return sum(freqs * weights) / sum(weights) # @jit(nopython=True) # only faster at around 30 000 calls def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float: # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) isi = np.diff(spiketimes) std = np.std(isi) mean = np.mean(isi) return std/mean # @jit(nopython=True) # maybe faster with more than ~60 000 calls def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray: isi = np.diff(spiketimes) if len(spiketimes) < max_lag + 1: raise ValueError("Given list to short, with given max_lag") cor = np.zeros(max_lag) for lag in range(max_lag): lag = lag + 1 first = isi[:-lag] second = isi[lag:] cor[lag-1] = np.corrcoef(first, second)[0][1] return cor def calculate_eod_frequency(time, eod): # TODO for few samples very volatile measure! up_indicies, down_indicies = threshold_crossings(eod, 0) up_times, down_times = threshold_crossing_times(time, eod, 0, up_indicies, down_indicies) if len(up_times) == 0: return 0 durations = np.diff(up_times) mean_duration = np.mean(durations) return 1/mean_duration def calculate_vector_strength(times, eods, v1_traces): # Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 # dl.iload_traces(repro='BaselineActivity') relative_spike_times = [] eod_durations = [] if len(times) == 0: print("-----LENGTH OF TIMES = 0") for recording in range(len(times)): spiketime_idices = detect_spikes(v1_traces[recording]) rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices) relative_spike_times.extend(rel_spikes) eod_durations.extend(eod_durs) print(__vector_strength__(np.array(rel_spikes), np.array(eod_durs))) relative_spike_times = np.array(relative_spike_times) eod_durations = np.array(eod_durations) return __vector_strength__(relative_spike_times, eod_durations) def detect_spikes(v1, split=20, threshold=3): total = len(v1) all_peaks = [] for n in range(split): length = int(total / split) first_index = n * length last_index = (n + 1) * length std = np.std(v1[first_index:last_index]) peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold) peaks = peaks + first_index all_peaks.extend(peaks) all_peaks = np.array(all_peaks) return all_peaks def calculate_phases(relative_spike_times, eod_durations): phase_times = np.zeros(len(relative_spike_times)) for i in range(len(relative_spike_times)): phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi return phase_times def eods_around_spikes(time, eod, spiketime_idices): eod_durations = [] relative_spike_times = [] for spike_idx in spiketime_idices: start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx) eod_durations.append(end_time-start_time) spiketime = time[spike_idx] relative_spike_times.append(spiketime - start_time) return relative_spike_times, eod_durations def search_eod_start_and_end_times(time, eod, index): # TODO might break if a spike is in the cut off first or last eod! # search start_time: previous = index working_idx = index-1 while True: if eod[working_idx] < 0 < eod[previous]: first_value = eod[working_idx] second_value = eod[previous] dif = second_value - first_value part = np.abs(first_value/dif) time_dif = np.abs(time[previous] - time[working_idx]) start_time = time[working_idx] + time_dif*part break previous = working_idx working_idx -= 1 # search end_time previous = index working_idx = index + 1 while True: if eod[previous] < 0 < eod[working_idx]: first_value = eod[previous] second_value = eod[working_idx] dif = second_value - first_value part = np.abs(first_value / dif) time_dif = np.abs(time[previous] - time[working_idx]) end_time = time[working_idx] + time_dif * part break previous = working_idx working_idx += 1 return start_time, end_time def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray): # adapted from Ramona n = len(relative_spike_times) if n == 0: return -1 phase_times = (relative_spike_times / eod_durations) * 2 * np.pi vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2) return vs