215 lines
6.4 KiB
Python
215 lines
6.4 KiB
Python
import os
|
|
import pyrelacs.DataLoader as dl
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from warnings import warn
|
|
|
|
def get_subfolder_paths(basepath):
|
|
subfolders = []
|
|
for content in os.listdir(basepath):
|
|
content_path = basepath + content
|
|
if os.path.isdir(content_path):
|
|
subfolders.append(content_path)
|
|
|
|
return sorted(subfolders)
|
|
|
|
|
|
def get_traces(directory, trace_type, repro):
|
|
# trace_type = 1: Voltage p-unit
|
|
# trace_type = 2: EOD
|
|
# trace_type = 3: local EOD ~(EOD + stimulus)
|
|
# trace_type = 4: Stimulus
|
|
|
|
load_iter = dl.iload_traces(directory, repro=repro)
|
|
|
|
time_traces = []
|
|
value_traces = []
|
|
|
|
nothing = True
|
|
|
|
for info, key, time, x in load_iter:
|
|
nothing = False
|
|
time_traces.append(time)
|
|
value_traces.append(x[trace_type-1])
|
|
|
|
if nothing:
|
|
print("iload_traces found nothing for the BaselineActivity repro!")
|
|
|
|
return time_traces, value_traces
|
|
|
|
|
|
def get_all_traces(directory, repro):
|
|
load_iter = dl.iload_traces(directory, repro=repro)
|
|
|
|
time_traces = []
|
|
v1_traces = []
|
|
eod_traces = []
|
|
local_eod_traces = []
|
|
stimulus_traces = []
|
|
|
|
nothing = True
|
|
|
|
for info, key, time, x in load_iter:
|
|
nothing = False
|
|
time_traces.append(time)
|
|
v1_traces.append(x[0])
|
|
eod_traces.append(x[1])
|
|
local_eod_traces.append(x[2])
|
|
stimulus_traces.append(x[3])
|
|
print(info)
|
|
|
|
traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]
|
|
|
|
if nothing:
|
|
print("iload_traces found nothing for the BaselineActivity repro!")
|
|
|
|
return time_traces, traces
|
|
|
|
|
|
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
|
|
i = 0
|
|
|
|
diffs = np.diff(sorted(intensities))
|
|
margin = np.mean(diffs) * 0.6666
|
|
|
|
while True:
|
|
if i >= len(intensities):
|
|
break
|
|
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
|
|
i += 1
|
|
|
|
# Sort the lists so that intensities are increasing
|
|
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
|
|
intensities = x[0]
|
|
spiketimes = x[1]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
|
|
intensity = intensities[index]
|
|
|
|
indices_to_merge = []
|
|
for i in range(index+1, len(intensities)):
|
|
if np.abs(intensities[i]-intensity) < margin:
|
|
indices_to_merge.append(i)
|
|
|
|
if len(indices_to_merge) != 0:
|
|
indices_to_merge.reverse()
|
|
|
|
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
|
|
|
|
all_the_same = True
|
|
for j in range(1, len(trans_amplitude_values)):
|
|
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
|
|
all_the_same = False
|
|
break
|
|
|
|
if all_the_same:
|
|
for idx in indices_to_merge:
|
|
del trans_amplitudes[idx]
|
|
else:
|
|
raise RuntimeError("Trans_amplitudes not the same....")
|
|
for idx in indices_to_merge:
|
|
spiketimes[index].extend(spiketimes[idx])
|
|
del spiketimes[idx]
|
|
del intensities[idx]
|
|
|
|
return intensities, spiketimes, trans_amplitudes
|
|
|
|
|
|
def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
|
|
times = []
|
|
mean_frequencies = []
|
|
|
|
for i in range(len(spiketimes)):
|
|
trial_times = []
|
|
trial_means = []
|
|
for j in range(len(spiketimes[i])):
|
|
time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
|
|
trial_means.append(isi_freq)
|
|
trial_times.append(time)
|
|
|
|
time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
|
|
times.append(time)
|
|
mean_frequencies.append(mean_freq)
|
|
|
|
return times, mean_frequencies
|
|
|
|
|
|
# TODO remove additional time vector calculation!
|
|
def calculate_isi_frequency(spiketimes, time_start, sampling_interval):
|
|
first_isi = spiketimes[0] - time_start
|
|
isis = [first_isi]
|
|
isis.extend(np.diff(spiketimes))
|
|
time = np.arange(time_start, spiketimes[-1], sampling_interval)
|
|
|
|
full_frequency = []
|
|
i = 0
|
|
for isi in isis:
|
|
if isi == 0:
|
|
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
|
|
continue
|
|
freq = 1 / isi
|
|
frequency_step = int(round(isi * (1 / sampling_interval))) * [freq]
|
|
full_frequency.extend(frequency_step)
|
|
i += 1
|
|
if len(full_frequency) != len(time):
|
|
if abs(len(full_frequency) - len(time)) == 1:
|
|
warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!")
|
|
if len(full_frequency) < len(time):
|
|
time = time[:len(full_frequency)]
|
|
else:
|
|
full_frequency = full_frequency[:len(time)]
|
|
else:
|
|
print("ERROR PRINT:")
|
|
print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time))
|
|
raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n"
|
|
"Frequency and time are not the same length!")
|
|
|
|
return time, full_frequency
|
|
|
|
|
|
def calculate_mean_frequency(trial_times, trial_freqs):
|
|
lengths = [len(t) for t in trial_times]
|
|
shortest = min(lengths)
|
|
|
|
time = trial_times[0][0:shortest]
|
|
shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
|
|
mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]
|
|
|
|
return time, mean_freq
|
|
|
|
|
|
def crappy_smoothing(signal:list, window_size:int = 5) -> list:
|
|
smoothed = []
|
|
|
|
for i in range(len(signal)):
|
|
k = window_size
|
|
if i < window_size:
|
|
k = i
|
|
j = window_size
|
|
if i + j > len(signal):
|
|
j = len(signal) - i
|
|
|
|
smoothed.append(np.mean(signal[i-k:i+j]))
|
|
|
|
return smoothed
|
|
|
|
|
|
def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
|
|
contrast = cell_data.get_fi_contrasts()
|
|
time_axes = cell_data.get_time_axes_mean_frequencies()
|
|
mean_freqs = cell_data.get_mean_isi_frequencies()
|
|
|
|
if indices is None:
|
|
indices = np.arange(len(contrast))
|
|
|
|
for i in indices:
|
|
plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))
|
|
|
|
if save_path is None:
|
|
plt.show()
|
|
else:
|
|
plt.savefig(save_path + "mean_frequency_curves.png")
|
|
plt.close() |