P-unit_model/helperFunctions.py
2020-02-27 09:28:34 +01:00

292 lines
9.0 KiB
Python

import numpy as np
from warnings import warn
from thunderfish.eventdetection import detect_peaks, threshold_crossing_times, threshold_crossings
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
i = 0
diffs = np.diff(sorted(intensities))
margin = np.mean(diffs) * 0.6666
while True:
if i >= len(intensities):
break
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
i += 1
# Sort the lists so that intensities are increasing
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
intensities = x[0]
spiketimes = x[1]
return intensities, spiketimes, trans_amplitudes
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
intensity = intensities[index]
indices_to_merge = []
for i in range(index+1, len(intensities)):
if np.abs(intensities[i]-intensity) < margin:
indices_to_merge.append(i)
if len(indices_to_merge) != 0:
indices_to_merge.reverse()
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
all_the_same = True
for j in range(1, len(trans_amplitude_values)):
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
all_the_same = False
break
if all_the_same:
for idx in indices_to_merge:
del trans_amplitudes[idx]
else:
raise RuntimeError("Trans_amplitudes not the same....")
for idx in indices_to_merge:
spiketimes[index].extend(spiketimes[idx])
del spiketimes[idx]
del intensities[idx]
return intensities, spiketimes, trans_amplitudes
def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
times = []
mean_frequencies = []
for i in range(len(spiketimes)):
trial_times = []
trial_means = []
for j in range(len(spiketimes[i])):
time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
trial_means.append(isi_freq)
trial_times.append(time)
time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
times.append(time)
mean_frequencies.append(mean_freq)
return times, mean_frequencies
def calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms=True):
"""
Calculates the frequency over time according to the inter spike intervals.
:param spiketimes: time points spikes were measured array_like
:param sampling_interval: the sampling interval in which the frequency should be given back
:param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval
:return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike
"""
isis = np.diff(spiketimes)
if time_in_ms:
isis = isis / 1000
sampling_interval = sampling_interval / 1000
full_frequency = np.array([])
for isi in isis:
if isi == 0:
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
print("ISI was zero:", spiketimes)
continue
freq = 1 / isi
frequency_step = np.full(int(round(isi * (1 / sampling_interval))), freq)
full_frequency = np.concatenate((full_frequency, frequency_step))
return full_frequency
def calculate_mean_frequency(trial_times, trial_freqs):
lengths = [len(t) for t in trial_times]
shortest = min(lengths)
time = trial_times[0][0:shortest]
shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]
return time, mean_freq
def mean_freq_of_spiketimes_after_time_x(spiketimes, sampling_interval, time_x, time_in_ms=False):
""" Calculates the mean frequency of the portion of spiketimes that is after last_x_time """
if len(spiketimes) <= 1:
return 0
freq = calculate_isi_frequency(spiketimes, sampling_interval, time_in_ms)
# returned frequency starts at the
idx = int((time_x-spiketimes[0]) / sampling_interval)
mean_freq = np.mean(freq[idx:])
return mean_freq
# @jit(nopython=True) # only faster at around 30 000 calls
def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isi = np.diff(spiketimes)
std = np.std(isi)
mean = np.mean(isi)
return std/mean
# @jit(nopython=True) # maybe faster with more than ~60 000 calls
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
isi = np.diff(spiketimes)
if len(spiketimes) < max_lag + 1:
raise ValueError("Given list to short, with given max_lag")
cor = np.zeros(max_lag)
for lag in range(max_lag):
lag = lag + 1
first = isi[:-lag]
second = isi[lag:]
cor[lag-1] = np.corrcoef(first, second)[0][1]
return cor
def calculate_eod_frequency(time, eod):
# TODO for few samples very volatile measure!
up_indicies, down_indicies = threshold_crossings(eod, 0)
up_times, down_times = threshold_crossing_times(time, eod, 0, up_indicies, down_indicies)
if len(up_times) == 0:
return 0
durations = np.diff(up_times)
mean_duration = np.mean(durations)
return 1/mean_duration
def calculate_vector_strength(times, eods, v1_traces):
# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
# dl.iload_traces(repro='BaselineActivity')
relative_spike_times = []
eod_durations = []
if len(times) == 0:
print("-----LENGTH OF TIMES = 0")
for recording in range(len(times)):
spiketime_idices = detect_spikes(v1_traces[recording])
rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketime_idices)
relative_spike_times.extend(rel_spikes)
eod_durations.extend(eod_durs)
print(__vector_strength__(np.array(rel_spikes), np.array(eod_durs)))
relative_spike_times = np.array(relative_spike_times)
eod_durations = np.array(eod_durations)
return __vector_strength__(relative_spike_times, eod_durations)
def detect_spikes(v1, split=20, threshold=3):
total = len(v1)
all_peaks = []
for n in range(split):
length = int(total / split)
first_index = n * length
last_index = (n + 1) * length
std = np.std(v1[first_index:last_index])
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
peaks = peaks + first_index
all_peaks.extend(peaks)
all_peaks = np.array(all_peaks)
return all_peaks
def calculate_phases(relative_spike_times, eod_durations):
phase_times = np.zeros(len(relative_spike_times))
for i in range(len(relative_spike_times)):
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
return phase_times
def eods_around_spikes(time, eod, spiketime_idices):
eod_durations = []
relative_spike_times = []
for spike_idx in spiketime_idices:
start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx)
eod_durations.append(end_time-start_time)
spiketime = time[spike_idx]
relative_spike_times.append(spiketime - start_time)
return relative_spike_times, eod_durations
def search_eod_start_and_end_times(time, eod, index):
# TODO might break if a spike is in the cut off first or last eod!
# search start_time:
previous = index
working_idx = index-1
while True:
if eod[working_idx] < 0 < eod[previous]:
first_value = eod[working_idx]
second_value = eod[previous]
dif = second_value - first_value
part = np.abs(first_value/dif)
time_dif = np.abs(time[previous] - time[working_idx])
start_time = time[working_idx] + time_dif*part
break
previous = working_idx
working_idx -= 1
# search end_time
previous = index
working_idx = index + 1
while True:
if eod[previous] < 0 < eod[working_idx]:
first_value = eod[previous]
second_value = eod[working_idx]
dif = second_value - first_value
part = np.abs(first_value / dif)
time_dif = np.abs(time[previous] - time[working_idx])
end_time = time[working_idx] + time_dif * part
break
previous = working_idx
working_idx += 1
return start_time, end_time
def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray):
# adapted from Ramona
n = len(relative_spike_times)
if n == 0:
return -1
phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2)
return vs