from DatParser import DatParser import numpy as np from warnings import warn import os from redetector import detect_spiketimes import matplotlib.pyplot as plt class Controller: def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)): self.data_path = data_path self.parser = DatParser(data_path) self.repros = self.get_repros() if len(repros_todo) != 0: not_available = [] repros_todo = list(np.unique(repros_todo)) for repro in repros_todo: if repro not in self.repros: not_available.append(repro) print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro)) for x in not_available: repros_todo.remove(x) self.repros_todo = repros_todo else: self.repros_todo = self.repros self.base_detection_parameters = base_detection_params self.sampling_interval = self.parser.get_sampling_interval() self.sorting = {} self.thresholds = {} self.stim_values = {} self.recording_times = {} def save_parameters(self, folder_path): header = "trial_index,threshold,min_window,step_size\n" for repro in self.thresholds.keys(): with open(folder_path + "/{}_thresholds.csv".format(repro), "w") as threshold_file: threshold_file.write(header) for i in range(len(self.sorting[repro])): if len(self.thresholds[repro][i]) == 0: # missing threshold info threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1)) else: # format and write given thresholding parameters thresh = self.thresholds[repro][i][0] min_window = self.thresholds[repro][i][1] step = self.thresholds[repro][i][2] threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step)) print("Thresholds saved!") def save_redetected_spikes(self, folder_path): # TODO save redetected spikes: redetection_folder = folder_path + "/redetected_spikes/" if not os.path.exists(redetection_folder): os.mkdir(redetection_folder) for repro in self.thresholds.keys(): for i in range(len(self.sorting[repro])): if len(self.thresholds[repro][i]) == 0: continue # If no params were saved, saving a redetection file makes no sense else: # format and write given thresholding parameters # i is the trial idx thresh = self.thresholds[repro][i][0] min_window = self.thresholds[repro][i][1] step = self.thresholds[repro][i][2] traces, spiketimes, recording_times = self.get_traces_with_spiketimes(repro) trace = traces[i] spiketimes = spiketimes[i] recording_times = recording_times sampling_interval = self.sampling_interval time = np.arange(len(trace)) * sampling_interval - recording_times[0] redetect = detect_spiketimes(time, trace, thresh, min_window, step) np.save(redetection_folder + "spikes_repro_{}_trial_{}.npy".format(repro, i), np.array(redetect)) print("Redetected spikes saved!") def get_repros(self): return self.parser.get_measured_repros() def get_unsorted_spiketimes(self, repro, file=""): return self.parser.get_spiketimes(repro, file) def get_unsorted_traces(self, repro, before=0, after=0): return self.parser.get_traces(repro, before, after) def get_traces_with_spiketimes(self, repro): if repro in self.sorting.keys(): traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1]) v1_traces = traces[1] spiketimes, metadata = self.get_unsorted_spiketimes(repro) sorted_spiketimes = np.array(spiketimes, dtype=object)[self.sorting[repro]] return v1_traces, sorted_spiketimes, self.recording_times[repro] if repro == "FICurve": recording_times = self.parser.get_recording_times() before = abs(recording_times[0]) after = recording_times[3] [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after) (spiketimes, metadata) = self.get_unsorted_spiketimes(repro) else: before = 0 after = 0 [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro) (spiketimes, metadata) = self.get_unsorted_spiketimes(repro) repro_stim_values = self.extract_stim_values(repro, metadata) self.stim_values[repro] = repro_stim_values if len(v1_traces) != len(spiketimes): warn("get_traces_with_spiketimes():Unequal number of traces ({}) and spiketimes ({}) for repro {}! Returning only according to the number of traces!".format(len(v1_traces), len(spiketimes), repro)) ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before) sorted_spiketimes = [] sorting = [] for i in range(len(v1_traces)): best = np.argmax(ash[i, :]) sorting.append(best) sorted_spiketimes.append(spiketimes[best]) self.sorting[repro] = sorting self.recording_times[repro] = (before, after) return v1_traces, sorted_spiketimes, (before, after) def get_stim_values(self, repro): return sorted(np.unique(self.stim_values[repro])) def extract_stim_values(self, repro, metadata_list): stim_values = [] if repro == "BaselineActivity": for i in range(len(metadata_list)): stim_values.append("None") elif repro == "FICurve": current_intensity = None for i in range(len(metadata_list)): if len(metadata_list[i]) != 0: current_intensity = float(metadata_list[i][-1]["intensity"][:-2]) stim_values.append(current_intensity) elif repro == "FileStimulus": for i in range(len(metadata_list)): stim_values.append(os.path.basename(metadata_list[i][0]["file"])) elif repro == "SAM": for i in range(len(metadata_list)): deltaf = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["deltaf"] contrast = metadata_list[i][0]["----- Stimulus -------------------------------------------------------"]["contrast"] stim_values.append("{} {}".format(deltaf, contrast)) if len(stim_values) == 0: return ["Stimulus values of repro not supported"] return stim_values def get_trials(self, repro, stimulus_value): pass def set_redetection_params(self, repro, trial_idx, params): if repro not in self.thresholds.keys(): self.thresholds[repro] = [(),]*len(self.sorting[repro]) self.thresholds[repro][trial_idx] = params def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): ash = np.zeros((len(traces), len(spiketimes))) total = len(traces) * len(spiketimes) count = 0 for i, trace in enumerate(traces): for j, spikes in enumerate(spiketimes): count += 1 if count % 50000 == 0: print("{} / {}".format(count, total)) if len(spikes) <= 1: ash[i, j] = -np.infty else: ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before) return ash def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before): # indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int) indices = (spike_train + before) / sampling_rate indices = np.array(indices, dtype=np.int) if len(indices) <= 1: return -np.infty # [v1[i] for i in indices if 0 <= i < len(v1)] applicable_indices = indices[(indices < len(v1)) & (indices > 0)] spike_values = v1[applicable_indices] average_height = np.mean(spike_values) return average_height # SLOW: # def average_spike_height(spike_train, v1, sampling_rate, before): # indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int) # if len(indices) <= 1: # return -np.infty # v1 = np.array(v1) # spike_values = [v1[i] for i in indices if 0 <= i < len(v1)] # average_height = np.mean(spike_values) # # return average_height