commit 4ba7f6a1b399c1acf2d1b773a32161ff3ea5c41f Author: alexanderott Date: Sat Feb 13 15:24:01 2021 +0100 remove redetector from neuron model and create its own project diff --git a/DatParser.py b/DatParser.py new file mode 100644 index 0000000..73cff0e --- /dev/null +++ b/DatParser.py @@ -0,0 +1,398 @@ + +from os.path import isdir, exists +from warnings import warn +import numpy as np + +import pyrelacs.DataLoader as Dl + + +class DatParser(): + + def __init__(self, dir_path): + self.base_path = dir_path + self.info_file = self.base_path + "/info.dat" + self.fi_file = self.base_path + "/fispikes1.dat" + self.baseline_file = self.base_path + "/basespikes1.dat" + self.sam_file = self.base_path + "/samallspikes1.dat" + self.stimuli_file = self.base_path + "/stimuli.dat" + self.__test_data_file_existence__() + + self.fi_recording_times = [] + self.sampling_interval = -1 + + def has_sam_recordings(self): + return exists(self.sam_file) + + def get_measured_repros(self): + repros = [] + for metadata, key, data in Dl.iload(self.stimuli_file): + repros.extend([d["repro"] for d in metadata if "repro" in d.keys()]) + + return sorted(np.unique(repros)) + + def get_baseline_length(self): + lengths = [] + for metadata, key, data in Dl.iload(self.baseline_file): + if len(metadata) != 0: + lengths.append(float(metadata[0]["duration"][:-3])) + + return lengths + + def get_species(self): + species = "" + for metadata in Dl.load(self.info_file): + if "Species" in metadata.keys(): + species = metadata["Species"] + elif "Subject" in metadata.keys(): + if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys(): + species = metadata["Subject"]["Species"] + + return species + + def get_gender(self): + gender = "not found" + for metadata in Dl.load(self.info_file): + if "Species" in metadata.keys(): + gender = metadata["Gender"] + elif "Subject" in metadata.keys(): + if isinstance(metadata["Subject"], dict) and "Gender" in metadata["Subject"].keys(): + gender = metadata["Subject"]["Gender"] + + return gender + + def get_quality(self): + quality = "" + for metadata in Dl.load(self.info_file): + if "Recording quality" in metadata.keys(): + quality = metadata["Recording quality"] + elif "Recording" in metadata.keys(): + if isinstance(metadata["Recording"], dict) and "Recording quality" in metadata["Recording"].keys(): + quality = metadata["Recording"]["Recording quality"] + return quality + + def get_cell_type(self): + type = "" + for metadata in Dl.load(self.info_file): + if len(metadata.keys()) < 3: + return "" + if "CellType" in metadata.keys(): + type = metadata["CellType"] + elif "Cell" in metadata.keys(): + if isinstance(metadata["Cell"], dict) and "CellType" in metadata["Cell"].keys(): + type = metadata["Cell"]["CellType"] + return type + + def get_fish_size(self): + size = "" + for metadata in Dl.load(self.info_file): + if "Species" in metadata.keys(): + size = metadata["Size"] + elif "Subject" in metadata.keys(): + if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys(): + size = metadata["Subject"]["Size"] + return size[:-2] + + def get_fi_curve_contrasts(self): + """ + + :return: list of tuples [(contrast, #_of_trials), ...] + """ + contrasts = [] + contrast = [-1, float("nan")] + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + if contrast[0] != -1: + contrasts.append(contrast) + contrast = [-1, 1] + contrast[0] = float(metadata[-1]["intensity"][:-2]) + else: + contrast[1] += 1 + + return np.array(contrasts) + + def traces_available(self) -> bool: + return True + + def frequencies_available(self) -> bool: + return False + + def spiketimes_available(self) -> bool: + return True + + def get_sampling_interval(self): + if self.sampling_interval == -1: + self.__read_sampling_interval__() + + return self.sampling_interval + + def get_recording_times(self): + if len(self.fi_recording_times) == 0: + self.__read_fi_recording_times__() + return self.fi_recording_times + + def get_baseline_traces(self): + return self.__get_traces__("BaselineActivity") + + def get_baseline_spiketimes(self): + # TODO change: reading from file -> detect from v1 trace + spiketimes = [] + warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.") + + for metadata, key, data in Dl.iload(self.baseline_file): + spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds + spiketimes.append(spikes) + + return spiketimes + + def get_fi_curve_traces(self): + return self.__get_traces__("FICurve") + + def get_fi_frequency_traces(self): + raise NotImplementedError("Not possible in .dat data type.\n" + "Please check availability with the x_available functions.") + + # TODO clean up/ rewrite + def get_fi_curve_spiketimes(self): + spiketimes = [] + pre_intensities = [] + pre_durations = [] + intensities = [] + trans_amplitudes = [] + pre_duration = -1 + index = -1 + skip = False + trans_amplitude = float('nan') + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + + metadata_index = 0 + + if '----- Control --------------------------------------------------------' in metadata[0].keys(): + metadata_index = 1 + pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2]) + trans_amplitude = float(metadata[0]["trans. amplitude"][:-2]) + if pre_duration == 0: + skip = False + else: + skip = True + continue + else: + if "preduration" in metadata[0].keys(): + pre_duration = float(metadata[0]["preduration"][:-2]) + trans_amplitude = float(metadata[0]["trans. amplitude"][:-2]) + if pre_duration == 0: + skip = False + else: + skip = True + continue + + if skip: + continue + if 'intensity' in metadata[metadata_index].keys(): + intensity = float(metadata[metadata_index]['intensity'][:-2]) + pre_intensity = float(metadata[metadata_index]['preintensity'][:-2]) + else: + intensity = float(metadata[1-metadata_index]['intensity'][:-2]) + pre_intensity = float(metadata[1-metadata_index]['preintensity'][:-2]) + + intensities.append(intensity) + pre_durations.append(pre_duration) + pre_intensities.append(pre_intensity) + trans_amplitudes.append(trans_amplitude) + spiketimes.append([]) + index += 1 + + if skip: + continue + + if data.shape[1] != 1: + raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!") + + spike_time_data = data[:, 0]/1000 + if len(spike_time_data) < 10: + print("# ignoring spike-train that contains less than 10 spikes.") + continue + if spike_time_data[-1] < 1: + print("# ignoring spike-train that ends before one second.") + continue + + spiketimes[index].append(spike_time_data) + + # TODO Check if sorting works! + new_order = np.arange(0, len(intensities), 1) + intensities, new_order = zip(*sorted(zip(intensities, new_order))) + intensities = list(intensities) + spiketimes = [spiketimes[i] for i in new_order] + trans_amplitudes = [trans_amplitudes[i] for i in new_order] + + for i in range(len(intensities)-1, -1, -1): + if len(spiketimes[i]) < 3: + del intensities[i] + del spiketimes[i] + del trans_amplitudes[i] + + return trans_amplitudes, intensities, spiketimes + + def get_sam_traces(self): + return self.__get_traces__("SAM") + + def get_sam_info(self): + contrasts = [] + delta_fs = [] + spiketimes = [] + durations = [] + eod_freqs = [] + trans_amplitudes = [] + index = -1 + for metadata, key, data in Dl.iload(self.sam_file): + factor = 1 + if key[0][0] == 'time': + if key[1][0] == 'ms': + factor = 1/1000 + elif key[1][0] == 's': + factor = 1 + else: + print("DataParser Dat: Unknown time notation:", key[1][0]) + if len(metadata) != 0: + if not "----- Stimulus -------------------------------------------------------" in metadata[0].keys(): + eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz + trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV + + duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor + contrast = float(metadata[0]["contrast"][:-1]) # in percent + delta_f = float(metadata[0]["deltaf"][:-2]) + else: + stimulus_dict = metadata[0]["----- Stimulus -------------------------------------------------------"] + analysis_dict = metadata[0]["----- Analysis -------------------------------------------------------"] + eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz + trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV + + duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor + contrast = float(stimulus_dict["contrast"][:-1]) # in percent + delta_f = float(stimulus_dict["deltaf"][:-2]) + + # delta_f = metadata[0]["true deltaf"] + # contrast = metadata[0]["true contrast"] + + contrasts.append(contrast) + delta_fs.append(delta_f) + durations.append(duration) + eod_freqs.append(eod_freq) + trans_amplitudes.append(trans_amplitude) + spiketimes.append([]) + index += 1 + + if data.shape[1] != 1: + raise RuntimeError("DatParser:get_sam_spiketimes():\n read data has more than one dimension!") + + spike_time_data = data[:, 0] * factor # saved in ms so use the factor to change it. + if len(spike_time_data) < 10: + continue + if spike_time_data[-1] < 0.1: + print("# ignoring spike-train that ends before one tenth of a second.") + continue + spiketimes[index].append(spike_time_data) + + return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes + + def __get_traces__(self, repro): + time_traces = [] + v1_traces = [] + eod_traces = [] + local_eod_traces = [] + stimulus_traces = [] + + nothing = True + + for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): + nothing = False + time_traces.append(time) + v1_traces.append(x[0]) + eod_traces.append(x[1]) + local_eod_traces.append(x[2]) + stimulus_traces.append(x[3]) + + traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] + + if nothing: + warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!" + warn(warn_msg) + + return traces + + def __iget_traces__(self, repro): + + for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): + # time, v1, eod, local_eod, stimulus + yield time, x[0], x[1], x[2], x[3] + + def __read_fi_recording_times__(self): + + delays = [] + stim_duration = [] + pause = [] + + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + control_key = '----- Control --------------------------------------------------------' + if control_key in metadata[0].keys(): + delays.append(float(metadata[0][control_key]["delay"][:-2])/1000) + pause.append(float(metadata[0][control_key]["pause"][:-2])/1000) + stim_key = "----- Test-Intensities -----------------------------------------------" + stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000) + + if "pause" in metadata[0].keys(): + delays.append(float(metadata[0]["delay"][:-2]) / 1000) + pause.append(float(metadata[0]["pause"][:-2]) / 1000) + stim_duration.append(float(metadata[0]["duration"][:-2]) / 1000) + + for l in [delays, stim_duration, pause]: + if len(l) == 0: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + elif len(set(l)) != 1: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Found multiple different delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + else: + self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]] + + def __read_sampling_interval__(self): + stop = False + sampling_intervals = [] + for metadata, key, data in Dl.iload(self.stimuli_file): + for md in metadata: + for i in range(4): + key = "sample interval" + str(i+1) + if key in md.keys(): + + sampling_intervals.append(float(md[key][:-2]) / 1000) + stop = True + else: + break + + if stop: + break + + if len(sampling_intervals) == 0: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not found in stimuli.dat this is not handled!\n" + + "with File:" + self.base_path) + + if len(set(sampling_intervals)) != 1: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not the same for all traces this is not handled!\n" + + "with File:" + self.base_path) + else: + self.sampling_interval = sampling_intervals[0] + + def __test_data_file_existence__(self): + if not exists(self.stimuli_file): + raise FileNotFoundError(self.stimuli_file + " file doesn't exist!") + if not exists(self.fi_file): + raise FileNotFoundError(self.fi_file + " file doesn't exist!") + if not exists(self.baseline_file): + raise FileNotFoundError(self.baseline_file + " file doesn't exist!") + # if not exists(self.sam_file): + # raise RuntimeError(self.sam_file + " file doesn't exist!") diff --git a/DataProvider.py b/DataProvider.py new file mode 100644 index 0000000..083a1fe --- /dev/null +++ b/DataProvider.py @@ -0,0 +1,20 @@ + +from DatParser import DatParser + +class DataProvider: + + def __init__(self, data_path): + self.data_path = data_path + # self.cell = CellData(data_path) + self.parser = DatParser(data_path) + self.thresholds = {} + + def get_repros(self): + pass + + def get_stim_values(self, repro): + pass + + def get_trials(self, repro, stimulus_value): + pass + diff --git a/SpikeRedetectGui.py b/SpikeRedetectGui.py new file mode 100644 index 0000000..cc4c021 --- /dev/null +++ b/SpikeRedetectGui.py @@ -0,0 +1,133 @@ + +from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ + QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ + QDoubleSpinBox, QComboBox +from PyQt5.QtCore import pyqtSlot + +import numpy as np + +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure + + +class SpikeRedetectGui(QWidget): + + def __init__(self, data_provider): + super().__init__() + self.data_provider = data_provider + self.title = 'Spike Redetection' + self.left = 10 + self.top = 10 + self.width = 640 + self.height = 400 + self.initUI() + + def initUI(self): + self.setWindowTitle(self.title) + self.setGeometry(self.left, self.top, self.width, self.height) + + # Middle: + middle = QHBoxLayout() + + # Canvas for matplotlib figure + m = PlotCanvas(self, width=5, height=4) + m.move(0, 0) + middle.addWidget(m) + middle.addWidget(QVLine()) + + # Side (options) panel + panel = QFrame() + panel_layout = QVBoxLayout() + + button = QPushButton('Button!', self) + button.setToolTip('A nice button!') + button.clicked.connect(lambda: threshold_spinbox.setValue(1)) + button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + panel_layout.addWidget(button) + + repro_label = QLabel("Repro:") + panel_layout.addWidget(repro_label) + + self.repro_box = QComboBox() + self.repro_box.addItem("placeholder repro") + panel_layout.addWidget(self.repro_box) + + stim_val_label = QLabel("Stimulus value:") + panel_layout.addWidget(stim_val_label) + self.stim_val_box = QComboBox() + self.stim_val_box.addItem("placeholder stim_value") + panel_layout.addWidget(self.stim_val_box) + + filler = QFill(maxh=20) + panel_layout.addWidget(filler) + + self.status_label = QLabel("Done x/15 Stimulus Values") + panel_layout.addWidget(self.status_label) + filler = QFill() + panel_layout.addWidget(filler) + + threshold_label = QLabel("Threshold:") + panel_layout.addWidget(threshold_label) + threshold_spinbox = QDoubleSpinBox(self) + threshold_spinbox.setValue(1) + threshold_spinbox.setSingleStep(0.5) + threshold_spinbox.valueChanged.connect(lambda: m.plot(threshold_spinbox.value())) + panel_layout.addWidget(threshold_spinbox) + + button = QPushButton('Accept!', self) + button.setToolTip('Accept the threshold for current stimulus value') + panel_layout.addWidget(button) + + + panel.setLayout(panel_layout) + middle.addWidget(panel) + + self.setLayout(middle) + self.show() + + +class PlotCanvas(FigureCanvas): + + def __init__(self, parent=None, width=5, height=4, dpi=100): + fig = Figure(figsize=(width, height), dpi=dpi) + self.axes = fig.add_subplot(111) + + FigureCanvas.__init__(self, fig) + self.setParent(parent) + + FigureCanvas.setSizePolicy(self, + QSizePolicy.Expanding, + QSizePolicy.Expanding) + FigureCanvas.updateGeometry(self) + self.plot() + + @pyqtSlot() + def plot(self, mean=1): + x = np.arange(0, 1, 0.0001) + data = np.sin(x*np.pi*2*mean) + ax = self.axes + ax.clear() + ax.plot(data, 'r-') + ax.set_title('Sinus Example') + self.draw() + + +class QHLine(QFrame): + def __init__(self): + super(QHLine, self).__init__() + self.setFrameShape(QFrame.HLine) + self.setFrameShadow(QFrame.Sunken) + + +class QVLine(QFrame): + def __init__(self): + super(QVLine, self).__init__() + self.setFrameShape(QFrame.VLine) + self.setFrameShadow(QFrame.Sunken) + + +class QFill(QFrame): + def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1): + super(QFill, self).__init__() + self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.setMaximumSize(maxw, maxh) diff --git a/main.py b/main.py new file mode 100644 index 0000000..2d18526 --- /dev/null +++ b/main.py @@ -0,0 +1,17 @@ + +import sys +from PyQt5.QtWidgets import QApplication, QWidget, QPushButton + +from spike_redetection.DataProvider import DataProvider +from spike_redetection.SpikeRedetectGui import SpikeRedetectGui + + +def main(): + app = QApplication(sys.argv) + data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1") + ex = SpikeRedetectGui(data_provider) + sys.exit(app.exec_()) + + +if __name__ == '__main__': + main()