164 lines
5.3 KiB
Python
164 lines
5.3 KiB
Python
import os
|
|
import pyrelacs.DataLoader as dl
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from warnings import warn
|
|
import scipy.stats
|
|
from numba import jit
|
|
import numba as numba
|
|
|
|
|
|
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
|
|
i = 0
|
|
|
|
diffs = np.diff(sorted(intensities))
|
|
margin = np.mean(diffs) * 0.6666
|
|
|
|
while True:
|
|
if i >= len(intensities):
|
|
break
|
|
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
|
|
i += 1
|
|
|
|
# Sort the lists so that intensities are increasing
|
|
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
|
|
intensities = x[0]
|
|
spiketimes = x[1]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
|
|
intensity = intensities[index]
|
|
|
|
indices_to_merge = []
|
|
for i in range(index+1, len(intensities)):
|
|
if np.abs(intensities[i]-intensity) < margin:
|
|
indices_to_merge.append(i)
|
|
|
|
if len(indices_to_merge) != 0:
|
|
indices_to_merge.reverse()
|
|
|
|
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
|
|
|
|
all_the_same = True
|
|
for j in range(1, len(trans_amplitude_values)):
|
|
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
|
|
all_the_same = False
|
|
break
|
|
|
|
if all_the_same:
|
|
for idx in indices_to_merge:
|
|
del trans_amplitudes[idx]
|
|
else:
|
|
raise RuntimeError("Trans_amplitudes not the same....")
|
|
for idx in indices_to_merge:
|
|
spiketimes[index].extend(spiketimes[idx])
|
|
del spiketimes[idx]
|
|
del intensities[idx]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
|
|
times = []
|
|
mean_frequencies = []
|
|
|
|
for i in range(len(spiketimes)):
|
|
trial_times = []
|
|
trial_means = []
|
|
for j in range(len(spiketimes[i])):
|
|
time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
|
|
trial_means.append(isi_freq)
|
|
trial_times.append(time)
|
|
|
|
time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
|
|
times.append(time)
|
|
mean_frequencies.append(mean_freq)
|
|
|
|
return times, mean_frequencies
|
|
|
|
|
|
# TODO remove additional time vector calculation!
|
|
def calculate_isi_frequency(spiketimes, time_start, sampling_interval):
|
|
first_isi = spiketimes[0] - time_start
|
|
isis = [first_isi]
|
|
isis.extend(np.diff(spiketimes))
|
|
time = np.arange(time_start, spiketimes[-1], sampling_interval)
|
|
|
|
full_frequency = []
|
|
i = 0
|
|
for isi in isis:
|
|
if isi == 0:
|
|
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
|
|
continue
|
|
freq = 1 / isi
|
|
frequency_step = int(round(isi * (1 / sampling_interval))) * [freq]
|
|
full_frequency.extend(frequency_step)
|
|
i += 1
|
|
if len(full_frequency) != len(time):
|
|
if abs(len(full_frequency) - len(time)) == 1:
|
|
warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!")
|
|
if len(full_frequency) < len(time):
|
|
time = time[:len(full_frequency)]
|
|
else:
|
|
full_frequency = full_frequency[:len(time)]
|
|
else:
|
|
print("ERROR PRINT:")
|
|
print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time))
|
|
raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n"
|
|
"Frequency and time are not the same length!")
|
|
|
|
return time, full_frequency
|
|
|
|
|
|
def calculate_mean_frequency(trial_times, trial_freqs):
|
|
lengths = [len(t) for t in trial_times]
|
|
shortest = min(lengths)
|
|
|
|
time = trial_times[0][0:shortest]
|
|
shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
|
|
mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]
|
|
|
|
return time, mean_freq
|
|
|
|
|
|
# @jit(nopython=True) # only faster at around 30 000 calls
|
|
def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
|
|
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
|
|
isi = np.diff(spiketimes)
|
|
std = np.std(isi)
|
|
mean = np.mean(isi)
|
|
|
|
return std/mean
|
|
|
|
|
|
# @jit(nopython=True) # maybe faster with more than ~60 000 calls
|
|
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
|
|
isi = np.diff(spiketimes)
|
|
if len(spiketimes) < max_lag + 1:
|
|
raise ValueError("Given list to short, with given max_lag")
|
|
|
|
cor = np.zeros(max_lag)
|
|
for lag in range(max_lag):
|
|
lag = lag + 1
|
|
first = isi[:-lag]
|
|
second = isi[lag:]
|
|
|
|
cor[lag-1] = np.corrcoef(first, second)[0][1]
|
|
|
|
return cor
|
|
|
|
|
|
def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray):
|
|
# adapted from Ramona
|
|
|
|
n = len(relative_spike_times)
|
|
if n == 0:
|
|
return 0
|
|
|
|
phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
|
|
vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2)
|
|
|
|
return vs
|