general working state before stim value grouping

This commit is contained in:
alexanderott 2021-03-07 12:50:30 +01:00
parent b0dfa22ef4
commit 8b10dba43b
6 changed files with 139 additions and 56 deletions

7
.gitignore vendored
View File

@ -1,10 +1,17 @@
*.dat
/data/
/temp/
/figures/
.idea/
__pycache__/
# Latex output files
*.out
*.aux
*.log
*.synctex.gz
*.toc
# other
*.csv
*.txt

View File

@ -6,7 +6,7 @@ from warnings import warn
import matplotlib.pyplot as plt
class DataProvider:
class Controller:
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
self.data_path = data_path
@ -28,12 +28,31 @@ class DataProvider:
self.repros_todo = self.repros
self.base_detection_parameters = base_detection_params
self.thresholds = {}
self.sampling_interval = self.parser.get_sampling_interval()
self.sorting = {}
self.thresholds = {}
self.recording_times = {}
def save_parameters(self):
header = "trial_index,threshold,min_window,step_size\n"
for repro in self.thresholds.keys():
# TODO change '.' to self.data_path after testing:
with open("." + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
threshold_file.write(header)
for i in range(len(self.sorting[repro])):
if len(self.thresholds[repro][i]) == 0:
# missing threshold info
threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1))
else:
# format and write given thresholding parameters
thresh = self.thresholds[repro][i][0]
min_window = self.thresholds[repro][i][1]
step = self.thresholds[repro][i][2]
threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step))
print("Thresholds saved!")
# TODO save redetected spikes:
def get_repros(self):
return self.parser.get_measured_repros()
@ -94,6 +113,12 @@ class DataProvider:
def get_trials(self, repro, stimulus_value):
pass
def set_redetection_params(self, repro, trial_idx, params):
if repro not in self.thresholds.keys():
self.thresholds[repro] = [(),]*len(self.sorting[repro])
self.thresholds[repro][trial_idx] = params
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
ash = np.zeros((len(traces), len(spiketimes)))

View File

@ -131,7 +131,7 @@ class DatParser:
def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace
spiketimes = []
warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
# warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
for metadata, key, data in Dl.iload(self.baseline_file):
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
@ -296,7 +296,7 @@ class DatParser:
spiketimes = []
metadata_list = []
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
# warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
if spiketimes_file == "":
if repro not in self.spike_files.keys():

View File

@ -1,30 +1,29 @@
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox, QSpinBox
QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox
from PyQt5.QtCore import pyqtSlot
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
from DataProvider import DataProvider
from Controller import Controller
from redetector import detect_spiketimes
class SpikeRedetectGui(QWidget):
def __init__(self, data_provider: DataProvider):
def __init__(self, controller: Controller):
super().__init__()
self.data_provider = data_provider
self.controller = controller
self.title = 'Spike Redetection'
self.left = 10
self.top = 10
self.width = 640
self.height = 400
self.width = 1500
self.height = 800
self.trial_idx = 0
self.initUI()
def initUI(self):
@ -35,7 +34,6 @@ class SpikeRedetectGui(QWidget):
middle = QHBoxLayout()
# Canvas Area for matplotlib figure
plot_area = QFrame()
plot_area_layout = QVBoxLayout()
self.canvas = PlotCanvas(self)
@ -48,26 +46,26 @@ class SpikeRedetectGui(QWidget):
plot_area_buttons.setLayout(plot_area_buttons_layout)
plot_area_layout.addWidget(plot_area_buttons)
button = QPushButton('Button1', self)
button = QPushButton('previous stimulus value', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button2', self)
button = QPushButton('previous trial', self)
button.setToolTip('Another nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(2))
button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button3', self)
button = QPushButton('next trial', self)
button.setToolTip('Even more nice buttons!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(3))
button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
button = QPushButton('Button4', self)
button.setToolTip('Even more nice buttons!')
button = QPushButton('next stimulus value', self)
button.setToolTip('But this is not a nice button!')
button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
plot_area_buttons_layout.addWidget(button)
@ -85,40 +83,43 @@ class SpikeRedetectGui(QWidget):
stim_val_label = QLabel("Stimulus value:")
self.stim_val_box = QComboBox()
repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label)
self.repro_box = QComboBox()
for repro in self.data_provider.get_repros():
self.repro_box.addItem(repro)
self.repro_box = QComboBox()
self.repro_box.addItems(self.controller.get_repros())
self.repro_box.currentTextChanged.connect(self.repro_change)
panel_layout.addWidget(self.repro_box)
stim_val_label = QLabel("Stimulus value:")
panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
panel_layout.addWidget(self.stim_val_box)
trial_label = QLabel("Trial:")
panel_layout.addWidget(trial_label)
trial_spinbox = QSpinBox(self)
trial_spinbox.setValue(1)
trial_spinbox.setSingleStep(1)
trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value()))
panel_layout.addWidget(trial_spinbox)
self.grouping_checkbox = QCheckBox()
self.grouping_checkbox.setText("Group by Stimulus Value:")
panel_layout.addWidget(self.grouping_checkbox)
filler = QFill(minh=200)
panel_layout.addWidget(filler)
self.status_label = QLabel("Done x/15 Stimulus Values")
panel_layout.addWidget(self.status_label)
trial_label = QLabel("Trial:")
panel_layout.addWidget(trial_label)
self.trial_spinbox = QSpinBox(self)
self.trial_spinbox.setValue(self.trial_idx)
self.trial_spinbox.setSingleStep(1)
self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value()))
panel_layout.addWidget(self.trial_spinbox)
# self.status_label = QLabel("Done x/15 Stimulus Values")
# panel_layout.addWidget(self.status_label)
filler = QFill()
panel_layout.addWidget(filler)
threshold_label = QLabel("Threshold:")
panel_layout.addWidget(threshold_label)
self.threshold_spinbox = QDoubleSpinBox(self)
self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0])
self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0])
self.threshold_spinbox.setSingleStep(0.5)
self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.threshold_spinbox)
@ -127,7 +128,7 @@ class SpikeRedetectGui(QWidget):
panel_layout.addWidget(window_label)
self.window_spinbox = QSpinBox(self)
self.window_spinbox.setMaximum(2**21)
self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1])
self.window_spinbox.setValue(self.controller.base_detection_parameters[1])
self.window_spinbox.setSingleStep(500)
self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.window_spinbox)
@ -136,43 +137,77 @@ class SpikeRedetectGui(QWidget):
panel_layout.addWidget(step_label)
self.step_spinbox = QSpinBox(self)
self.step_spinbox.setMaximum(2 ** 21)
self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2])
self.step_spinbox.setValue(self.controller.base_detection_parameters[2])
self.step_spinbox.setSingleStep(200)
self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
panel_layout.addWidget(self.step_spinbox)
button = QPushButton('Accept!', self)
button.setToolTip('Accept the threshold for current stimulus value')
button.clicked.connect(self.accept_redetection)
panel_layout.addWidget(button)
button = QPushButton('Save thresholds!', self)
button.setToolTip('Save accepted thresholds!')
button.clicked.connect(self.save_threshold_parameters)
panel_layout.addWidget(button)
panel.setLayout(panel_layout)
middle.addWidget(panel)
self.setLayout(middle)
# init the first repro:
self.repro_change()
self.show()
@pyqtSlot()
def redetection_changed(self):
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def trial_change(self, new_trial_idx):
# TODO test if in range of trials!
if new_trial_idx == self.trial_idx:
return
traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText())
if not 0 <= new_trial_idx < len(spiketimes):
print("trial_change():Out of range trial index!")
return
self.trial_idx = new_trial_idx
if new_trial_idx != self.trial_spinbox.value():
self.trial_spinbox.setValue(new_trial_idx)
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def repro_change(self):
repro = self.repro_box.currentText()
# reset trials and stim values
self.trial_change(0)
self.stim_val_box.clear()
for val in self.data_provider.get_stim_values(repro):
for val in self.controller.get_stim_values(repro):
self.stim_val_box.addItem(str(val))
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
@pyqtSlot()
def accept_redetection(self):
params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params)
self.trial_change(self.trial_idx + 1)
@pyqtSlot()
def save_threshold_parameters(self):
self.controller.save_parameters()
class PlotCanvas(FigureCanvas):
@ -189,6 +224,7 @@ class PlotCanvas(FigureCanvas):
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.current_repro = None
self.mouse_button_pressed = False
self.mouse_button = "-1"
@ -249,7 +285,6 @@ class PlotCanvas(FigureCanvas):
self.axes.set_ylim(new_ylimits)
self.draw()
def onclick(self, event):
if event.button in (1, 3):
@ -272,7 +307,7 @@ class PlotCanvas(FigureCanvas):
self.mouse_button_pressed = False
@pyqtSlot()
def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars):
def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars):
traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
trace = traces[trial_idx]
spiketimes = spiketimes[trial_idx]
@ -289,8 +324,12 @@ class PlotCanvas(FigureCanvas):
redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red")
ax.set_title('Trial XYZ')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
if self.current_repro == repro:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
else:
self.current_repro = repro
self.draw()

18
main.py
View File

@ -1,17 +1,29 @@
import sys
from PyQt5.QtWidgets import QApplication
import os
from DataProvider import DataProvider
from PyQt5.QtWidgets import QApplication, QFileDialog
from Controller import Controller
from SpikeRedetectGui import SpikeRedetectGui
test_file = "../neuronModel/data/final_sam/2010-11-08-al-invivo-1"
def main():
app = QApplication(sys.argv)
data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1")
data_path = QFileDialog.getExistingDirectory(caption='Select a directory')
if os.path.isdir(data_path):
data_provider = Controller(data_path)
else:
print("Cell loading didn't work!\n Cell folder: {}".format(data_path))
return
#data_provider = Controller(test_file)
ex = SpikeRedetectGui(data_provider)
sys.exit(app.exec_())
if __name__ == '__main__':
main()

View File

@ -1,6 +1,6 @@
from DatParser import DatParser
from DataProvider import DataProvider, average_spike_height
from Controller import Controller, average_spike_height
import os
import numpy as np
@ -18,7 +18,7 @@ def main():
if cell in failure_to_read_sam:
continue
cell_folder = os.path.join(DATA_FOLDER, cell)
data_provider = DataProvider(cell_folder)
data_provider = Controller(cell_folder)
repros = test_getting_repros(data_provider)
@ -42,7 +42,7 @@ def main():
plt.close()
def test_loading_spikes(data_provider: DataProvider, repro):
def test_loading_spikes(data_provider: Controller, repro):
return data_provider.parser.get_spiketimes(repro)
@ -50,7 +50,7 @@ def test_loading_traces(data_provider, repro):
return data_provider.get_traces(repro)
def test_getting_repros(data_provider: DataProvider):
def test_getting_repros(data_provider: Controller):
return data_provider.get_repros()