remove redetector from neuron model and create its own project

This commit is contained in:
alexanderott 2021-02-13 15:24:01 +01:00
commit 4ba7f6a1b3
4 changed files with 568 additions and 0 deletions

398
DatParser.py Normal file
View File

@ -0,0 +1,398 @@
from os.path import isdir, exists
from warnings import warn
import numpy as np
import pyrelacs.DataLoader as Dl
class DatParser():
def __init__(self, dir_path):
self.base_path = dir_path
self.info_file = self.base_path + "/info.dat"
self.fi_file = self.base_path + "/fispikes1.dat"
self.baseline_file = self.base_path + "/basespikes1.dat"
self.sam_file = self.base_path + "/samallspikes1.dat"
self.stimuli_file = self.base_path + "/stimuli.dat"
self.__test_data_file_existence__()
self.fi_recording_times = []
self.sampling_interval = -1
def has_sam_recordings(self):
return exists(self.sam_file)
def get_measured_repros(self):
repros = []
for metadata, key, data in Dl.iload(self.stimuli_file):
repros.extend([d["repro"] for d in metadata if "repro" in d.keys()])
return sorted(np.unique(repros))
def get_baseline_length(self):
lengths = []
for metadata, key, data in Dl.iload(self.baseline_file):
if len(metadata) != 0:
lengths.append(float(metadata[0]["duration"][:-3]))
return lengths
def get_species(self):
species = ""
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
species = metadata["Species"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
species = metadata["Subject"]["Species"]
return species
def get_gender(self):
gender = "not found"
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
gender = metadata["Gender"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Gender" in metadata["Subject"].keys():
gender = metadata["Subject"]["Gender"]
return gender
def get_quality(self):
quality = ""
for metadata in Dl.load(self.info_file):
if "Recording quality" in metadata.keys():
quality = metadata["Recording quality"]
elif "Recording" in metadata.keys():
if isinstance(metadata["Recording"], dict) and "Recording quality" in metadata["Recording"].keys():
quality = metadata["Recording"]["Recording quality"]
return quality
def get_cell_type(self):
type = ""
for metadata in Dl.load(self.info_file):
if len(metadata.keys()) < 3:
return ""
if "CellType" in metadata.keys():
type = metadata["CellType"]
elif "Cell" in metadata.keys():
if isinstance(metadata["Cell"], dict) and "CellType" in metadata["Cell"].keys():
type = metadata["Cell"]["CellType"]
return type
def get_fish_size(self):
size = ""
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
size = metadata["Size"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
size = metadata["Subject"]["Size"]
return size[:-2]
def get_fi_curve_contrasts(self):
"""
:return: list of tuples [(contrast, #_of_trials), ...]
"""
contrasts = []
contrast = [-1, float("nan")]
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
if contrast[0] != -1:
contrasts.append(contrast)
contrast = [-1, 1]
contrast[0] = float(metadata[-1]["intensity"][:-2])
else:
contrast[1] += 1
return np.array(contrasts)
def traces_available(self) -> bool:
return True
def frequencies_available(self) -> bool:
return False
def spiketimes_available(self) -> bool:
return True
def get_sampling_interval(self):
if self.sampling_interval == -1:
self.__read_sampling_interval__()
return self.sampling_interval
def get_recording_times(self):
if len(self.fi_recording_times) == 0:
self.__read_fi_recording_times__()
return self.fi_recording_times
def get_baseline_traces(self):
return self.__get_traces__("BaselineActivity")
def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace
spiketimes = []
warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
for metadata, key, data in Dl.iload(self.baseline_file):
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
spiketimes.append(spikes)
return spiketimes
def get_fi_curve_traces(self):
return self.__get_traces__("FICurve")
def get_fi_frequency_traces(self):
raise NotImplementedError("Not possible in .dat data type.\n"
"Please check availability with the x_available functions.")
# TODO clean up/ rewrite
def get_fi_curve_spiketimes(self):
spiketimes = []
pre_intensities = []
pre_durations = []
intensities = []
trans_amplitudes = []
pre_duration = -1
index = -1
skip = False
trans_amplitude = float('nan')
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
metadata_index = 0
if '----- Control --------------------------------------------------------' in metadata[0].keys():
metadata_index = 1
pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
if pre_duration == 0:
skip = False
else:
skip = True
continue
else:
if "preduration" in metadata[0].keys():
pre_duration = float(metadata[0]["preduration"][:-2])
trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
if pre_duration == 0:
skip = False
else:
skip = True
continue
if skip:
continue
if 'intensity' in metadata[metadata_index].keys():
intensity = float(metadata[metadata_index]['intensity'][:-2])
pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])
else:
intensity = float(metadata[1-metadata_index]['intensity'][:-2])
pre_intensity = float(metadata[1-metadata_index]['preintensity'][:-2])
intensities.append(intensity)
pre_durations.append(pre_duration)
pre_intensities.append(pre_intensity)
trans_amplitudes.append(trans_amplitude)
spiketimes.append([])
index += 1
if skip:
continue
if data.shape[1] != 1:
raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")
spike_time_data = data[:, 0]/1000
if len(spike_time_data) < 10:
print("# ignoring spike-train that contains less than 10 spikes.")
continue
if spike_time_data[-1] < 1:
print("# ignoring spike-train that ends before one second.")
continue
spiketimes[index].append(spike_time_data)
# TODO Check if sorting works!
new_order = np.arange(0, len(intensities), 1)
intensities, new_order = zip(*sorted(zip(intensities, new_order)))
intensities = list(intensities)
spiketimes = [spiketimes[i] for i in new_order]
trans_amplitudes = [trans_amplitudes[i] for i in new_order]
for i in range(len(intensities)-1, -1, -1):
if len(spiketimes[i]) < 3:
del intensities[i]
del spiketimes[i]
del trans_amplitudes[i]
return trans_amplitudes, intensities, spiketimes
def get_sam_traces(self):
return self.__get_traces__("SAM")
def get_sam_info(self):
contrasts = []
delta_fs = []
spiketimes = []
durations = []
eod_freqs = []
trans_amplitudes = []
index = -1
for metadata, key, data in Dl.iload(self.sam_file):
factor = 1
if key[0][0] == 'time':
if key[1][0] == 'ms':
factor = 1/1000
elif key[1][0] == 's':
factor = 1
else:
print("DataParser Dat: Unknown time notation:", key[1][0])
if len(metadata) != 0:
if not "----- Stimulus -------------------------------------------------------" in metadata[0].keys():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
contrast = float(metadata[0]["contrast"][:-1]) # in percent
delta_f = float(metadata[0]["deltaf"][:-2])
else:
stimulus_dict = metadata[0]["----- Stimulus -------------------------------------------------------"]
analysis_dict = metadata[0]["----- Analysis -------------------------------------------------------"]
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
contrast = float(stimulus_dict["contrast"][:-1]) # in percent
delta_f = float(stimulus_dict["deltaf"][:-2])
# delta_f = metadata[0]["true deltaf"]
# contrast = metadata[0]["true contrast"]
contrasts.append(contrast)
delta_fs.append(delta_f)
durations.append(duration)
eod_freqs.append(eod_freq)
trans_amplitudes.append(trans_amplitude)
spiketimes.append([])
index += 1
if data.shape[1] != 1:
raise RuntimeError("DatParser:get_sam_spiketimes():\n read data has more than one dimension!")
spike_time_data = data[:, 0] * factor # saved in ms so use the factor to change it.
if len(spike_time_data) < 10:
continue
if spike_time_data[-1] < 0.1:
print("# ignoring spike-train that ends before one tenth of a second.")
continue
spiketimes[index].append(spike_time_data)
return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
def __get_traces__(self, repro):
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
warn(warn_msg)
return traces
def __iget_traces__(self, repro):
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
# time, v1, eod, local_eod, stimulus
yield time, x[0], x[1], x[2], x[3]
def __read_fi_recording_times__(self):
delays = []
stim_duration = []
pause = []
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
control_key = '----- Control --------------------------------------------------------'
if control_key in metadata[0].keys():
delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
stim_key = "----- Test-Intensities -----------------------------------------------"
stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)
if "pause" in metadata[0].keys():
delays.append(float(metadata[0]["delay"][:-2]) / 1000)
pause.append(float(metadata[0]["pause"][:-2]) / 1000)
stim_duration.append(float(metadata[0]["duration"][:-2]) / 1000)
for l in [delays, stim_duration, pause]:
if len(l) == 0:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
elif len(set(l)) != 1:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
else:
self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]
def __read_sampling_interval__(self):
stop = False
sampling_intervals = []
for metadata, key, data in Dl.iload(self.stimuli_file):
for md in metadata:
for i in range(4):
key = "sample interval" + str(i+1)
if key in md.keys():
sampling_intervals.append(float(md[key][:-2]) / 1000)
stop = True
else:
break
if stop:
break
if len(sampling_intervals) == 0:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not found in stimuli.dat this is not handled!\n" +
"with File:" + self.base_path)
if len(set(sampling_intervals)) != 1:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not the same for all traces this is not handled!\n" +
"with File:" + self.base_path)
else:
self.sampling_interval = sampling_intervals[0]
def __test_data_file_existence__(self):
if not exists(self.stimuli_file):
raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
if not exists(self.fi_file):
raise FileNotFoundError(self.fi_file + " file doesn't exist!")
if not exists(self.baseline_file):
raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
# if not exists(self.sam_file):
# raise RuntimeError(self.sam_file + " file doesn't exist!")

20
DataProvider.py Normal file
View File

@ -0,0 +1,20 @@
from DatParser import DatParser
class DataProvider:
def __init__(self, data_path):
self.data_path = data_path
# self.cell = CellData(data_path)
self.parser = DatParser(data_path)
self.thresholds = {}
def get_repros(self):
pass
def get_stim_values(self, repro):
pass
def get_trials(self, repro, stimulus_value):
pass

133
SpikeRedetectGui.py Normal file
View File

@ -0,0 +1,133 @@
from PyQt5.QtWidgets import QWidget, QPushButton, QSizePolicy, QLineEdit,\
QMessageBox, QVBoxLayout, QHBoxLayout, QGridLayout, QLabel, QFrame,\
QDoubleSpinBox, QComboBox
from PyQt5.QtCore import pyqtSlot
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
class SpikeRedetectGui(QWidget):
def __init__(self, data_provider):
super().__init__()
self.data_provider = data_provider
self.title = 'Spike Redetection'
self.left = 10
self.top = 10
self.width = 640
self.height = 400
self.initUI()
def initUI(self):
self.setWindowTitle(self.title)
self.setGeometry(self.left, self.top, self.width, self.height)
# Middle:
middle = QHBoxLayout()
# Canvas for matplotlib figure
m = PlotCanvas(self, width=5, height=4)
m.move(0, 0)
middle.addWidget(m)
middle.addWidget(QVLine())
# Side (options) panel
panel = QFrame()
panel_layout = QVBoxLayout()
button = QPushButton('Button!', self)
button.setToolTip('A nice button!')
button.clicked.connect(lambda: threshold_spinbox.setValue(1))
button.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
panel_layout.addWidget(button)
repro_label = QLabel("Repro:")
panel_layout.addWidget(repro_label)
self.repro_box = QComboBox()
self.repro_box.addItem("placeholder repro")
panel_layout.addWidget(self.repro_box)
stim_val_label = QLabel("Stimulus value:")
panel_layout.addWidget(stim_val_label)
self.stim_val_box = QComboBox()
self.stim_val_box.addItem("placeholder stim_value")
panel_layout.addWidget(self.stim_val_box)
filler = QFill(maxh=20)
panel_layout.addWidget(filler)
self.status_label = QLabel("Done x/15 Stimulus Values")
panel_layout.addWidget(self.status_label)
filler = QFill()
panel_layout.addWidget(filler)
threshold_label = QLabel("Threshold:")
panel_layout.addWidget(threshold_label)
threshold_spinbox = QDoubleSpinBox(self)
threshold_spinbox.setValue(1)
threshold_spinbox.setSingleStep(0.5)
threshold_spinbox.valueChanged.connect(lambda: m.plot(threshold_spinbox.value()))
panel_layout.addWidget(threshold_spinbox)
button = QPushButton('Accept!', self)
button.setToolTip('Accept the threshold for current stimulus value')
panel_layout.addWidget(button)
panel.setLayout(panel_layout)
middle.addWidget(panel)
self.setLayout(middle)
self.show()
class PlotCanvas(FigureCanvas):
def __init__(self, parent=None, width=5, height=4, dpi=100):
fig = Figure(figsize=(width, height), dpi=dpi)
self.axes = fig.add_subplot(111)
FigureCanvas.__init__(self, fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self,
QSizePolicy.Expanding,
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
self.plot()
@pyqtSlot()
def plot(self, mean=1):
x = np.arange(0, 1, 0.0001)
data = np.sin(x*np.pi*2*mean)
ax = self.axes
ax.clear()
ax.plot(data, 'r-')
ax.set_title('Sinus Example')
self.draw()
class QHLine(QFrame):
def __init__(self):
super(QHLine, self).__init__()
self.setFrameShape(QFrame.HLine)
self.setFrameShadow(QFrame.Sunken)
class QVLine(QFrame):
def __init__(self):
super(QVLine, self).__init__()
self.setFrameShape(QFrame.VLine)
self.setFrameShadow(QFrame.Sunken)
class QFill(QFrame):
def __init__(self, maxh=int(2**24)-1, maxw=int(2**24)-1):
super(QFill, self).__init__()
self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
self.setMaximumSize(maxw, maxh)

17
main.py Normal file
View File

@ -0,0 +1,17 @@
import sys
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton
from spike_redetection.DataProvider import DataProvider
from spike_redetection.SpikeRedetectGui import SpikeRedetectGui
def main():
app = QApplication(sys.argv)
data_provider = DataProvider("../data/final_sam/2010-11-08-al-invivo-1")
ex = SpikeRedetectGui(data_provider)
sys.exit(app.exec_())
if __name__ == '__main__':
main()