P-unit_model/helperFunctions.py
2020-01-15 14:37:12 +01:00

215 lines
6.4 KiB
Python

import os
import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
def get_subfolder_paths(basepath):
subfolders = []
for content in os.listdir(basepath):
content_path = basepath + content
if os.path.isdir(content_path):
subfolders.append(content_path)
return sorted(subfolders)
def get_traces(directory, trace_type, repro):
# trace_type = 1: Voltage p-unit
# trace_type = 2: EOD
# trace_type = 3: local EOD ~(EOD + stimulus)
# trace_type = 4: Stimulus
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
value_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
value_traces.append(x[trace_type-1])
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, value_traces
def get_all_traces(directory, repro):
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
print(info)
traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, traces
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
i = 0
diffs = np.diff(sorted(intensities))
margin = np.mean(diffs) * 0.6666
while True:
if i >= len(intensities):
break
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
i += 1
# Sort the lists so that intensities are increasing
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
intensities = x[0]
spiketimes = x[1]
return intensities, spiketimes, trans_amplitudes
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
intensity = intensities[index]
indices_to_merge = []
for i in range(index+1, len(intensities)):
if np.abs(intensities[i]-intensity) < margin:
indices_to_merge.append(i)
if len(indices_to_merge) != 0:
indices_to_merge.reverse()
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
all_the_same = True
for j in range(1, len(trans_amplitude_values)):
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
all_the_same = False
break
if all_the_same:
for idx in indices_to_merge:
del trans_amplitudes[idx]
else:
raise RuntimeError("Trans_amplitudes not the same....")
for idx in indices_to_merge:
spiketimes[index].extend(spiketimes[idx])
del spiketimes[idx]
del intensities[idx]
return intensities, spiketimes, trans_amplitudes
def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
times = []
mean_frequencies = []
for i in range(len(spiketimes)):
trial_times = []
trial_means = []
for j in range(len(spiketimes[i])):
time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
trial_means.append(isi_freq)
trial_times.append(time)
time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
times.append(time)
mean_frequencies.append(mean_freq)
return times, mean_frequencies
# TODO remove additional time vector calculation!
def calculate_isi_frequency(spiketimes, time_start, sampling_interval):
first_isi = spiketimes[0] - time_start
isis = [first_isi]
isis.extend(np.diff(spiketimes))
time = np.arange(time_start, spiketimes[-1], sampling_interval)
full_frequency = []
i = 0
for isi in isis:
if isi == 0:
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
continue
freq = 1 / isi
frequency_step = int(round(isi * (1 / sampling_interval))) * [freq]
full_frequency.extend(frequency_step)
i += 1
if len(full_frequency) != len(time):
if abs(len(full_frequency) - len(time)) == 1:
warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!")
if len(full_frequency) < len(time):
time = time[:len(full_frequency)]
else:
full_frequency = full_frequency[:len(time)]
else:
print("ERROR PRINT:")
print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time))
raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n"
"Frequency and time are not the same length!")
return time, full_frequency
def calculate_mean_frequency(trial_times, trial_freqs):
lengths = [len(t) for t in trial_times]
shortest = min(lengths)
time = trial_times[0][0:shortest]
shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]
return time, mean_freq
def crappy_smoothing(signal:list, window_size:int = 5) -> list:
smoothed = []
for i in range(len(signal)):
k = window_size
if i < window_size:
k = i
j = window_size
if i + j > len(signal):
j = len(signal) - i
smoothed.append(np.mean(signal[i-k:i+j]))
return smoothed
def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
contrast = cell_data.get_fi_contrasts()
time_axes = cell_data.get_time_axes_mean_frequencies()
mean_freqs = cell_data.get_mean_isi_frequencies()
if indices is None:
indices = np.arange(len(contrast))
for i in indices:
plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "mean_frequency_curves.png")
plt.close()