spikeRedetector/DataProvider.py

139 lines
5.1 KiB
Python

from DatParser import DatParser
import numpy as np
from warnings import warn
import matplotlib.pyplot as plt
class DataProvider:
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path
self.parser = DatParser(data_path)
self.repros = self.get_repros()
if len(repros_todo) != 0:
not_available = []
repros_todo = list(np.unique(repros_todo))
for repro in repros_todo:
if repro not in self.repros:
not_available.append(repro)
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
for x in not_available:
repros_todo.remove(x)
self.repros_todo = repros_todo
else:
self.repros_todo = self.repros
self.base_detection_parameters = base_detection_params
self.thresholds = {}
self.sampling_interval = self.parser.get_sampling_interval()
self.sorting = {}
self.recording_times = {}
def get_repros(self):
return self.parser.get_measured_repros()
def get_unsorted_spiketimes(self, repro, file=""):
return self.parser.get_spiketimes(repro, file)
def get_unsorted_traces(self, repro, before=0, after=0):
return self.parser.get_traces(repro, before, after)
def get_traces_with_spiketimes(self, repro):
if repro in self.sorting.keys():
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
v1_traces = traces[1]
spiketimes, metadata = self.get_unsorted_spiketimes(repro)
sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]]
return v1_traces, sorted_spiketimes, self.recording_times[repro]
if repro == "FICurve":
recording_times = self.parser.get_recording_times()
before = abs(recording_times[0])
after = recording_times[3]
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
else:
before = 0
after = 0
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
if len(v1_traces) != len(spiketimes):
warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!"
"Returning only according to the number of traces!".format(repro))
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
sorted_spiketimes = []
sorting = []
for i in range(len(v1_traces)):
best = np.argmax(ash[i, :])
sorting.append(best)
sorted_spiketimes.append(spiketimes[best])
self.sorting[repro] = sorting
self.recording_times[repro] = (before, after)
return v1_traces, sorted_spiketimes, (before, after)
def get_stim_values(self, repro):
if repro == "BaselineActivity":
return ["None"]
elif repro == "FICurve":
# TODO other function that just provides the contrasts
return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()])
return ["repro not supported"]
def get_trials(self, repro, stimulus_value):
pass
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes)))
total = len(traces) * len(spiketimes)
count = 0
for i, trace in enumerate(traces):
for j, spikes in enumerate(spiketimes):
count += 1
if count % 50000 == 0:
print("{} / {}".format(count, total))
if len(spikes) <= 1:
ash[i, j] = -np.infty
else:
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
return ash
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
indices = (spike_train + before) / sampling_rate
indices = np.array(indices, dtype=np.int)
if len(indices) <= 1:
return -np.infty
# [v1[i] for i in indices if 0 <= i < len(v1)]
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
spike_values = v1[applicable_indices]
average_height = np.mean(spike_values)
return average_height
# SLOW:
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
# if len(indices) <= 1:
# return -np.infty
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height