clean up helper functions

This commit is contained in:
a.ott 2020-01-29 10:59:27 +01:00
parent 3ffa6d5dbe
commit 0adb8e98b9
2 changed files with 121 additions and 106 deletions

View File

@ -4,68 +4,8 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from warnings import warn from warnings import warn
import scipy.stats import scipy.stats
from numba import jit
import numba as numba
def get_subfolder_paths(basepath):
subfolders = []
for content in os.listdir(basepath):
content_path = basepath + content
if os.path.isdir(content_path):
subfolders.append(content_path)
return sorted(subfolders)
def get_traces(directory, trace_type, repro):
# trace_type = 1: Voltage p-unit
# trace_type = 2: EOD
# trace_type = 3: local EOD ~(EOD + stimulus)
# trace_type = 4: Stimulus
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
value_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
value_traces.append(x[trace_type-1])
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, value_traces
def get_all_traces(directory, repro):
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
print(info)
traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, traces
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
@ -183,47 +123,8 @@ def calculate_mean_frequency(trial_times, trial_freqs):
return time, mean_freq return time, mean_freq
def crappy_smoothing(signal:list, window_size:int = 5) -> list: # @jit(nopython=True) # only faster at around 30 000 calls
smoothed = [] def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
for i in range(len(signal)):
k = window_size
if i < window_size:
k = i
j = window_size
if i + j > len(signal):
j = len(signal) - i
smoothed.append(np.mean(signal[i-k:i+j]))
return smoothed
def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
contrast = cell_data.get_fi_contrasts()
time_axes = cell_data.get_time_axes_mean_frequencies()
mean_freqs = cell_data.get_mean_isi_frequencies()
if indices is None:
indices = np.arange(len(contrast))
for i in indices:
plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "mean_frequency_curves.png")
plt.close()
def rectify(x):
if x < 0:
return 0
return x
def calculate_coefficient_of_variation(spiketimes: list) -> float:
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isi = np.diff(spiketimes) isi = np.diff(spiketimes)
std = np.std(isi) std = np.std(isi)
@ -232,17 +133,31 @@ def calculate_coefficient_of_variation(spiketimes: list) -> float:
return std/mean return std/mean
def calculate_serial_correlation(spiketimes: list, max_lag: int) -> list: # @jit(nopython=True) # maybe faster with more than ~60 000 calls
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
isi = np.diff(spiketimes) isi = np.diff(spiketimes)
if len(spiketimes) < max_lag + 1: if len(spiketimes) < max_lag + 1:
raise ValueError("Given list to short, with given max_lag") raise ValueError("Given list to short, with given max_lag")
cor = [] cor = np.zeros(max_lag)
for lag in range(max_lag): for lag in range(max_lag):
lag = lag + 1 lag = lag + 1
first = isi[:-lag] first = isi[:-lag]
second = isi[lag:] second = isi[lag:]
cor.append(np.corrcoef(first, second)[0][1]) cor[lag-1] = np.corrcoef(first, second)[0][1]
return cor return cor
def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray):
# adapted from Ramona
n = len(relative_spike_times)
if n == 0:
return 0
phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2)
return vs

View File

@ -0,0 +1,100 @@
import pyrelacs.DataLoader as dl
import os
import numpy as np
def get_subfolder_paths(basepath):
subfolders = []
for content in os.listdir(basepath):
content_path = basepath + content
if os.path.isdir(content_path):
subfolders.append(content_path)
return sorted(subfolders)
def get_traces(directory, trace_type, repro):
# trace_type = 1: Voltage p-unit
# trace_type = 2: EOD
# trace_type = 3: local EOD ~(EOD + stimulus)
# trace_type = 4: Stimulus
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
value_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
value_traces.append(x[trace_type-1])
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, value_traces
def get_all_traces(directory, repro):
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
print(info)
traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, traces
def crappy_smoothing(signal:list, window_size:int = 5) -> list:
smoothed = []
for i in range(len(signal)):
k = window_size
if i < window_size:
k = i
j = window_size
if i + j > len(signal):
j = len(signal) - i
smoothed.append(np.mean(signal[i-k:i+j]))
return smoothed
def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
contrast = cell_data.get_fi_contrasts()
time_axes = cell_data.get_time_axes_mean_frequencies()
mean_freqs = cell_data.get_mean_isi_frequencies()
if indices is None:
indices = np.arange(len(contrast))
for i in indices:
plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "mean_frequency_curves.png")
plt.close()