diff --git a/.gitignore b/.gitignore index 05d3c00..5e37a94 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.dat /data/ /temp/ +/figures/ # Latex output files *.out *.aux diff --git a/DatParser.py b/DatParser.py index 73cff0e..6737908 100644 --- a/DatParser.py +++ b/DatParser.py @@ -6,7 +6,7 @@ import numpy as np import pyrelacs.DataLoader as Dl -class DatParser(): +class DatParser: def __init__(self, dir_path): self.base_path = dir_path @@ -15,18 +15,22 @@ class DatParser(): self.baseline_file = self.base_path + "/basespikes1.dat" self.sam_file = self.base_path + "/samallspikes1.dat" self.stimuli_file = self.base_path + "/stimuli.dat" - self.__test_data_file_existence__() - - self.fi_recording_times = [] + self.spike_files = {"BaselineActivity": self.baseline_file, + "FICurve": self.fi_file, + "FileStimulus": self.base_path + "/stimspikes1.dat", + "SAM": self.sam_file} self.sampling_interval = -1 + self.fi_recording_times = [] - def has_sam_recordings(self): - return exists(self.sam_file) + self.spiketimes = {} + self.traces = {} + self.metadata = {} def get_measured_repros(self): repros = [] for metadata, key, data in Dl.iload(self.stimuli_file): repros.extend([d["repro"] for d in metadata if "repro" in d.keys()]) + repros.extend([d["RePro"] for d in metadata if "RePro" in d.keys()]) return sorted(np.unique(repros)) @@ -110,15 +114,6 @@ class DatParser(): return np.array(contrasts) - def traces_available(self) -> bool: - return True - - def frequencies_available(self) -> bool: - return False - - def spiketimes_available(self) -> bool: - return True - def get_sampling_interval(self): if self.sampling_interval == -1: self.__read_sampling_interval__() @@ -131,7 +126,7 @@ class DatParser(): return self.fi_recording_times def get_baseline_traces(self): - return self.__get_traces__("BaselineActivity") + return self.get_traces("BaselineActivity") def get_baseline_spiketimes(self): # TODO change: reading from file -> detect from v1 trace @@ -145,7 +140,7 @@ class DatParser(): return spiketimes def get_fi_curve_traces(self): - return self.__get_traces__("FICurve") + return self.get_traces("FICurve") def get_fi_frequency_traces(self): raise NotImplementedError("Not possible in .dat data type.\n" @@ -234,7 +229,7 @@ class DatParser(): return trans_amplitudes, intensities, spiketimes def get_sam_traces(self): - return self.__get_traces__("SAM") + return self.get_traces("SAM") def get_sam_info(self): contrasts = [] @@ -258,7 +253,7 @@ class DatParser(): eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV - duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor + duration = float(metadata[0]["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor contrast = float(metadata[0]["contrast"][:-1]) # in percent delta_f = float(metadata[0]["deltaf"][:-2]) else: @@ -267,7 +262,7 @@ class DatParser(): eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV - duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor + duration = float(stimulus_dict["duration"][:-2]) * factor # normaly saved in ms? so change it with the factor contrast = float(stimulus_dict["contrast"][:-1]) # in percent delta_f = float(stimulus_dict["deltaf"][:-2]) @@ -295,7 +290,32 @@ class DatParser(): return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes - def __get_traces__(self, repro): + def get_spiketimes(self, repro, spiketimes_file=""): + if repro in self.spiketimes.keys(): + return self.spiketimes[repro], self.metadata[repro] + + spiketimes = [] + metadata_list = [] + warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!") + + if spiketimes_file == "": + if repro not in self.spike_files.keys(): + raise NotImplementedError("No spiketime file specified in DatParser for repro {}! " + "Please specify spiketimes_file argument or add it to the spike_files dictionary in the Datparser class.".format(repro)) + spiketimes_file = self.spike_files[repro] + + for metadata, key, data in Dl.iload(spiketimes_file): + metadata_list.append(metadata) + spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds + spiketimes.append(spikes) + + self.spiketimes[repro] = spiketimes + self.metadata[repro] = metadata_list + return spiketimes, metadata_list + + def get_traces(self, repro, before=0, after=0): + if repro in self.traces.keys(): + return self.traces[repro] time_traces = [] v1_traces = [] eod_traces = [] @@ -304,7 +324,7 @@ class DatParser(): nothing = True - for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): + for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro, before=before, after=after): nothing = False time_traces.append(time) v1_traces.append(x[0]) @@ -318,14 +338,9 @@ class DatParser(): warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!" warn(warn_msg) + self.traces[repro] = traces return traces - def __iget_traces__(self, repro): - - for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): - # time, v1, eod, local_eod, stimulus - yield time, x[0], x[1], x[2], x[3] - def __read_fi_recording_times__(self): delays = [] @@ -386,13 +401,3 @@ class DatParser(): "with File:" + self.base_path) else: self.sampling_interval = sampling_intervals[0] - - def __test_data_file_existence__(self): - if not exists(self.stimuli_file): - raise FileNotFoundError(self.stimuli_file + " file doesn't exist!") - if not exists(self.fi_file): - raise FileNotFoundError(self.fi_file + " file doesn't exist!") - if not exists(self.baseline_file): - raise FileNotFoundError(self.baseline_file + " file doesn't exist!") - # if not exists(self.sam_file): - # raise RuntimeError(self.sam_file + " file doesn't exist!") diff --git a/DataProvider.py b/DataProvider.py index 083a1fe..8aa02d0 100644 --- a/DataProvider.py +++ b/DataProvider.py @@ -1,20 +1,138 @@ from DatParser import DatParser +import numpy as np +from warnings import warn +import matplotlib.pyplot as plt + + class DataProvider: - def __init__(self, data_path): + def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)): self.data_path = data_path - # self.cell = CellData(data_path) self.parser = DatParser(data_path) + self.repros = self.get_repros() + + if len(repros_todo) != 0: + not_available = [] + repros_todo = list(np.unique(repros_todo)) + for repro in repros_todo: + if repro not in self.repros: + not_available.append(repro) + print("DataProvider: Given cell hasn't measured '{}'. This repro will be skipped!".format(repro)) + for x in not_available: + repros_todo.remove(x) + + self.repros_todo = repros_todo + else: + self.repros_todo = self.repros + + self.base_Detection_parameters = base_detection_params self.thresholds = {} + self.sorting = {} + self.recording_times = {} + def get_repros(self): - pass + return self.parser.get_measured_repros() + + def get_unsorted_spiketimes(self, repro, file=""): + return self.parser.get_spiketimes(repro, file) + + def get_unsorted_traces(self, repro, before=0, after=0): + return self.parser.get_traces(repro, before, after) + + def get_traces_with_spiketimes(self, repro): + if repro in self.sorting.keys(): + traces = self.get_unsorted_traces(repro, self.recording_times[repro][0], self.recording_times[repro][1]) + v1_traces = traces[1] + spiketimes = self.get_unsorted_spiketimes(repro) + + sorted_spiketimes = np.array(spiketimes)[self.sorting[repro]] + return v1_traces, sorted_spiketimes, self.recording_times[repro] + + if repro == "FICurve": + recording_times = self.parser.get_recording_times() + before = abs(recording_times[0]) + after = recording_times[3] + [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.parser.get_traces(repro, before=before, after=after) + spiketimes = self.get_unsorted_spiketimes(repro) + + else: + before = 0 + after = 0 + [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = self.get_unsorted_traces(repro) + spiketimes = self.get_unsorted_spiketimes(repro) + + if len(v1_traces) != len(spiketimes): + warn("get_traces_with_spiketimes():Unequal number of traces and spiketimes for repro {}!" + "Returning only according to the number of traces!".format(repro)) + + ash = calculate_distance_matrix_traces_spikes(v1_traces, spiketimes, self.parser.get_sampling_interval(), before) + + sorted_spiketimes = [] + sorting = [] + for i in range(len(v1_traces)): + best = np.argmax(ash[i, :]) + sorting.append(best) + sorted_spiketimes.append(spiketimes[best]) + self.sorting[repro] = sorting + self.recording_times[repro] = (before, after) + + return v1_traces, sorted_spiketimes, (before, after) def get_stim_values(self, repro): - pass + if repro == "BaselineActivity": + return ["None"] + elif repro == "FICurve": + # TODO other function that just provides the contrasts + return sorted([str(x[0]) for x in self.parser.get_fi_curve_contrasts()]) + + return ["repro not supported"] def get_trials(self, repro, stimulus_value): pass + +def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): + ash = np.zeros((len(traces), len(spiketimes))) + total = len(traces) * len(spiketimes) + count = 0 + for i, trace in enumerate(traces): + for j, spikes in enumerate(spiketimes): + count += 1 + if count % 50000 == 0: + print("{} / {}".format(count, total)) + if len(spikes) <= 1: + ash[i, j] = -np.infty + else: + ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before) + + return ash + + +def average_spike_height(spike_train: np.ndarray, v1: np.ndarray, sampling_rate, before): + # indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int) + indices = (spike_train + before) / sampling_rate + indices = np.array(indices, dtype=np.int) + if len(indices) <= 1: + return -np.infty + # [v1[i] for i in indices if 0 <= i < len(v1)] + + applicable_indices = indices[(indices < len(v1)) & (indices > 0)] + spike_values = v1[applicable_indices] + average_height = np.mean(spike_values) + + return average_height + + +# SLOW: +# def average_spike_height(spike_train, v1, sampling_rate, before): +# indices = np.array([(s + before) / sampling_rate for s in spike_train], dtype=np.int) +# if len(indices) <= 1: +# return -np.infty +# v1 = np.array(v1) +# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)] +# average_height = np.mean(spike_values) +# +# return average_height \ No newline at end of file diff --git a/SpikeRedetectGui.py b/SpikeRedetectGui.py index cc4c021..fa65eb6 100644 --- a/SpikeRedetectGui.py +++ b/SpikeRedetectGui.py @@ -1,18 +1,19 @@ from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\ QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\ - QDoubleSpinBox, QComboBox + QDoubleSpinBox, QComboBox, QSpinBox from PyQt5.QtCore import pyqtSlot import numpy as np from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas from matplotlib.figure import Figure +from DataProvider import DataProvider class SpikeRedetectGui(QWidget): - def __init__(self, data_provider): + def __init__(self, data_provider: DataProvider): super().__init__() self.data_provider = data_provider self.title = 'Spike Redetection' @@ -29,36 +30,80 @@ class SpikeRedetectGui(QWidget): # Middle: middle = QHBoxLayout() - # Canvas for matplotlib figure - m = PlotCanvas(self, width=5, height=4) + # Canvas Area for matplotlib figure + + plot_area = QFrame() + plot_area_layout = QVBoxLayout() + m = PlotCanvas(self) m.move(0, 0) - middle.addWidget(m) + plot_area_layout.addWidget(m) + + # plot area buttons + plot_area_buttons = QFrame() + plot_area_buttons_layout = QHBoxLayout() + plot_area_buttons.setLayout(plot_area_buttons_layout) + plot_area_layout.addWidget(plot_area_buttons) + + button = QPushButton('Button1', self) + button.setToolTip('A nice button!') + button.clicked.connect(lambda: threshold_spinbox.setValue(1)) + button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + plot_area_buttons_layout.addWidget(button) + + button = QPushButton('Button2', self) + button.setToolTip('Another nice button!') + button.clicked.connect(lambda: threshold_spinbox.setValue(2)) + button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + plot_area_buttons_layout.addWidget(button) + + button = QPushButton('Button3', self) + button.setToolTip('Even more nice buttons!') + button.clicked.connect(lambda: threshold_spinbox.setValue(3)) + button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + plot_area_buttons_layout.addWidget(button) + + button = QPushButton('Button4', self) + button.setToolTip('Even more nice buttons!') + button.clicked.connect(lambda: threshold_spinbox.setValue(4)) + button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + plot_area_buttons_layout.addWidget(button) + + plot_area.setLayout(plot_area_layout) + + middle.addWidget(plot_area) middle.addWidget(QVLine()) # Side (options) panel panel = QFrame() + panel.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum) + panel.setMaximumWidth(200) panel_layout = QVBoxLayout() - button = QPushButton('Button!', self) - button.setToolTip('A nice button!') - button.clicked.connect(lambda: threshold_spinbox.setValue(1)) - button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) - panel_layout.addWidget(button) + + + stim_val_label = QLabel("Stimulus value:") + self.stim_val_box = QComboBox() repro_label = QLabel("Repro:") panel_layout.addWidget(repro_label) - self.repro_box = QComboBox() - self.repro_box.addItem("placeholder repro") - panel_layout.addWidget(self.repro_box) + self.repro_box.currentTextChanged.connect(self.repro_change) + for repro in self.data_provider.get_repros(): + self.repro_box.addItem(repro) - stim_val_label = QLabel("Stimulus value:") + panel_layout.addWidget(self.repro_box) panel_layout.addWidget(stim_val_label) - self.stim_val_box = QComboBox() - self.stim_val_box.addItem("placeholder stim_value") panel_layout.addWidget(self.stim_val_box) - filler = QFill(maxh=20) + trial_label = QLabel("Trial:") + panel_layout.addWidget(trial_label) + threshold_spinbox = QSpinBox(self) + threshold_spinbox.setValue(1) + threshold_spinbox.setSingleStep(1) + threshold_spinbox.valueChanged.connect() + panel_layout.addWidget(threshold_spinbox) + + filler = QFill(minh=200) panel_layout.addWidget(filler) self.status_label = QLabel("Done x/15 Stimulus Values") @@ -78,18 +123,25 @@ class SpikeRedetectGui(QWidget): button.setToolTip('Accept the threshold for current stimulus value') panel_layout.addWidget(button) - panel.setLayout(panel_layout) middle.addWidget(panel) self.setLayout(middle) self.show() + @pyqtSlot() + def repro_change(self): + repro = self.repro_box.currentText() + self.stim_val_box.clear() + + for val in self.data_provider.get_stim_values(repro): + self.stim_val_box.addItem(str(val)) + class PlotCanvas(FigureCanvas): - def __init__(self, parent=None, width=5, height=4, dpi=100): - fig = Figure(figsize=(width, height), dpi=dpi) + def __init__(self, parent=None, dpi=100): + fig = Figure(dpi=dpi) self.axes = fig.add_subplot(111) FigureCanvas.__init__(self, fig) @@ -107,7 +159,7 @@ class PlotCanvas(FigureCanvas): data = np.sin(x*np.pi*2*mean) ax = self.axes ax.clear() - ax.plot(data, 'r-') + ax.plot(x, data, 'r-') ax.set_title('Sinus Example') self.draw() @@ -127,7 +179,8 @@ class QVLine(QFrame): class QFill(QFrame): - def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1): + def __init__(self, maxw=int(2**24)-1, maxh=int(2**24)-1, minw=0, minh=0): super(QFill, self).__init__() self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) self.setMaximumSize(maxw, maxh) + self.setMinimumSize(minw, minh) diff --git a/main.py b/main.py index 2d18526..ed03aa8 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,14 @@ import sys -from PyQt5.QtWidgets import QApplication, QWidget, QPushButton +from PyQt5.QtWidgets import QApplication -from spike_redetection.DataProvider import DataProvider -from spike_redetection.SpikeRedetectGui import SpikeRedetectGui +from DataProvider import DataProvider +from SpikeRedetectGui import SpikeRedetectGui def main(): app = QApplication(sys.argv) - data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1") + data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1") ex = SpikeRedetectGui(data_provider) sys.exit(app.exec_()) diff --git a/redetector.py b/redetector.py new file mode 100644 index 0000000..0062949 --- /dev/null +++ b/redetector.py @@ -0,0 +1,71 @@ + +import numpy as np + +from thunderfish.eventdetection import detect_peaks + + +def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000): + all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step) + + return time[all_peak_indicies] + + +def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000): + split_start = 0 + step_size = split_step + break_threshold = 0.25 + splits = [] + + if len(v1) < min_length: + splits = [(0, len(v1))] + else: + last_max = max(v1[0:min_length]) + last_min = min(v1[0:min_length]) + idx = min_length + + while idx < len(v1): + if idx + step_size > len(v1): + splits.append((split_start, len(v1))) + break + + max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold + min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold + + if not max_similar or not min_similar: + # print("new split") + end_idx = np.argmin(v1[idx-20:idx+21]) - 20 + splits.append((split_start, idx+end_idx)) + split_start = idx+end_idx + last_max = max(v1[split_start:split_start + min_length]) + last_min = min(v1[split_start:split_start + min_length]) + idx = split_start + min_length + continue + else: + pass + # print("elongated!") + + idx += step_size + + if splits[-1][1] != len(v1): + splits.append((split_start, len(v1))) + + # plt.plot(v1) + + # for s in splits: + # plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]]))) + + all_peaks = [] + for s in splits: + first_index = s[0] + last_index = s[1] + std = np.std(v1[first_index:last_index]) + peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold) + peaks = peaks + first_index + # plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o') + all_peaks.extend(peaks) + + # plt.show() + # plt.close() + # all_peaks = np.array(all_peaks) + + return all_peaks diff --git a/testing.py b/testing.py new file mode 100644 index 0000000..eeb330b --- /dev/null +++ b/testing.py @@ -0,0 +1,80 @@ + +from DatParser import DatParser +from DataProvider import DataProvider, average_spike_height + +import os +import numpy as np +import matplotlib.pyplot as plt + +from redetector import detect_spiketimes + +DATA_FOLDER = "../neuronModel/data/final/" + +failure_to_read_sam = ["2012-06-27-an-invivo-1", "2012-12-13-ag-invivo-1"] + + +def main(): + for cell in sorted(os.listdir(DATA_FOLDER)): + if cell in failure_to_read_sam: + continue + cell_folder = os.path.join(DATA_FOLDER, cell) + data_provider = DataProvider(cell_folder) + + repros = test_getting_repros(data_provider) + + print("\n", cell) + for repro in repros: + if not repro in data_provider.parser.spike_files.keys(): + continue + + print(repro) + traces, spiketimes, rec_times = data_provider.get_traces_with_spiketimes(repro) + sampling_interval = data_provider.parser.get_sampling_interval() + for i in range(len(traces)): + time = np.arange(len(traces[i])) * sampling_interval - rec_times[0] + plt.figure(figsize=(10, 5)) + plt.plot(time, traces[i]) + plt.eventplot(spiketimes[i], lineoffsets=max(traces[i]) + 1, colors="black") + redetect = detect_spiketimes(time, traces[i]) + plt.eventplot(redetect, lineoffsets=max(traces[i]) + 2, colors="red") + # plt.savefig("figures/best_spikes_test/" + cell + "_" + repro + str(i) + ".png") + plt.show() + plt.close() + + +def test_loading_spikes(data_provider: DataProvider, repro): + return data_provider.parser.get_spiketimes(repro) + + +def test_loading_traces(data_provider, repro): + return data_provider.get_traces(repro) + + +def test_getting_repros(data_provider: DataProvider): + return data_provider.get_repros() + + +# def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before): +# ash = np.zeros((len(traces), len(spiketimes))) +# +# for i, trace in enumerate(traces): +# for j, spikes in enumerate(spiketimes): +# if len(spikes) <= 1: +# ash[i, j] = -np.infty +# else: +# ash[i, j] = average_spike_height(spikes, trace, sampling_rate, before) +# +# return ash +# + +# def average_spike_height(spike_train, v1, sampling_rate, before): +# indices = np.array([(s - before) / sampling_rate for s in spike_train], dtype=np.int) +# v1 = np.array(v1) +# spike_values = [v1[i] for i in indices if 0 <= i < len(v1)] +# average_height = np.mean(spike_values) +# +# return average_height + + +if __name__ == '__main__': + main() \ No newline at end of file