general working state before stim value grouping
This commit is contained in:
parent
b0dfa22ef4
commit
8b10dba43b
9
.gitignore
vendored
9
.gitignore
vendored
@ -1,10 +1,17 @@
|
||||
*.dat
|
||||
|
||||
/data/
|
||||
/temp/
|
||||
/figures/
|
||||
.idea/
|
||||
__pycache__/
|
||||
# Latex output files
|
||||
*.out
|
||||
*.aux
|
||||
*.log
|
||||
*.synctex.gz
|
||||
*.toc
|
||||
*.toc
|
||||
|
||||
# other
|
||||
*.csv
|
||||
*.txt
|
@ -6,7 +6,7 @@ from warnings import warn
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class DataProvider:
|
||||
class Controller:
|
||||
|
||||
def __init__(self, data_path, repros_todo=(), base_detection_params=(2, 5000, 1000)):
|
||||
self.data_path = data_path
|
||||
@ -28,12 +28,31 @@ class DataProvider:
|
||||
self.repros_todo = self.repros
|
||||
|
||||
self.base_detection_parameters = base_detection_params
|
||||
self.thresholds = {}
|
||||
self.sampling_interval = self.parser.get_sampling_interval()
|
||||
|
||||
self.sorting = {}
|
||||
self.thresholds = {}
|
||||
self.recording_times = {}
|
||||
|
||||
def save_parameters(self):
|
||||
header = "trial_index,threshold,min_window,step_size\n"
|
||||
for repro in self.thresholds.keys():
|
||||
# TODO change '.' to self.data_path after testing:
|
||||
with open("." + "/{}_thresholds.csv".format(repro), "w") as threshold_file:
|
||||
threshold_file.write(header)
|
||||
for i in range(len(self.sorting[repro])):
|
||||
if len(self.thresholds[repro][i]) == 0:
|
||||
# missing threshold info
|
||||
threshold_file.write("{},{},{},{}\n".format(i, -1, -1, -1))
|
||||
else:
|
||||
# format and write given thresholding parameters
|
||||
thresh = self.thresholds[repro][i][0]
|
||||
min_window = self.thresholds[repro][i][1]
|
||||
step = self.thresholds[repro][i][2]
|
||||
threshold_file.write("{},{},{},{}\n".format(i, thresh, min_window, step))
|
||||
print("Thresholds saved!")
|
||||
# TODO save redetected spikes:
|
||||
|
||||
def get_repros(self):
|
||||
return self.parser.get_measured_repros()
|
||||
|
||||
@ -94,6 +113,12 @@ class DataProvider:
|
||||
def get_trials(self, repro, stimulus_value):
|
||||
pass
|
||||
|
||||
def set_redetection_params(self, repro, trial_idx, params):
|
||||
if repro not in self.thresholds.keys():
|
||||
self.thresholds[repro] = [(),]*len(self.sorting[repro])
|
||||
|
||||
self.thresholds[repro][trial_idx] = params
|
||||
|
||||
|
||||
def calculate_distance_matrix_traces_spikes(traces, spiketimes, sampling_rate, before):
|
||||
ash = np.zeros((len(traces), len(spiketimes)))
|
@ -131,7 +131,7 @@ class DatParser:
|
||||
def get_baseline_spiketimes(self):
|
||||
# TODO change: reading from file -> detect from v1 trace
|
||||
spiketimes = []
|
||||
warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
|
||||
# warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
|
||||
|
||||
for metadata, key, data in Dl.iload(self.baseline_file):
|
||||
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
|
||||
@ -296,7 +296,7 @@ class DatParser:
|
||||
|
||||
spiketimes = []
|
||||
metadata_list = []
|
||||
warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
|
||||
# warn("get_spiketimes():Spiketimes aren't sorted the same way as traces!")
|
||||
|
||||
if spiketimes_file == "":
|
||||
if repro not in self.spike_files.keys():
|
||||
|
@ -1,30 +1,29 @@
|
||||
|
||||
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
|
||||
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
|
||||
QDoubleSpinBox, QComboBox, QSpinBox
|
||||
QDoubleSpinBox, QComboBox, QSpinBox, QCheckBox
|
||||
from PyQt5.QtCore import pyqtSlot
|
||||
|
||||
import numpy as np
|
||||
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
from DataProvider import DataProvider
|
||||
from Controller import Controller
|
||||
from redetector import detect_spiketimes
|
||||
|
||||
|
||||
class SpikeRedetectGui(QWidget):
|
||||
|
||||
def __init__(self, data_provider: DataProvider):
|
||||
def __init__(self, controller: Controller):
|
||||
super().__init__()
|
||||
self.data_provider = data_provider
|
||||
self.controller = controller
|
||||
self.title = 'Spike Redetection'
|
||||
self.left = 10
|
||||
self.top = 10
|
||||
self.width = 640
|
||||
self.height = 400
|
||||
self.width = 1500
|
||||
self.height = 800
|
||||
|
||||
self.trial_idx = 0
|
||||
|
||||
self.initUI()
|
||||
|
||||
def initUI(self):
|
||||
@ -35,7 +34,6 @@ class SpikeRedetectGui(QWidget):
|
||||
middle = QHBoxLayout()
|
||||
|
||||
# Canvas Area for matplotlib figure
|
||||
|
||||
plot_area = QFrame()
|
||||
plot_area_layout = QVBoxLayout()
|
||||
self.canvas = PlotCanvas(self)
|
||||
@ -48,26 +46,26 @@ class SpikeRedetectGui(QWidget):
|
||||
plot_area_buttons.setLayout(plot_area_buttons_layout)
|
||||
plot_area_layout.addWidget(plot_area_buttons)
|
||||
|
||||
button = QPushButton('Button1', self)
|
||||
button = QPushButton('previous stimulus value', self)
|
||||
button.setToolTip('A nice button!')
|
||||
button.clicked.connect(lambda: self.threshold_spinbox.setValue(1))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button2', self)
|
||||
button = QPushButton('previous trial', self)
|
||||
button.setToolTip('Another nice button!')
|
||||
button.clicked.connect(lambda: self.threshold_spinbox.setValue(2))
|
||||
button.clicked.connect(lambda: self.trial_change(self.trial_idx - 1))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button3', self)
|
||||
button = QPushButton('next trial', self)
|
||||
button.setToolTip('Even more nice buttons!')
|
||||
button.clicked.connect(lambda: self.threshold_spinbox.setValue(3))
|
||||
button.clicked.connect(lambda: self.trial_change(self.trial_idx + 1))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Button4', self)
|
||||
button.setToolTip('Even more nice buttons!')
|
||||
button = QPushButton('next stimulus value', self)
|
||||
button.setToolTip('But this is not a nice button!')
|
||||
button.clicked.connect(lambda: self.threshold_spinbox.setValue(4))
|
||||
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
plot_area_buttons_layout.addWidget(button)
|
||||
@ -85,40 +83,43 @@ class SpikeRedetectGui(QWidget):
|
||||
|
||||
|
||||
|
||||
stim_val_label = QLabel("Stimulus value:")
|
||||
self.stim_val_box = QComboBox()
|
||||
|
||||
repro_label = QLabel("Repro:")
|
||||
panel_layout.addWidget(repro_label)
|
||||
self.repro_box = QComboBox()
|
||||
|
||||
for repro in self.data_provider.get_repros():
|
||||
self.repro_box.addItem(repro)
|
||||
self.repro_box = QComboBox()
|
||||
self.repro_box.addItems(self.controller.get_repros())
|
||||
self.repro_box.currentTextChanged.connect(self.repro_change)
|
||||
|
||||
panel_layout.addWidget(self.repro_box)
|
||||
|
||||
stim_val_label = QLabel("Stimulus value:")
|
||||
panel_layout.addWidget(stim_val_label)
|
||||
|
||||
self.stim_val_box = QComboBox()
|
||||
panel_layout.addWidget(self.stim_val_box)
|
||||
|
||||
trial_label = QLabel("Trial:")
|
||||
panel_layout.addWidget(trial_label)
|
||||
trial_spinbox = QSpinBox(self)
|
||||
trial_spinbox.setValue(1)
|
||||
trial_spinbox.setSingleStep(1)
|
||||
trial_spinbox.valueChanged.connect(lambda: self.trial_change(trial_spinbox.value()))
|
||||
panel_layout.addWidget(trial_spinbox)
|
||||
self.grouping_checkbox = QCheckBox()
|
||||
self.grouping_checkbox.setText("Group by Stimulus Value:")
|
||||
panel_layout.addWidget(self.grouping_checkbox)
|
||||
|
||||
filler = QFill(minh=200)
|
||||
panel_layout.addWidget(filler)
|
||||
|
||||
self.status_label = QLabel("Done x/15 Stimulus Values")
|
||||
panel_layout.addWidget(self.status_label)
|
||||
trial_label = QLabel("Trial:")
|
||||
panel_layout.addWidget(trial_label)
|
||||
self.trial_spinbox = QSpinBox(self)
|
||||
self.trial_spinbox.setValue(self.trial_idx)
|
||||
self.trial_spinbox.setSingleStep(1)
|
||||
self.trial_spinbox.valueChanged.connect(lambda: self.trial_change(self.trial_spinbox.value()))
|
||||
panel_layout.addWidget(self.trial_spinbox)
|
||||
|
||||
# self.status_label = QLabel("Done x/15 Stimulus Values")
|
||||
# panel_layout.addWidget(self.status_label)
|
||||
filler = QFill()
|
||||
panel_layout.addWidget(filler)
|
||||
|
||||
threshold_label = QLabel("Threshold:")
|
||||
panel_layout.addWidget(threshold_label)
|
||||
self.threshold_spinbox = QDoubleSpinBox(self)
|
||||
self.threshold_spinbox.setValue(self.data_provider.base_detection_parameters[0])
|
||||
self.threshold_spinbox.setValue(self.controller.base_detection_parameters[0])
|
||||
self.threshold_spinbox.setSingleStep(0.5)
|
||||
self.threshold_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||
panel_layout.addWidget(self.threshold_spinbox)
|
||||
@ -127,7 +128,7 @@ class SpikeRedetectGui(QWidget):
|
||||
panel_layout.addWidget(window_label)
|
||||
self.window_spinbox = QSpinBox(self)
|
||||
self.window_spinbox.setMaximum(2**21)
|
||||
self.window_spinbox.setValue(self.data_provider.base_detection_parameters[1])
|
||||
self.window_spinbox.setValue(self.controller.base_detection_parameters[1])
|
||||
self.window_spinbox.setSingleStep(500)
|
||||
self.window_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||
panel_layout.addWidget(self.window_spinbox)
|
||||
@ -136,43 +137,77 @@ class SpikeRedetectGui(QWidget):
|
||||
panel_layout.addWidget(step_label)
|
||||
self.step_spinbox = QSpinBox(self)
|
||||
self.step_spinbox.setMaximum(2 ** 21)
|
||||
self.step_spinbox.setValue(self.data_provider.base_detection_parameters[2])
|
||||
self.step_spinbox.setValue(self.controller.base_detection_parameters[2])
|
||||
self.step_spinbox.setSingleStep(200)
|
||||
self.step_spinbox.valueChanged.connect(lambda: self.redetection_changed())
|
||||
panel_layout.addWidget(self.step_spinbox)
|
||||
|
||||
button = QPushButton('Accept!', self)
|
||||
button.setToolTip('Accept the threshold for current stimulus value')
|
||||
button.clicked.connect(self.accept_redetection)
|
||||
panel_layout.addWidget(button)
|
||||
|
||||
button = QPushButton('Save thresholds!', self)
|
||||
button.setToolTip('Save accepted thresholds!')
|
||||
button.clicked.connect(self.save_threshold_parameters)
|
||||
panel_layout.addWidget(button)
|
||||
|
||||
panel.setLayout(panel_layout)
|
||||
middle.addWidget(panel)
|
||||
|
||||
self.setLayout(middle)
|
||||
|
||||
# init the first repro:
|
||||
self.repro_change()
|
||||
|
||||
self.show()
|
||||
|
||||
@pyqtSlot()
|
||||
def redetection_changed(self):
|
||||
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
|
||||
|
||||
@pyqtSlot()
|
||||
def trial_change(self, new_trial_idx):
|
||||
# TODO test if in range of trials!
|
||||
if new_trial_idx == self.trial_idx:
|
||||
return
|
||||
|
||||
traces, spiketimes, recording_times = self.controller.get_traces_with_spiketimes(self.repro_box.currentText())
|
||||
if not 0 <= new_trial_idx < len(spiketimes):
|
||||
print("trial_change():Out of range trial index!")
|
||||
return
|
||||
self.trial_idx = new_trial_idx
|
||||
|
||||
if new_trial_idx != self.trial_spinbox.value():
|
||||
self.trial_spinbox.setValue(new_trial_idx)
|
||||
|
||||
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
|
||||
|
||||
@pyqtSlot()
|
||||
def repro_change(self):
|
||||
repro = self.repro_box.currentText()
|
||||
|
||||
# reset trials and stim values
|
||||
self.trial_change(0)
|
||||
self.stim_val_box.clear()
|
||||
|
||||
for val in self.data_provider.get_stim_values(repro):
|
||||
for val in self.controller.get_stim_values(repro):
|
||||
self.stim_val_box.addItem(str(val))
|
||||
|
||||
redetection = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.data_provider, redetection)
|
||||
self.canvas.plot(self.trial_idx, self.repro_box.currentText(), self.controller, redetection)
|
||||
|
||||
@pyqtSlot()
|
||||
def accept_redetection(self):
|
||||
params = (self.threshold_spinbox.value(), self.window_spinbox.value(), self.step_spinbox.value())
|
||||
self.controller.set_redetection_params(self.repro_box.currentText(), self.trial_idx, params)
|
||||
|
||||
self.trial_change(self.trial_idx + 1)
|
||||
|
||||
@pyqtSlot()
|
||||
def save_threshold_parameters(self):
|
||||
self.controller.save_parameters()
|
||||
|
||||
|
||||
class PlotCanvas(FigureCanvas):
|
||||
@ -189,6 +224,7 @@ class PlotCanvas(FigureCanvas):
|
||||
QSizePolicy.Expanding)
|
||||
FigureCanvas.updateGeometry(self)
|
||||
|
||||
self.current_repro = None
|
||||
self.mouse_button_pressed = False
|
||||
self.mouse_button = "-1"
|
||||
|
||||
@ -249,7 +285,6 @@ class PlotCanvas(FigureCanvas):
|
||||
self.axes.set_ylim(new_ylimits)
|
||||
self.draw()
|
||||
|
||||
|
||||
def onclick(self, event):
|
||||
|
||||
if event.button in (1, 3):
|
||||
@ -272,7 +307,7 @@ class PlotCanvas(FigureCanvas):
|
||||
self.mouse_button_pressed = False
|
||||
|
||||
@pyqtSlot()
|
||||
def plot(self, trial_idx, repro, data_provider: DataProvider, redetection_vars):
|
||||
def plot(self, trial_idx, repro, data_provider: Controller, redetection_vars):
|
||||
traces, spiketimes, recording_times = data_provider.get_traces_with_spiketimes(repro)
|
||||
trace = traces[trial_idx]
|
||||
spiketimes = spiketimes[trial_idx]
|
||||
@ -289,8 +324,12 @@ class PlotCanvas(FigureCanvas):
|
||||
redetect = detect_spiketimes(time, trace, redetection_vars[0], redetection_vars[1], redetection_vars[2])
|
||||
ax.eventplot(redetect, lineoffsets=max(trace) + 2, colors="red")
|
||||
ax.set_title('Trial XYZ')
|
||||
ax.set_xlim(xlim)
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
if self.current_repro == repro:
|
||||
ax.set_xlim(xlim)
|
||||
ax.set_ylim(ylim)
|
||||
else:
|
||||
self.current_repro = repro
|
||||
self.draw()
|
||||
|
||||
|
||||
|
18
main.py
18
main.py
@ -1,17 +1,29 @@
|
||||
|
||||
import sys
|
||||
from PyQt5.QtWidgets import QApplication
|
||||
import os
|
||||
|
||||
from DataProvider import DataProvider
|
||||
from PyQt5.QtWidgets import QApplication, QFileDialog
|
||||
|
||||
from Controller import Controller
|
||||
from SpikeRedetectGui import SpikeRedetectGui
|
||||
|
||||
test_file = "../neuronModel/data/final_sam/2010-11-08-al-invivo-1"
|
||||
|
||||
|
||||
def main():
|
||||
app = QApplication(sys.argv)
|
||||
data_provider = DataProvider("../neuronModel/data/final_sam/2010-11-08-al-invivo-1")
|
||||
data_path = QFileDialog.getExistingDirectory(caption='Select a directory')
|
||||
if os.path.isdir(data_path):
|
||||
data_provider = Controller(data_path)
|
||||
else:
|
||||
print("Cell loading didn't work!\n Cell folder: {}".format(data_path))
|
||||
return
|
||||
#data_provider = Controller(test_file)
|
||||
ex = SpikeRedetectGui(data_provider)
|
||||
sys.exit(app.exec_())
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -1,6 +1,6 @@
|
||||
|
||||
from DatParser import DatParser
|
||||
from DataProvider import DataProvider, average_spike_height
|
||||
from Controller import Controller, average_spike_height
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
@ -18,7 +18,7 @@ def main():
|
||||
if cell in failure_to_read_sam:
|
||||
continue
|
||||
cell_folder = os.path.join(DATA_FOLDER, cell)
|
||||
data_provider = DataProvider(cell_folder)
|
||||
data_provider = Controller(cell_folder)
|
||||
|
||||
repros = test_getting_repros(data_provider)
|
||||
|
||||
@ -42,7 +42,7 @@ def main():
|
||||
plt.close()
|
||||
|
||||
|
||||
def test_loading_spikes(data_provider: DataProvider, repro):
|
||||
def test_loading_spikes(data_provider: Controller, repro):
|
||||
return data_provider.parser.get_spiketimes(repro)
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ def test_loading_traces(data_provider, repro):
|
||||
return data_provider.get_traces(repro)
|
||||
|
||||
|
||||
def test_getting_repros(data_provider: DataProvider):
|
||||
def test_getting_repros(data_provider: Controller):
|
||||
return data_provider.get_repros()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user