spikeRedetector/Controller.py
2021-07-02 11:27:12 +02:00

196 lines
7.9 KiB
Python

from DatParser import DatParser
import numpy as np
from warnings import warn
import os
import matplotlib.pyplot as plt
class Controller:
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path
self.parser = DatParser(data_path)
self.repros = self.get_repros()
if len(repros_todo) != 0:
not_available = []
repros_todo = list(np.unique(repros_todo))
for repro in repros_todo:
if repro not in self.repros:
not_available.append(repro)
print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro))
for x in not_available:
repros_todo.remove(x)
self.repros_todo = repros_todo
else:
self.repros_todo = self.repros
self.base_detection_parameters = base_detection_params
self.sampling_interval = self.parser.get_sampling_interval()
self.sorting = {}
self.thresholds = {}
self.stim_values = {}
self.recording_times = {}
def save_parameters(self, folder_path):
header = "trial_index,threshold,min_window,step_size\n"
for repro in self.thresholds.keys():
# TODO change '.' to self.data_path after testing:
with open(folder_path + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
threshold_file.write(header)
for i in range(len(self.sorting[repro])):
if len(self.thresholds[repro][i]) == 0:
# missing threshold info
threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1))
else:
# format and write given thresholding parameters
thresh = self.thresholds[repro][i][0]
min_window = self.thresholds[repro][i][1]
step = self.thresholds[repro][i][2]
threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step))
print("Thresholds saved!")
def save_redetected_spikes(self, folder_path):
# TODO save redetected spikes:
header = "trial_index,threshold,min_window,step_size\n"
for repro in self.thresholds.keys():
# TODO change '.' to self.data_path after testing:
with open(folder_path + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
threshold_file.write(header)
def get_repros(self):
return self.parser.get_measured_repros()
def get_unsorted_spiketimes(self, repro, file=""):
return self.parser.get_spiketimes(repro, file)
def get_unsorted_traces(self, repro, before=0, after=0):
return self.parser.get_traces(repro, before, after)
def get_traces_with_spiketimes(self, repro):
if repro in self.sorting.keys():
traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1])
v1_traces = traces[1]
spiketimes, metadata = self.get_unsorted_spiketimes(repro)
sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]]
return v1_traces, sorted_spiketimes, self.recording_times[repro]
if repro == "FICurve":
recording_times = self.parser.get_recording_times()
before = abs(recording_times[0])
after = recording_times[3]
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after)
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
else:
before = 0
after = 0
[time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro)
(spiketimes, metadata) = self.get_unsorted_spiketimes(repro)
repro_stim_values = self.extract_stim_values(repro, metadata)
self.stim_values[repro] = repro_stim_values
if len(v1_traces) != len(spiketimes):
warn("get_traces_with_spiketimes():Unequal number of traces ({}) and spiketimes ({}) for repro {}! Returning only according to the number of traces!".format(len(v1_traces), len(spiketimes), repro))
ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before)
sorted_spiketimes = []
sorting = []
for i in range(len(v1_traces)):
best = np.argmax(ash[i, :])
sorting.append(best)
sorted_spiketimes.append(spiketimes[best])
self.sorting[repro] = sorting
self.recording_times[repro] = (before, after)
return v1_traces, sorted_spiketimes, (before, after)
def get_stim_values(self, repro):
return sorted(np.unique(self.stim_values[repro]))
def extract_stim_values(self, repro, metadata_list):
stim_values = []
if repro == "BaselineActivity":
for i in range(len(metadata_list)):
stim_values.append("None")
elif repro == "FICurve":
current_intensity = None
for i in range(len(metadata_list)):
if len(metadata_list[i]) != 0:
current_intensity = float(metadata_list[i][-1]["intensity"][:-2])
stim_values.append(current_intensity)
elif repro == "FileStimulus":
for i in range(len(metadata_list)):
stim_values.append(os.path.basename(metadata_list[i][0]["file"]))
elif repro == "SAM":
for i in range(len(metadata_list)):
deltaf = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["deltaf"]
contrast = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["contrast"]
stim_values.append("{} {}".format(deltaf, contrast))
if len(stim_values) == 0:
return ["Stimulus values of repro not supported"]
return stim_values
def get_trials(self, repro, stimulus_value):
pass
def set_redetection_params(self, repro, trial_idx, params):
if repro not in self.thresholds.keys():
self.thresholds[repro] = [(),]*len(self.sorting[repro])
self.thresholds[repro][trial_idx] = params
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes)))
total = len(traces) * len(spiketimes)
count = 0
for i, trace in enumerate(traces):
for j, spikes in enumerate(spiketimes):
count += 1
if count % 50000 == 0:
print("{} / {}".format(count, total))
if len(spikes) <= 1:
ash[i, j] = -np.infty
else:
ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before)
return ash
def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
indices = (spike_train + before) / sampling_rate
indices = np.array(indices, dtype=np.int)
if len(indices) <= 1:
return -np.infty
# [v1[i] for i in indices if 0 <= i < len(v1)]
applicable_indices = indices[(indices < len(v1)) & (indices > 0)]
spike_values = v1[applicable_indices]
average_height = np.mean(spike_values)
return average_height
# SLOW:
# def average_spike_height(spike_train, v1, sampling_rate, before):
# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int)
# if len(indices) <= 1:
# return -np.infty
# v1 = np.array(v1)
# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)]
# average_height = np.mean(spike_values)
#
# return average_height