P-unit_model/parser/DataParserFactory.py
2021-02-13 11:36:17 +01:00

482 lines
18 KiB
Python

from os.path import isdir, exists
from warnings import warn
import pyrelacs.DataLoader as Dl
from models.AbstractModel import AbstractModel
import numpy as np
UNKNOWN = -1
DAT_FORMAT = 0
NIX_FORMAT = 1
MODEL = 2
class AbstractParser:
# def cell_get_metadata(self):
# raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_baseline_length(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def has_sam_recordings(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_curve_contrasts(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_baseline_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_baseline_spiketimes(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_curve_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_curve_spiketimes(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_frequency_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_sam_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_sam_info(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_sampling_interval(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_recording_times(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def traces_available(self) -> bool:
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def spiketimes_available(self) -> bool:
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def frequencies_available(self) -> bool:
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
class DatParser(AbstractParser):
def __init__(self, dir_path):
self.base_path = dir_path
self.info_file = self.base_path + "/info.dat"
self.fi_file = self.base_path + "/fispikes1.dat"
self.baseline_file = self.base_path + "/basespikes1.dat"
self.sam_file = self.base_path + "/samallspikes1.dat"
self.stimuli_file = self.base_path + "/stimuli.dat"
self.__test_data_file_existence__()
self.fi_recording_times = []
self.sampling_interval = -1
def has_sam_recordings(self):
return exists(self.sam_file)
def get_measured_repros(self):
repros = []
for metadata, key, data in Dl.iload(self.stimuli_file):
repros.extend([d["repro"] for d in metadata if "repro" in d.keys()])
return sorted(np.unique(repros))
def get_baseline_length(self):
lengths = []
for metadata, key, data in Dl.iload(self.baseline_file):
if len(metadata) != 0:
lengths.append(float(metadata[0]["duration"][:-3]))
return lengths
def get_species(self):
species = ""
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
species = metadata["Species"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
species = metadata["Subject"]["Species"]
return species
def get_gender(self):
gender = "not found"
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
gender = metadata["Gender"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Gender" in metadata["Subject"].keys():
gender = metadata["Subject"]["Gender"]
return gender
def get_quality(self):
quality = ""
for metadata in Dl.load(self.info_file):
if "Recording quality" in metadata.keys():
quality = metadata["Recording quality"]
elif "Recording" in metadata.keys():
if isinstance(metadata["Recording"], dict) and "Recording quality" in metadata["Recording"].keys():
quality = metadata["Recording"]["Recording quality"]
return quality
def get_cell_type(self):
type = ""
for metadata in Dl.load(self.info_file):
if len(metadata.keys()) < 3:
return ""
if "CellType" in metadata.keys():
type = metadata["CellType"]
elif "Cell" in metadata.keys():
if isinstance(metadata["Cell"], dict) and "CellType" in metadata["Cell"].keys():
type = metadata["Cell"]["CellType"]
return type
def get_fish_size(self):
size = ""
for metadata in Dl.load(self.info_file):
if "Species" in metadata.keys():
size = metadata["Size"]
elif "Subject" in metadata.keys():
if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
size = metadata["Subject"]["Size"]
return size[:-2]
def get_fi_curve_contrasts(self):
"""
:return: list of tuples [(contrast, #_of_trials), ...]
"""
contrasts = []
contrast = [-1, float("nan")]
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
if contrast[0] != -1:
contrasts.append(contrast)
contrast = [-1, 1]
contrast[0] = float(metadata[-1]["intensity"][:-2])
else:
contrast[1] += 1
return np.array(contrasts)
def traces_available(self) -> bool:
return True
def frequencies_available(self) -> bool:
return False
def spiketimes_available(self) -> bool:
return True
def get_sampling_interval(self):
if self.sampling_interval == -1:
self.__read_sampling_interval__()
return self.sampling_interval
def get_recording_times(self):
if len(self.fi_recording_times) == 0:
self.__read_fi_recording_times__()
return self.fi_recording_times
def get_baseline_traces(self):
return self.__get_traces__("BaselineActivity")
def get_baseline_spiketimes(self):
# TODO change: reading from file -> detect from v1 trace
spiketimes = []
warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")
for metadata, key, data in Dl.iload(self.baseline_file):
spikes = np.array(data[:, 0]) / 1000 # timestamps are saved in ms -> conversion to seconds
spiketimes.append(spikes)
return spiketimes
def get_fi_curve_traces(self):
return self.__get_traces__("FICurve")
def get_fi_frequency_traces(self):
raise NotImplementedError("Not possible in .dat data type.\n"
"Please check availability with the x_available functions.")
# TODO clean up/ rewrite
def get_fi_curve_spiketimes(self):
spiketimes = []
pre_intensities = []
pre_durations = []
intensities = []
trans_amplitudes = []
pre_duration = -1
index = -1
skip = False
trans_amplitude = float('nan')
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
metadata_index = 0
if '----- Control --------------------------------------------------------' in metadata[0].keys():
metadata_index = 1
pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
if pre_duration == 0:
skip = False
else:
skip = True
continue
else:
if "preduration" in metadata[0].keys():
pre_duration = float(metadata[0]["preduration"][:-2])
trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
if pre_duration == 0:
skip = False
else:
skip = True
continue
if skip:
continue
if 'intensity' in metadata[metadata_index].keys():
intensity = float(metadata[metadata_index]['intensity'][:-2])
pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])
else:
intensity = float(metadata[1-metadata_index]['intensity'][:-2])
pre_intensity = float(metadata[1-metadata_index]['preintensity'][:-2])
intensities.append(intensity)
pre_durations.append(pre_duration)
pre_intensities.append(pre_intensity)
trans_amplitudes.append(trans_amplitude)
spiketimes.append([])
index += 1
if skip:
continue
if data.shape[1] != 1:
raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")
spike_time_data = data[:, 0]/1000
if len(spike_time_data) < 10:
print("# ignoring spike-train that contains less than 10 spikes.")
continue
if spike_time_data[-1] < 1:
print("# ignoring spike-train that ends before one second.")
continue
spiketimes[index].append(spike_time_data)
# TODO Check if sorting works!
new_order = np.arange(0, len(intensities), 1)
intensities, new_order = zip(*sorted(zip(intensities, new_order)))
intensities = list(intensities)
spiketimes = [spiketimes[i] for i in new_order]
trans_amplitudes = [trans_amplitudes[i] for i in new_order]
for i in range(len(intensities)-1, -1, -1):
if len(spiketimes[i]) < 3:
del intensities[i]
del spiketimes[i]
del trans_amplitudes[i]
return trans_amplitudes, intensities, spiketimes
def get_sam_traces(self):
return self.__get_traces__("SAM")
def get_sam_info(self):
contrasts = []
delta_fs = []
spiketimes = []
durations = []
eod_freqs = []
trans_amplitudes = []
index = -1
for metadata, key, data in Dl.iload(self.sam_file):
factor = 1
if key[0][0] == 'time':
if key[1][0] == 'ms':
factor = 1/1000
elif key[1][0] == 's':
factor = 1
else:
print("DataParser Dat: Unknown time notation:", key[1][0])
if len(metadata) != 0:
if not "----- Stimulus -------------------------------------------------------" in metadata[0].keys():
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(metadata[0]["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
contrast = float(metadata[0]["contrast"][:-1]) # in percent
delta_f = float(metadata[0]["deltaf"][:-2])
else:
stimulus_dict = metadata[0]["----- Stimulus -------------------------------------------------------"]
analysis_dict = metadata[0]["----- Analysis -------------------------------------------------------"]
eod_freq = float(metadata[0]["EOD rate"][:-2]) # in Hz
trans_amplitude = metadata[0]["trans. amplitude"][:-2] # in mV
duration = float(stimulus_dict["duration"][:-2]) * factor # normally saved in ms? so change it with the factor
contrast = float(stimulus_dict["contrast"][:-1]) # in percent
delta_f = float(stimulus_dict["deltaf"][:-2])
# delta_f = metadata[0]["true deltaf"]
# contrast = metadata[0]["true contrast"]
contrasts.append(contrast)
delta_fs.append(delta_f)
durations.append(duration)
eod_freqs.append(eod_freq)
trans_amplitudes.append(trans_amplitude)
spiketimes.append([])
index += 1
if data.shape[1] != 1:
raise RuntimeError("DatParser:get_sam_spiketimes():\n read data has more than one dimension!")
spike_time_data = data[:, 0] * factor # saved in ms so use the factor to change it.
if len(spike_time_data) < 10:
continue
if spike_time_data[-1] < 0.1:
print("# ignoring spike-train that ends before one tenth of a second.")
continue
spiketimes[index].append(spike_time_data)
return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
def __get_traces__(self, repro):
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
warn(warn_msg)
return traces
def __iget_traces__(self, repro):
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
# time, v1, eod, local_eod, stimulus
yield time, x[0], x[1], x[2], x[3]
def __read_fi_recording_times__(self):
delays = []
stim_duration = []
pause = []
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
control_key = '----- Control --------------------------------------------------------'
if control_key in metadata[0].keys():
delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
stim_key = "----- Test-Intensities -----------------------------------------------"
stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)
if "pause" in metadata[0].keys():
delays.append(float(metadata[0]["delay"][:-2]) / 1000)
pause.append(float(metadata[0]["pause"][:-2]) / 1000)
stim_duration.append(float(metadata[0]["duration"][:-2]) / 1000)
for l in [delays, stim_duration, pause]:
if len(l) == 0:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
elif len(set(l)) != 1:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
else:
self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]
def __read_sampling_interval__(self):
stop = False
sampling_intervals = []
for metadata, key, data in Dl.iload(self.stimuli_file):
for md in metadata:
for i in range(4):
key = "sample interval" + str(i+1)
if key in md.keys():
sampling_intervals.append(float(md[key][:-2]) / 1000)
stop = True
else:
break
if stop:
break
if len(sampling_intervals) == 0:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not found in stimuli.dat this is not handled!\n" +
"with File:" + self.base_path)
if len(set(sampling_intervals)) != 1:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not the same for all traces this is not handled!\n" +
"with File:" + self.base_path)
else:
self.sampling_interval = sampling_intervals[0]
def __test_data_file_existence__(self):
if not exists(self.stimuli_file):
raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
if not exists(self.fi_file):
raise FileNotFoundError(self.fi_file + " file doesn't exist!")
if not exists(self.baseline_file):
raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
# if not exists(self.sam_file):
# raise RuntimeError(self.sam_file + " file doesn't exist!")
def get_parser(data_path) -> AbstractParser:
data_format = __test_for_format__(data_path)
if data_format == DAT_FORMAT:
return DatParser(data_path)
elif data_format == NIX_FORMAT:
raise NotImplementedError("DataParserFactory:get_parser(data_path): nix format doesn't have a parser yet")
elif data_format == MODEL:
raise NotImplementedError("DataParserFactory:get_parser(data_path): Model doesn't have a parser yet")
elif data_format == UNKNOWN:
raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path)
def __test_for_format__(data_path):
if isinstance(data_path, AbstractModel):
return MODEL
if isdir(data_path):
if exists(data_path + "/fispikes1.dat"):
return DAT_FORMAT
elif data_path.endswith(".nix"):
return NIX_FORMAT
else:
return UNKNOWN