139 lines
5.1 KiB
Python
139 lines
5.1 KiB
Python
|
|
from DatParser import DatParser
|
|
|
|
import numpy as np
|
|
from warnings import warn
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class DataProvider:
|
|
|
|
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
|
|
self.data_path = data_path
|
|
self.parser = DatParser(data_path)
|
|
self.repros = self.get_repros()
|
|
|
|
if len(repros_todo) != 0:
|
|
not_available = []
|
|
repros_todo = list(np.unique(repros_todo))
|
|
for repro in repros_todo:
|
|
if repro not in self.repros:
|
|
not_available.append(repro)
|
|
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
|
|
for x in not_available:
|
|
repros_todo.remove(x)
|
|
|
|
self.repros_todo = repros_todo
|
|
else:
|
|
self.repros_todo = self.repros
|
|
|
|
self.base_detection_parameters = base_detection_params
|
|
self.thresholds = {}
|
|
self.sampling_interval = self.parser.get_sampling_interval()
|
|
|
|
self.sorting = {}
|
|
self.recording_times = {}
|
|
|
|
def get_repros(self):
|
|
return self.parser.get_measured_repros()
|
|
|
|
def get_unsorted_spiketimes(self, repro, file=""):
|
|
return self.parser.get_spiketimes(repro, file)
|
|
|
|
def get_unsorted_traces(self, repro, before=0, after=0):
|
|
return self.parser.get_traces(repro, before, after)
|
|
|
|
def get_traces_with_spiketimes(self, repro):
|
|
if repro in self.sorting.keys():
|
|
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
|
|
v1_traces = traces[1]
|
|
spiketimes, metadata = self.get_unsorted_spiketimes(repro)
|
|
|
|
sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]]
|
|
return v1_traces, sorted_spiketimes, self.recording_times[repro]
|
|
|
|
if repro == "FICurve":
|
|
recording_times = self.parser.get_recording_times()
|
|
before = abs(recording_times[0])
|
|
after = recording_times[3]
|
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
|
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
|
|
|
else:
|
|
before = 0
|
|
after = 0
|
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
|
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
|
|
|
if len(v1_traces) != len(spiketimes):
|
|
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
|
|
"Returning only according to the number of traces!".format(repro))
|
|
|
|
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
|
|
|
|
sorted_spiketimes = []
|
|
sorting = []
|
|
for i in range(len(v1_traces)):
|
|
best = np.argmax(ash[i, :])
|
|
sorting.append(best)
|
|
sorted_spiketimes.append(spiketimes[best])
|
|
self.sorting[repro] = sorting
|
|
self.recording_times[repro] = (before, after)
|
|
|
|
return v1_traces, sorted_spiketimes, (before, after)
|
|
|
|
def get_stim_values(self, repro):
|
|
if repro == "BaselineActivity":
|
|
return ["None"]
|
|
elif repro == "FICurve":
|
|
# TODO other function that just provides the contrasts
|
|
return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()])
|
|
|
|
return ["repro not supported"]
|
|
|
|
def get_trials(self, repro, stimulus_value):
|
|
pass
|
|
|
|
|
|
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
|
ash = np.zeros((len(traces), len(spiketimes)))
|
|
total = len(traces) * len(spiketimes)
|
|
count = 0
|
|
for i, trace in enumerate(traces):
|
|
for j, spikes in enumerate(spiketimes):
|
|
count += 1
|
|
if count % 50000 == 0:
|
|
print("{} / {}".format(count, total))
|
|
if len(spikes) <= 1:
|
|
ash[i, j] = -np.infty
|
|
else:
|
|
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
|
|
|
|
return ash
|
|
|
|
|
|
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
|
|
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
|
indices = (spike_train + before) / sampling_rate
|
|
indices = np.array(indices, dtype=np.int)
|
|
if len(indices) <= 1:
|
|
return -np.infty
|
|
# [v1[i] for i in indices if 0 <= i < len(v1)]
|
|
|
|
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
|
|
spike_values = v1[applicable_indices]
|
|
average_height = np.mean(spike_values)
|
|
|
|
return average_height
|
|
|
|
|
|
# SLOW:
|
|
# def average_spike_height(spike_train, v1, sampling_rate, before):
|
|
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
|
# if len(indices) <= 1:
|
|
# return -np.infty
|
|
# v1 = np.array(v1)
|
|
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
|
|
# average_height = np.mean(spike_values)
|
|
#
|
|
# return average_height |