343 lines
12 KiB
Python
343 lines
12 KiB
Python
import DataParserFactory as dpf
|
|
from warnings import warn
|
|
import os
|
|
import helperFunctions as hf
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
COUNT = 0
|
|
|
|
|
|
def icelldata_of_dir(base_path, test_for_v1_trace=True):
|
|
global COUNT
|
|
for item in sorted(os.listdir(base_path)):
|
|
item_path = base_path + item
|
|
|
|
if not os.path.isdir(item_path) and not item.endswith(".nix"):
|
|
print("ignoring path: " + item_path)
|
|
print("It isn't expected to be cell data.")
|
|
continue
|
|
|
|
try:
|
|
data = CellData(item_path)
|
|
if test_for_v1_trace:
|
|
try:
|
|
trace = data.get_base_traces(trace_type=data.V1)
|
|
if len(trace) == 0:
|
|
|
|
COUNT += 1
|
|
print("NO V1 TRACE FOUND: ", item_path)
|
|
print(COUNT)
|
|
continue
|
|
except IndexError as e:
|
|
COUNT += 1
|
|
print(data.get_data_path(), "Threw Index error!")
|
|
print(COUNT)
|
|
print(str(e), "\n")
|
|
continue
|
|
except ValueError as e:
|
|
COUNT += 1
|
|
print(data.get_data_path(), "Threw Value error!")
|
|
print(COUNT)
|
|
print(str(e), "\n")
|
|
|
|
yield data
|
|
else:
|
|
yield data
|
|
|
|
except TypeError as e:
|
|
warn_msg = str(e)
|
|
warn(warn_msg)
|
|
|
|
print("Currently throw errors: {}".format(COUNT))
|
|
|
|
|
|
class CellData:
|
|
# Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.)
|
|
# should be abstract from the way the data is saved in the background .dat vs .nix
|
|
|
|
# traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]]
|
|
TIME = 0
|
|
V1 = 1
|
|
EOD = 2
|
|
LOCAL_EOD = 3
|
|
STIMULUS = 4
|
|
|
|
def __init__(self, data_path):
|
|
self.data_path = data_path
|
|
self.parser = dpf.get_parser(data_path)
|
|
|
|
self.base_traces = None
|
|
self.base_spikes = None
|
|
self.fi_traces = None
|
|
self.fi_intensities = None
|
|
self.fi_spiketimes = None
|
|
self.fi_trans_amplitudes = None
|
|
self.mean_isi_frequencies = None
|
|
self.time_axes = None
|
|
# self.metadata = None
|
|
|
|
self.sam_spiketimes = None
|
|
self.sam_contrasts = None
|
|
self.sam_delta_fs = None
|
|
self.sam_eod_freqs = None
|
|
self.sam_durations = None
|
|
self.sam_trans_amplitudes = None
|
|
|
|
self.sampling_interval = None
|
|
self.recording_times = None
|
|
|
|
def get_data_path(self):
|
|
return self.data_path
|
|
|
|
def get_cell_name(self):
|
|
return os.path.basename(self.data_path)
|
|
|
|
def get_baseline_length(self):
|
|
return self.parser.get_baseline_length()
|
|
|
|
def get_fi_curve_contrasts_with_trial_number(self):
|
|
return self.parser.get_fi_curve_contrasts()
|
|
|
|
def get_base_traces(self, trace_type=None):
|
|
if self.base_traces is None:
|
|
self.base_traces = self.parser.get_baseline_traces()
|
|
|
|
if trace_type is None:
|
|
return self.base_traces
|
|
else:
|
|
return self.base_traces[trace_type]
|
|
|
|
def get_base_spikes(self, threshold=2.5, min_length=5000, split_step=1000, re_calculate=False, only_first=False):
|
|
if self.base_spikes is not None and not re_calculate:
|
|
return self.base_spikes
|
|
|
|
saved_spikes_file = "base_spikes_ndarray.npy"
|
|
full_path = os.path.join(self.data_path, saved_spikes_file)
|
|
if os.path.isdir(self.data_path) and os.path.exists(full_path) and not re_calculate:
|
|
self.base_spikes = np.load(full_path, allow_pickle=True)
|
|
print("Baseline spikes loaded from file.")
|
|
return self.base_spikes
|
|
|
|
if self.base_spikes is None or re_calculate:
|
|
print("Baseline spikes are being (re-)calculated...")
|
|
times = self.get_base_traces(self.TIME)
|
|
v1_traces = self.get_base_traces(self.V1)
|
|
spiketimes = []
|
|
for i in range(len(times)):
|
|
if only_first and i > 0:
|
|
break
|
|
spiketimes.append(hf.detect_spiketimes(times[i], v1_traces[i], threshold=threshold, min_length=min_length, split_step=split_step))
|
|
|
|
# plt.plot(times[0], v1_traces[0])
|
|
# idx_pos = np.array(spiketimes) / self.get_sampling_interval()
|
|
# idx_pos = np.array(np.rint(idx_pos), np.int)
|
|
#
|
|
# plt.plot(spiketimes[0], np.array(v1_traces[0])[idx_pos][0, :], 'o')
|
|
# plt.show()
|
|
|
|
self.base_spikes = np.array(spiketimes)
|
|
|
|
if os.path.isdir(self.data_path):
|
|
np.save(full_path, self.base_spikes)
|
|
print("Calculated spikes saved to file")
|
|
|
|
return self.base_spikes
|
|
|
|
def get_base_isis(self):
|
|
spikestimes = self.get_base_spikes()
|
|
|
|
isis = []
|
|
for spikes in spikestimes:
|
|
isis.extend(np.diff(spikes))
|
|
|
|
return isis
|
|
|
|
def get_fi_traces(self):
|
|
if self.fi_traces is None:
|
|
warn("Fi traces not sorted in the same way as the spiketimes!!!")
|
|
self.fi_traces = self.parser.get_fi_curve_traces()
|
|
return self.fi_traces
|
|
|
|
def get_fi_spiketimes(self):
|
|
self.__read_fi_spiketimes_info__()
|
|
return self.fi_spiketimes
|
|
|
|
def get_fi_intensities(self):
|
|
self.__read_fi_spiketimes_info__()
|
|
return self.fi_intensities
|
|
|
|
def get_fi_contrasts(self):
|
|
if self.fi_intensities is None:
|
|
self.__read_fi_spiketimes_info__()
|
|
contrast = []
|
|
for i in range(len(self.fi_intensities)):
|
|
|
|
contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i])
|
|
|
|
return contrast
|
|
|
|
def get_sam_spiketimes(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_spiketimes
|
|
|
|
def get_sam_contrasts(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_contrasts
|
|
|
|
def get_sam_delta_frequencies(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_delta_fs
|
|
|
|
def get_sam_durations(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_durations
|
|
|
|
def get_sam_eod_frequencies(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_eod_freqs
|
|
|
|
def get_sam_trans_amplitudes(self):
|
|
self.__read_sam_info__()
|
|
return self.sam_trans_amplitudes
|
|
|
|
def get_mean_fi_curve_isi_frequencies(self):
|
|
if self.mean_isi_frequencies is None:
|
|
self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces(
|
|
self.get_fi_spiketimes(), self.get_sampling_interval())
|
|
|
|
return self.mean_isi_frequencies
|
|
|
|
def get_time_axes_fi_curve_mean_frequencies(self):
|
|
if self.time_axes is None:
|
|
self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces(
|
|
self.get_fi_spiketimes(), self.get_sampling_interval())
|
|
|
|
return self.time_axes
|
|
|
|
def get_base_frequency(self):
|
|
base_freqs = []
|
|
for freq in self.get_mean_fi_curve_isi_frequencies():
|
|
delay = self.get_delay()
|
|
sampling_interval = self.get_sampling_interval()
|
|
if delay < 0.1:
|
|
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
|
|
|
|
idx_start = int(0.025 / sampling_interval)
|
|
idx_end = int((delay - 0.025) / sampling_interval)
|
|
base_freqs.append(np.mean(freq[idx_start:idx_end]))
|
|
|
|
return np.median(base_freqs)
|
|
|
|
def get_sampling_interval(self) -> float:
|
|
if self.sampling_interval is None:
|
|
self.sampling_interval = self.parser.get_sampling_interval()
|
|
return self.sampling_interval
|
|
|
|
def get_recording_times(self) -> list:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times
|
|
|
|
def get_time_start(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times[0]
|
|
|
|
def get_delay(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return abs(self.recording_times[0])
|
|
|
|
def get_time_end(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times[2] + self.recording_times[3]
|
|
|
|
def get_stimulus_start(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times[1]
|
|
|
|
def get_stimulus_duration(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times[2]
|
|
|
|
def get_stimulus_end(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.get_stimulus_start() + self.get_stimulus_duration()
|
|
|
|
def get_after_stimulus_duration(self) -> float:
|
|
if self.recording_times is None:
|
|
self.recording_times = self.parser.get_recording_times()
|
|
return self.recording_times[3]
|
|
|
|
def get_eod_frequency(self, recalculate=False):
|
|
eod_freq_file_name = "eod_freq_peak_based.npy"
|
|
eod_freq_file_path = os.path.join(self.get_data_path(), eod_freq_file_name)
|
|
if os.path.exists(eod_freq_file_path) and not recalculate:
|
|
print("Loaded eod_freq from file")
|
|
return np.load(eod_freq_file_path)
|
|
else:
|
|
eods = self.get_base_traces(self.EOD)
|
|
sampling_interval = self.get_sampling_interval()
|
|
frequencies = []
|
|
for eod in eods:
|
|
frequencies.append(hf.calculate_eod_frequency(eod, sampling_interval))
|
|
mean_freq = np.mean(frequencies)
|
|
np.save(eod_freq_file_path, mean_freq)
|
|
print("Saved eod freq to file.")
|
|
return mean_freq
|
|
|
|
def __read_fi_spiketimes_info__(self):
|
|
if self.fi_spiketimes is None:
|
|
self.fi_trans_amplitudes, self.fi_intensities, self.fi_spiketimes = self.parser.get_fi_curve_spiketimes()
|
|
|
|
if os.path.exists(self.get_data_path() + "/redetected_spikes.npy"):
|
|
print("overwriting fi_spiketimes with redetected ones.")
|
|
contrasts = self.get_fi_contrasts()
|
|
spikes = np.load(self.get_data_path() + "/redetected_spikes.npy", allow_pickle=True)
|
|
trace_contrasts_idx = np.load(self.get_data_path() + "/fi_traces_contrasts.npy", allow_pickle=True)
|
|
trace_max_similarity = np.load(self.get_data_path() + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)
|
|
spiketimes = []
|
|
for i in range(len(contrasts)):
|
|
contrast_list = []
|
|
|
|
for j in range(len(trace_contrasts_idx)):
|
|
if trace_contrasts_idx[j] == i and trace_max_similarity[j][0] > trace_max_similarity[j][1] + 0.15:
|
|
contrast_list.append(spikes[j])
|
|
|
|
spiketimes.append(contrast_list)
|
|
|
|
self.fi_spiketimes = spiketimes
|
|
|
|
def __read_sam_info__(self):
|
|
if self.sam_spiketimes is None:
|
|
spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes = self.parser.get_sam_info()
|
|
|
|
self.sam_spiketimes = spiketimes
|
|
self.sam_contrasts = contrasts
|
|
self.sam_delta_fs = delta_fs
|
|
self.sam_eod_freqs = eod_freqs
|
|
self.sam_durations = durations
|
|
self.sam_trans_amplitudes = trans_amplitudes
|
|
|
|
# def get_metadata(self):
|
|
# self.__read_metadata__()
|
|
# return self.metadata
|
|
#
|
|
# def get_metadata_item(self, item):
|
|
# self.__read_metadata__()
|
|
# if item in self.metadata.keys():
|
|
# return self.metadata[item]
|
|
# else:
|
|
# raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item))
|
|
#
|
|
# def __read_metadata__(self):
|
|
# if self.metadata is None:
|
|
# # TODO!!
|
|
# pass
|