454 lines
16 KiB
Python
454 lines
16 KiB
Python
import numpy as np
|
|
from warnings import warn
|
|
from thunderfish.eventdetection import detect_peaks, threshold_crossing_times, threshold_crossings
|
|
|
|
|
|
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
|
|
i = 0
|
|
|
|
diffs = np.diff(sorted(intensities))
|
|
margin = np.mean(diffs) * 0.6666
|
|
|
|
while True:
|
|
if i >= len(intensities):
|
|
break
|
|
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
|
|
i += 1
|
|
|
|
# Sort the lists so that intensities are increasing
|
|
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
|
|
intensities = x[0]
|
|
spiketimes = x[1]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
|
|
intensity = intensities[index]
|
|
|
|
indices_to_merge = []
|
|
for i in range(index+1, len(intensities)):
|
|
if np.abs(intensities[i]-intensity) < margin:
|
|
indices_to_merge.append(i)
|
|
|
|
if len(indices_to_merge) != 0:
|
|
indices_to_merge.reverse()
|
|
|
|
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
|
|
|
|
all_the_same = True
|
|
for j in range(1, len(trans_amplitude_values)):
|
|
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
|
|
all_the_same = False
|
|
break
|
|
|
|
if all_the_same:
|
|
for idx in indices_to_merge:
|
|
del trans_amplitudes[idx]
|
|
else:
|
|
raise RuntimeError("Trans_amplitudes not the same....")
|
|
for idx in indices_to_merge:
|
|
spiketimes[index].extend(spiketimes[idx])
|
|
del spiketimes[idx]
|
|
del intensities[idx]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, stimulus_start=0, time_in_ms=False):
|
|
"""
|
|
Expects spiketimes to be a 3dim list with the first dimension being the trial
|
|
the second the count of runs of spikes and the last the individual spikes_times:
|
|
[[[trial1-run1-spike1, trial1-run1-spike2, ...],[trial1-run2-spike1, ...]],[[trial2-run1-spike1, ...], [..]]]
|
|
:param stimulus_start: the time point at which the actual stimulus starts
|
|
:param spiketimes: time points of action potentials
|
|
:param sampling_interval: the sampling interval used / will also be used for the frequency trace
|
|
:param time_in_ms: whether the time is in ms or seconds
|
|
:return: the mean frequency trace for each trial and its time trace
|
|
"""
|
|
times = []
|
|
mean_frequencies = []
|
|
|
|
for i in range(len(spiketimes)):
|
|
trial_time_trace = []
|
|
trial_freq_trace = []
|
|
for j in range(len(spiketimes[i])):
|
|
time, isi_freq = calculate_time_and_frequency_trace(spiketimes[i][j], sampling_interval, time_in_ms)
|
|
|
|
if time[0] > stimulus_start:
|
|
print("Trial not used as its frequency trace started after the stimulus start!")
|
|
continue
|
|
trial_freq_trace.append(isi_freq)
|
|
trial_time_trace.append(time)
|
|
|
|
time, mean_freq = calculate_mean_of_frequency_traces(trial_time_trace, trial_freq_trace, sampling_interval)
|
|
times.append(time)
|
|
mean_frequencies.append(mean_freq)
|
|
|
|
return times, mean_frequencies
|
|
|
|
|
|
def calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms=False):
|
|
"""
|
|
Calculates the frequency over time according to the inter spike intervals.
|
|
|
|
:param spiketimes: sorted time points spikes were measured array_like
|
|
:param sampling_interval: the sampling interval in which the frequency should be given back
|
|
:param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval
|
|
:return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike
|
|
"""
|
|
|
|
if len(spiketimes) <= 1:
|
|
return []
|
|
|
|
isis = np.diff(spiketimes)
|
|
if sampling_interval > round(min(isis), 7):
|
|
raise ValueError("The sampling interval is bigger than the some isis! cannot accurately compute the trace.\n"
|
|
"Sampling interval {:.5f}, smallest isi: {:.5f}".format(sampling_interval, min(isis)))
|
|
|
|
if time_in_ms:
|
|
isis = isis / 1000
|
|
sampling_interval = sampling_interval / 1000
|
|
|
|
full_frequency = np.array([])
|
|
for isi in isis:
|
|
if isi < 0:
|
|
raise ValueError("There was a negative interspike interval, the spiketimes need to be sorted")
|
|
if isi == 0:
|
|
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
|
|
print("ISI was zero:", spiketimes)
|
|
continue
|
|
freq = 1 / isi
|
|
frequency_step = np.full(int(round(isi * (1 / sampling_interval))), freq)
|
|
full_frequency = np.concatenate((full_frequency, frequency_step))
|
|
|
|
return full_frequency
|
|
|
|
|
|
def calculate_time_and_frequency_trace(spiketimes, sampling_interval, time_in_ms=False):
|
|
if len(spiketimes) < 2:
|
|
pass # raise ValueError("Cannot compute a time and frequency vector with fewer than 2 spikes")
|
|
|
|
frequency = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms)
|
|
|
|
time = np.arange(spiketimes[0], spiketimes[-1], sampling_interval)
|
|
|
|
if len(time) != len(frequency):
|
|
if len(time) > len(frequency):
|
|
time = time[:len(frequency)]
|
|
|
|
|
|
return time, frequency
|
|
|
|
|
|
def calculate_mean_of_frequency_traces(trial_time_traces, trial_frequency_traces, sampling_interval):
|
|
"""
|
|
calculates the mean_trace of the given frequency traces -> mean at each time point
|
|
for traces starting at different times
|
|
:param trial_time_traces:
|
|
:param trial_frequency_traces:
|
|
:param sampling_interval:
|
|
:return:
|
|
"""
|
|
ends = [t[-1] for t in trial_time_traces]
|
|
starts = [t[0] for t in trial_time_traces]
|
|
latest_start = max(starts)
|
|
earliest_end = min(ends)
|
|
|
|
shortened_time = np.arange(latest_start, earliest_end, sampling_interval)
|
|
|
|
shortened_freqs = []
|
|
for i in range(len(trial_frequency_traces)):
|
|
start_idx = int(round((latest_start - trial_time_traces[i][0]) / sampling_interval))
|
|
end_idx = int(round((earliest_end - trial_time_traces[i][0]) / sampling_interval))
|
|
|
|
shortened_freqs.append(trial_frequency_traces[i][start_idx:end_idx])
|
|
|
|
mean_freq = [sum(e) / len(e) for e in zip(*shortened_freqs)]
|
|
|
|
return shortened_time, mean_freq
|
|
|
|
|
|
def mean_freq_of_spiketimes_after_time_x(spiketimes, time_x, time_in_ms=False):
|
|
""" Calculates the mean frequency of the portion of spiketimes that is after last_x_time """
|
|
|
|
spiketimes = np.array(spiketimes)
|
|
if len(spiketimes) <= 1:
|
|
return 0
|
|
|
|
relevant_spikes = spiketimes[spiketimes > time_x]
|
|
|
|
if len(relevant_spikes) <= 1:
|
|
return 0
|
|
if time_in_ms:
|
|
relevant_spikes = relevant_spikes / 1000
|
|
isis = np.diff(relevant_spikes)
|
|
isi_freqs = 1 / isis
|
|
weights = isis / min(isis)
|
|
|
|
mean_freq = sum(isi_freqs * weights) / sum(weights)
|
|
|
|
return mean_freq
|
|
|
|
# freq = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms)
|
|
# # returned frequency starts at the
|
|
# idx = int((time_x-spiketimes[0]) / sampling_interval)
|
|
# rest_array = freq[idx:]
|
|
# mean_freq = np.mean(rest_array)
|
|
# return mean_freq
|
|
|
|
|
|
def calculate_mean_isi_freq(spiketimes, time_in_ms=False):
|
|
if len(spiketimes) < 2:
|
|
return 0
|
|
|
|
isis = np.diff(spiketimes)
|
|
if time_in_ms:
|
|
isis = isis / 1000
|
|
freqs = 1 / isis
|
|
weights = isis / np.min(isis)
|
|
|
|
return sum(freqs * weights) / sum(weights)
|
|
|
|
|
|
|
|
|
|
# @jit(nopython=True) # only faster at around 30 000 calls
|
|
def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
|
|
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
|
|
isi = np.diff(spiketimes)
|
|
std = np.std(isi)
|
|
mean = np.mean(isi)
|
|
|
|
return std/mean
|
|
|
|
|
|
# @jit(nopython=True) # maybe faster with more than ~60 000 calls
|
|
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
|
|
isi = np.diff(spiketimes)
|
|
if len(spiketimes) < max_lag + 1:
|
|
raise ValueError("Given list to short, with given max_lag")
|
|
|
|
cor = np.zeros(max_lag)
|
|
for lag in range(max_lag):
|
|
lag = lag + 1
|
|
first = isi[:-lag]
|
|
second = isi[lag:]
|
|
|
|
cor[lag-1] = np.corrcoef(first, second)[0][1]
|
|
|
|
return cor
|
|
|
|
|
|
def calculate_eod_frequency(time, eod):
|
|
# TODO for few samples very volatile measure!
|
|
up_indicies, down_indicies = threshold_crossings(eod, 0)
|
|
up_times, down_times = threshold_crossing_times(time, eod, 0, up_indicies, down_indicies)
|
|
|
|
if len(up_times) == 0:
|
|
return 0
|
|
|
|
durations = np.diff(up_times)
|
|
mean_duration = np.mean(durations)
|
|
|
|
return 1/mean_duration
|
|
|
|
|
|
def calculate_vector_strength_from_v1_trace(times, eods, v1_traces):
|
|
# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
|
|
# dl.iload_traces(repro='BaselineActivity')
|
|
|
|
relative_spike_times = []
|
|
eod_durations = []
|
|
|
|
if len(times) == 0:
|
|
print("-----LENGTH OF TIMES = 0")
|
|
|
|
for recording in range(len(times)):
|
|
spiketime_idices = detect_spikes(v1_traces[recording])
|
|
rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices)
|
|
relative_spike_times.extend(rel_spikes)
|
|
eod_durations.extend(eod_durs)
|
|
# print(__vector_strength__(np.array(rel_spikes), np.array(eod_durs)))
|
|
|
|
relative_spike_times = np.array(relative_spike_times)
|
|
eod_durations = np.array(eod_durations)
|
|
|
|
return __vector_strength__(relative_spike_times, eod_durations)
|
|
|
|
|
|
def calculate_vector_strength_from_spiketimes(time, eod, spiketimes, sampling_interval):
|
|
spiketime_indices = np.array(np.around((np.array(spiketimes) + time[0]) / sampling_interval), dtype=int)
|
|
rel_spikes, eod_durs = eods_around_spikes(time, eod, spiketime_indices)
|
|
|
|
return __vector_strength__(rel_spikes, eod_durs)
|
|
|
|
|
|
def detect_spikes(v1, split=20, threshold=3):
|
|
total = len(v1)
|
|
all_peaks = []
|
|
|
|
for n in range(split):
|
|
length = int(total / split)
|
|
first_index = n * length
|
|
last_index = (n + 1) * length
|
|
std = np.std(v1[first_index:last_index])
|
|
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
|
|
peaks = peaks + first_index
|
|
all_peaks.extend(peaks)
|
|
|
|
all_peaks = np.array(all_peaks)
|
|
|
|
return all_peaks
|
|
|
|
|
|
def calculate_phases(relative_spike_times, eod_durations):
|
|
phase_times = np.zeros(len(relative_spike_times))
|
|
|
|
for i in range(len(relative_spike_times)):
|
|
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
|
|
|
|
return phase_times
|
|
|
|
|
|
def eods_around_spikes(time, eod, spiketime_idices):
|
|
eod_durations = []
|
|
relative_spike_times = []
|
|
|
|
sign_changes = np.sign(eod[:-1]) != np.sign(eod[1:])
|
|
eod_trace_increasing = eod[:-1] < eod[1:]
|
|
|
|
eod_zero_crossings_indices = np.where(sign_changes & eod_trace_increasing)[0]
|
|
for spike_idx in spiketime_idices:
|
|
# test if it is inside two detected crossings
|
|
if eod_zero_crossings_indices[0] > spike_idx > eod_zero_crossings_indices[-1]:
|
|
continue
|
|
zero_crossing_index_of_eod_end = np.argmax(eod_zero_crossings_indices > spike_idx)
|
|
end_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end]
|
|
start_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end - 1]
|
|
|
|
eod_durations.append(time[end_time_idx] - time[start_time_idx])
|
|
relative_spike_times.append(time[spike_idx] - time[start_time_idx])
|
|
|
|
# try:
|
|
# start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx)
|
|
#
|
|
# eod_durations.append(end_time-start_time)
|
|
# spiketime = time[spike_idx]
|
|
# relative_spike_times.append(spiketime - start_time)
|
|
# except IndexError as e:
|
|
# continue
|
|
|
|
return np.array(relative_spike_times), np.array(eod_durations)
|
|
|
|
|
|
def search_eod_start_and_end_times(time, eod, index):
|
|
# TODO might break if a spike is in the cut off first or last eod!
|
|
|
|
# search start_time:
|
|
previous = index
|
|
working_idx = index-1
|
|
while True:
|
|
if eod[working_idx] < 0 < eod[previous]:
|
|
first_value = eod[working_idx]
|
|
second_value = eod[previous]
|
|
|
|
dif = second_value - first_value
|
|
part = np.abs(first_value/dif)
|
|
|
|
time_dif = np.abs(time[previous] - time[working_idx])
|
|
start_time = time[working_idx] + time_dif*part
|
|
|
|
break
|
|
|
|
previous = working_idx
|
|
working_idx -= 1
|
|
|
|
# search end_time
|
|
previous = index
|
|
working_idx = index + 1
|
|
while True:
|
|
if eod[previous] < 0 < eod[working_idx]:
|
|
first_value = eod[previous]
|
|
second_value = eod[working_idx]
|
|
|
|
dif = second_value - first_value
|
|
part = np.abs(first_value / dif)
|
|
|
|
time_dif = np.abs(time[previous] - time[working_idx])
|
|
end_time = time[working_idx] + time_dif * part
|
|
|
|
break
|
|
|
|
previous = working_idx
|
|
working_idx += 1
|
|
|
|
return start_time, end_time
|
|
|
|
|
|
def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray):
|
|
# adapted from Ramona
|
|
|
|
n = len(relative_spike_times)
|
|
if n == 0:
|
|
return -1
|
|
|
|
phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
|
|
vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2)
|
|
|
|
return vs
|
|
|
|
|
|
def detect_f_zero_in_frequency_trace(time, frequency, stimulus_start, sampling_interval, peak_buffer_percent=0.05, buffer=0.025):
|
|
stimulus_start = stimulus_start - time[0] # time start is generally != 0 and != delay
|
|
|
|
freq_before = frequency[int(buffer/sampling_interval):int((stimulus_start - buffer) / sampling_interval)]
|
|
min_before = min(freq_before)
|
|
max_before = max(freq_before)
|
|
mean_before = np.mean(freq_before)
|
|
|
|
# time where the f-zero is searched in
|
|
start_idx = int((stimulus_start-0.1*buffer) / sampling_interval)
|
|
end_idx = int((stimulus_start + buffer) / sampling_interval)
|
|
|
|
min_during_start_of_stim = min(frequency[start_idx:end_idx])
|
|
max_during_start_of_stim = max(frequency[start_idx:end_idx])
|
|
|
|
if abs(mean_before-min_during_start_of_stim) > abs(max_during_start_of_stim-mean_before):
|
|
f_zero = min_during_start_of_stim
|
|
else:
|
|
f_zero = max_during_start_of_stim
|
|
|
|
peak_buffer = (max_before - min_before) * peak_buffer_percent
|
|
if min_before - peak_buffer <= f_zero <= max_before + peak_buffer:
|
|
end_idx = start_idx + int((end_idx-start_idx)/2)
|
|
f_zero = np.mean(frequency[start_idx:end_idx])
|
|
|
|
# import matplotlib.pyplot as plt
|
|
# plt.plot(time, frequency)
|
|
# plt.plot(time[start_idx:end_idx], [f_zero for i in range(end_idx-start_idx)])
|
|
# plt.show()
|
|
|
|
return f_zero
|
|
|
|
|
|
def detect_f_infinity_in_freq_trace(time, frequency, stimulus_start, stimulus_duration, sampling_interval, length=0.1, buffer=0.025):
|
|
stimulus_end_time = stimulus_start + stimulus_duration - time[0]
|
|
|
|
start_idx = int((stimulus_end_time - length - buffer) / sampling_interval)
|
|
end_idx = int((stimulus_end_time - buffer) / sampling_interval)
|
|
|
|
return np.mean(frequency[start_idx:end_idx])
|
|
|
|
|
|
def detect_f_baseline_in_freq_trace(time, frequency, stimulus_start, sampling_interval, buffer=0.025):
|
|
stim_start = stimulus_start - time[0]
|
|
|
|
if stim_start < 0.1:
|
|
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
|
|
|
|
start_idx = int(buffer/sampling_interval)
|
|
end_idx = int((stim_start-buffer)/sampling_interval)
|
|
f_baseline = np.mean(frequency[start_idx:end_idx])
|
|
|
|
return f_baseline |