196 lines
7.9 KiB
Python
196 lines
7.9 KiB
Python
|
|
from DatParser import DatParser
|
|
|
|
import numpy as np
|
|
from warnings import warn
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class Controller:
|
|
|
|
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
|
|
self.data_path = data_path
|
|
self.parser = DatParser(data_path)
|
|
self.repros = self.get_repros()
|
|
|
|
if len(repros_todo) != 0:
|
|
not_available = []
|
|
repros_todo = list(np.unique(repros_todo))
|
|
for repro in repros_todo:
|
|
if repro not in self.repros:
|
|
not_available.append(repro)
|
|
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
|
|
for x in not_available:
|
|
repros_todo.remove(x)
|
|
|
|
self.repros_todo = repros_todo
|
|
else:
|
|
self.repros_todo = self.repros
|
|
|
|
self.base_detection_parameters = base_detection_params
|
|
self.sampling_interval = self.parser.get_sampling_interval()
|
|
|
|
self.sorting = {}
|
|
self.thresholds = {}
|
|
self.stim_values = {}
|
|
self.recording_times = {}
|
|
|
|
def save_parameters(self, folder_path):
|
|
header = "trial_index,threshold,min_window,step_size\n"
|
|
for repro in self.thresholds.keys():
|
|
# TODO change '.' to self.data_path after testing:
|
|
with open(folder_path + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
|
|
threshold_file.write(header)
|
|
for i in range(len(self.sorting[repro])):
|
|
if len(self.thresholds[repro][i]) == 0:
|
|
# missing threshold info
|
|
threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1))
|
|
else:
|
|
# format and write given thresholding parameters
|
|
thresh = self.thresholds[repro][i][0]
|
|
min_window = self.thresholds[repro][i][1]
|
|
step = self.thresholds[repro][i][2]
|
|
threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step))
|
|
print("Thresholds saved!")
|
|
|
|
def save_redetected_spikes(self, folder_path):
|
|
# TODO save redetected spikes:
|
|
header = "trial_index,threshold,min_window,step_size\n"
|
|
for repro in self.thresholds.keys():
|
|
# TODO change '.' to self.data_path after testing:
|
|
with open(folder_path + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
|
|
threshold_file.write(header)
|
|
|
|
def get_repros(self):
|
|
return self.parser.get_measured_repros()
|
|
|
|
def get_unsorted_spiketimes(self, repro, file=""):
|
|
return self.parser.get_spiketimes(repro, file)
|
|
|
|
def get_unsorted_traces(self, repro, before=0, after=0):
|
|
return self.parser.get_traces(repro, before, after)
|
|
|
|
def get_traces_with_spiketimes(self, repro):
|
|
if repro in self.sorting.keys():
|
|
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
|
|
v1_traces = traces[1]
|
|
spiketimes, metadata = self.get_unsorted_spiketimes(repro)
|
|
|
|
sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]]
|
|
return v1_traces, sorted_spiketimes, self.recording_times[repro]
|
|
|
|
if repro == "FICurve":
|
|
recording_times = self.parser.get_recording_times()
|
|
before = abs(recording_times[0])
|
|
after = recording_times[3]
|
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
|
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
|
|
|
else:
|
|
before = 0
|
|
after = 0
|
|
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
|
|
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
|
|
|
|
repro_stim_values = self.extract_stim_values(repro, metadata)
|
|
self.stim_values[repro] = repro_stim_values
|
|
if len(v1_traces) != len(spiketimes):
|
|
warn("get_traces_with_spiketimes():Unequal number of traces ({}) and spiketimes ({}) for repro {}! Returning only according to the number of traces!".format(len(v1_traces), len(spiketimes), repro))
|
|
|
|
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
|
|
|
|
sorted_spiketimes = []
|
|
sorting = []
|
|
for i in range(len(v1_traces)):
|
|
best = np.argmax(ash[i, :])
|
|
sorting.append(best)
|
|
sorted_spiketimes.append(spiketimes[best])
|
|
self.sorting[repro] = sorting
|
|
self.recording_times[repro] = (before, after)
|
|
|
|
return v1_traces, sorted_spiketimes, (before, after)
|
|
|
|
def get_stim_values(self, repro):
|
|
return sorted(np.unique(self.stim_values[repro]))
|
|
|
|
def extract_stim_values(self, repro, metadata_list):
|
|
stim_values = []
|
|
if repro == "BaselineActivity":
|
|
for i in range(len(metadata_list)):
|
|
stim_values.append("None")
|
|
elif repro == "FICurve":
|
|
current_intensity = None
|
|
for i in range(len(metadata_list)):
|
|
if len(metadata_list[i]) != 0:
|
|
current_intensity = float(metadata_list[i][-1]["intensity"][:-2])
|
|
|
|
stim_values.append(current_intensity)
|
|
|
|
elif repro == "FileStimulus":
|
|
for i in range(len(metadata_list)):
|
|
stim_values.append(os.path.basename(metadata_list[i][0]["file"]))
|
|
|
|
elif repro == "SAM":
|
|
for i in range(len(metadata_list)):
|
|
deltaf = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["deltaf"]
|
|
contrast = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["contrast"]
|
|
stim_values.append("{} {}".format(deltaf, contrast))
|
|
|
|
if len(stim_values) == 0:
|
|
return ["Stimulus values of repro not supported"]
|
|
|
|
return stim_values
|
|
|
|
def get_trials(self, repro, stimulus_value):
|
|
pass
|
|
|
|
def set_redetection_params(self, repro, trial_idx, params):
|
|
if repro not in self.thresholds.keys():
|
|
self.thresholds[repro] = [(),]*len(self.sorting[repro])
|
|
|
|
self.thresholds[repro][trial_idx] = params
|
|
|
|
|
|
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
|
ash = np.zeros((len(traces), len(spiketimes)))
|
|
total = len(traces) * len(spiketimes)
|
|
count = 0
|
|
for i, trace in enumerate(traces):
|
|
for j, spikes in enumerate(spiketimes):
|
|
count += 1
|
|
if count % 50000 == 0:
|
|
print("{} / {}".format(count, total))
|
|
if len(spikes) <= 1:
|
|
ash[i, j] = -np.infty
|
|
else:
|
|
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
|
|
|
|
return ash
|
|
|
|
|
|
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
|
|
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
|
indices = (spike_train + before) / sampling_rate
|
|
indices = np.array(indices, dtype=np.int)
|
|
if len(indices) <= 1:
|
|
return -np.infty
|
|
# [v1[i] for i in indices if 0 <= i < len(v1)]
|
|
|
|
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
|
|
spike_values = v1[applicable_indices]
|
|
average_height = np.mean(spike_values)
|
|
|
|
return average_height
|
|
|
|
|
|
# SLOW:
|
|
# def average_spike_height(spike_train, v1, sampling_rate, before):
|
|
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
|
|
# if len(indices) <= 1:
|
|
# return -np.infty
|
|
# v1 = np.array(v1)
|
|
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
|
|
# average_height = np.mean(spike_values)
|
|
#
|
|
# return average_height |