import parser.DataParserFactory as dpf from warnings import warn import os from my_util import helperFunctions as hf import numpy as np COUNT = 0 def icelldata_of_dir(base_path, test_for_v1_trace=True): global COUNT for item in sorted(os.listdir(base_path)): item_path = base_path + item if not os.path.isdir(item_path) and not item.endswith(".nix"): print("ignoring path: " + item_path) print("It isn't expected to be cell data.") continue try: data = CellData(item_path) if test_for_v1_trace: try: trace = data.get_base_traces(trace_type=data.V1) if len(trace) == 0: COUNT += 1 print("NO V1 TRACE FOUND: ", item_path) print(COUNT) continue except IndexError as e: COUNT += 1 print(data.get_data_path(), "Threw Index error!") print(COUNT) print(str(e), "\n") continue except ValueError as e: COUNT += 1 print(data.get_data_path(), "Threw Value error!") print(COUNT) print(str(e), "\n") yield data else: yield data except TypeError as e: warn_msg = str(e) warn(warn_msg) print("Currently throw errors: {}".format(COUNT)) class CellData: # Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.) # should be abstract from the way the data is saved in the background .dat vs .nix # traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]] TIME = 0 V1 = 1 EOD = 2 LOCAL_EOD = 3 STIMULUS = 4 def __init__(self, data_path): self.data_path = data_path self.parser = dpf.get_parser(data_path) self.base_traces = None self.base_spikes = None self.fi_traces = None self.fi_intensities = None self.fi_spiketimes = None self.fi_trans_amplitudes = None self.mean_isi_frequencies = None self.time_axes = None # self.metadata = None self.sam_spiketimes = None self.sam_contrasts = None self.sam_delta_fs = None self.sam_eod_freqs = None self.sam_durations = None self.sam_trans_amplitudes = None self.sam_traces = None self.sampling_interval = None self.recording_times = None def get_data_path(self): return self.data_path def get_cell_name(self): return os.path.basename(self.data_path) def has_sam_recordings(self): return self.parser.has_sam_recordings() def get_baseline_length(self): return self.parser.get_baseline_length() def get_fi_curve_contrasts_with_trial_number(self): return self.parser.get_fi_curve_contrasts() def get_base_traces(self, trace_type=None): if self.base_traces is None: self.base_traces = self.parser.get_baseline_traces() if trace_type is None: return self.base_traces else: return self.base_traces[trace_type] def get_base_spikes(self, threshold=2.5, min_length=5000, split_step=1000, re_calculate=False, only_first=False): if self.base_spikes is not None and not re_calculate: return self.base_spikes saved_spikes_file = "base_spikes_ndarray.npy" full_path = os.path.join(self.data_path, saved_spikes_file) if os.path.isdir(self.data_path) and os.path.exists(full_path) and not re_calculate: self.base_spikes = np.load(full_path, allow_pickle=True) print("Baseline spikes loaded from file.") return self.base_spikes if self.base_spikes is None or re_calculate: print("Baseline spikes are being (re-)calculated...") times = self.get_base_traces(self.TIME) v1_traces = self.get_base_traces(self.V1) spiketimes = [] for i in range(len(times)): if only_first and i > 0: break spiketimes.append(hf.detect_spiketimes(times[i], v1_traces[i], threshold=threshold, min_length=min_length, split_step=split_step)) # plt.plot(times[0], v1_traces[0]) # idx_pos = np.array(spiketimes) / self.get_sampling_interval() # idx_pos = np.array(np.rint(idx_pos), np.int) # # plt.plot(spiketimes[0], np.array(v1_traces[0])[idx_pos][0, :], 'o') # plt.show() self.base_spikes = np.array(spiketimes) if os.path.isdir(self.data_path): np.save(full_path, self.base_spikes) print("Calculated spikes saved to file") return self.base_spikes def get_base_isis(self): spikestimes = self.get_base_spikes() isis = [] for spikes in spikestimes: isis.extend(np.diff(spikes)) return isis def get_fi_traces(self): if self.fi_traces is None: warn("Fi traces not sorted in the same way as the spiketimes!!!") self.fi_traces = self.parser.get_fi_curve_traces() return self.fi_traces def get_fi_spiketimes(self): self.__read_fi_spiketimes_info__() return self.fi_spiketimes def get_fi_intensities(self): self.__read_fi_spiketimes_info__() return self.fi_intensities def get_fi_contrasts(self): if self.fi_intensities is None: self.__read_fi_spiketimes_info__() contrast = [] for i in range(len(self.fi_intensities)): contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i]) return contrast def get_sam_traces(self): if self.sam_traces is None: warn("SAM traces might not be sorted in the same way as the spiketimes?!!!") self.sam_traces = self.parser.get_sam_traces() return self.sam_traces def get_sam_spiketimes(self): self.__read_sam_info__() return self.sam_spiketimes def get_sam_contrasts(self): self.__read_sam_info__() return self.sam_contrasts def get_sam_delta_frequencies(self): self.__read_sam_info__() return self.sam_delta_fs def get_sam_durations(self): self.__read_sam_info__() return self.sam_durations def get_sam_eod_frequencies(self): self.__read_sam_info__() return self.sam_eod_freqs def get_sam_trans_amplitudes(self): self.__read_sam_info__() return self.sam_trans_amplitudes def get_mean_fi_curve_isi_frequencies(self): if self.mean_isi_frequencies is None: self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces( self.get_fi_spiketimes(), self.get_sampling_interval()) return self.mean_isi_frequencies def get_time_axes_fi_curve_mean_frequencies(self): if self.time_axes is None: self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces( self.get_fi_spiketimes(), self.get_sampling_interval()) return self.time_axes def get_base_frequency(self): base_freqs = [] for freq in self.get_mean_fi_curve_isi_frequencies(): delay = self.get_delay() sampling_interval = self.get_sampling_interval() if delay < 0.1: warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") idx_start = int(0.025 / sampling_interval) idx_end = int((delay - 0.025) / sampling_interval) base_freqs.append(np.mean(freq[idx_start:idx_end])) return np.median(base_freqs) def get_sampling_interval(self) -> float: if self.sampling_interval is None: self.sampling_interval = self.parser.get_sampling_interval() return self.sampling_interval def get_recording_times(self) -> list: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times def get_time_start(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times[0] def get_delay(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return abs(self.recording_times[0]) def get_time_end(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times[2] + self.recording_times[3] def get_stimulus_start(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times[1] def get_stimulus_duration(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times[2] def get_stimulus_end(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.get_stimulus_start() + self.get_stimulus_duration() def get_after_stimulus_duration(self) -> float: if self.recording_times is None: self.recording_times = self.parser.get_recording_times() return self.recording_times[3] def get_eod_frequency(self, recalculate=False): eod_freq_file_name = "eod_freq_peak_based.npy" eod_freq_file_path = os.path.join(self.get_data_path(), eod_freq_file_name) if os.path.exists(eod_freq_file_path) and not recalculate: print("Loaded eod_freq from file") return np.load(eod_freq_file_path) else: eods = self.get_base_traces(self.EOD) sampling_interval = self.get_sampling_interval() frequencies = [] for eod in eods: frequencies.append(hf.calculate_eod_frequency(eod, sampling_interval)) mean_freq = np.mean(frequencies) np.save(eod_freq_file_path, mean_freq) print("Saved eod freq to file.") return mean_freq def __read_fi_spiketimes_info__(self): if self.fi_spiketimes is None: self.fi_trans_amplitudes, self.fi_intensities, self.fi_spiketimes = self.parser.get_fi_curve_spiketimes() if os.path.exists(self.get_data_path() + "/redetected_spikes.npy"): print("overwriting fi_spiketimes with redetected ones.") contrasts = self.get_fi_contrasts() spikes = np.load(self.get_data_path() + "/redetected_spikes.npy", allow_pickle=True) trace_contrasts_idx = np.load(self.get_data_path() + "/fi_traces_contrasts.npy", allow_pickle=True) trace_max_similarity = np.load(self.get_data_path() + "/fi_traces_contrasts_similarity.npy", allow_pickle=True) spiketimes = [] for i in range(len(contrasts)): contrast_list = [] for j in range(len(trace_contrasts_idx)): if trace_contrasts_idx[j] == i and trace_max_similarity[j][0] > trace_max_similarity[j][1] + 0.15: contrast_list.append(spikes[j]) spiketimes.append(contrast_list) self.fi_spiketimes = spiketimes def __read_sam_info__(self): if self.sam_spiketimes is None: spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes = self.parser.get_sam_info() self.sam_spiketimes = spiketimes self.sam_contrasts = contrasts self.sam_delta_fs = delta_fs self.sam_eod_freqs = eod_freqs self.sam_durations = durations self.sam_trans_amplitudes = trans_amplitudes # def get_metadata(self): # self.__read_metadata__() # return self.metadata # # def get_metadata_item(self, item): # self.__read_metadata__() # if item in self.metadata.keys(): # return self.metadata[item] # else: # raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item)) # # def __read_metadata__(self): # if self.metadata is None: # # TODO!! # pass