232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
|
|
from CellData import CellData
|
|
from models.LIFACnoise import LifacNoiseModel
|
|
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
|
|
import helperFunctions as hF
|
|
import numpy as np
|
|
from warnings import warn
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class Baseline:
|
|
|
|
def __init__(self):
|
|
self.baseline_frequency = -1
|
|
self.serial_correlation = []
|
|
self.vector_strength = -1
|
|
self.coefficient_of_variation = -1
|
|
|
|
def get_baseline_frequency(self):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def get_serial_correlation(self, max_lag):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def get_vector_strength(self):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def get_coefficient_of_variation(self):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def get_interspike_intervals(self):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def plot_baseline(self, save_path=None):
|
|
"""
|
|
plots the stimulus / eod, together with the v1, spiketimes and frequency
|
|
:return:
|
|
"""
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def plot_inter_spike_interval_histogram(self, save_path=None):
|
|
isi = self.get_interspike_intervals() * 1000 # change unit to milliseconds
|
|
maximum = max(isi)
|
|
bins = np.arange(0, maximum * 1.01, 0.1)
|
|
|
|
plt.title('Baseline ISIs')
|
|
plt.xlabel('ISI in ms')
|
|
plt.ylabel('Count')
|
|
plt.hist(isi, bins=bins)
|
|
|
|
if save_path is not None:
|
|
plt.savefig(save_path)
|
|
else:
|
|
plt.show()
|
|
|
|
plt.close()
|
|
|
|
def plot_serial_correlation(self, max_lag, save_path=None):
|
|
plt.title("Baseline Serial correlation")
|
|
plt.xlabel("Lag")
|
|
plt.ylabel("Correlation")
|
|
|
|
plt.plot(np.arange(1,max_lag+1, 1), self.get_serial_correlation(max_lag))
|
|
|
|
if save_path is not None:
|
|
plt.savefig(save_path)
|
|
else:
|
|
plt.show()
|
|
|
|
plt.close()
|
|
|
|
|
|
class BaselineCellData(Baseline):
|
|
|
|
def __init__(self, cell_data: CellData):
|
|
super().__init__()
|
|
self.data = cell_data
|
|
|
|
def get_baseline_frequency(self):
|
|
if self.baseline_frequency == -1:
|
|
base_freqs = []
|
|
for freq in self.data.get_mean_isi_frequencies():
|
|
delay = self.data.get_delay()
|
|
sampling_interval = self.data.get_sampling_interval()
|
|
if delay < 0.1:
|
|
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
|
|
|
|
idx_start = int(0.025 / sampling_interval)
|
|
idx_end = int((delay - 0.025) / sampling_interval)
|
|
base_freqs.append(np.mean(freq[idx_start:idx_end]))
|
|
|
|
self.baseline_frequency = np.median(base_freqs)
|
|
|
|
return self.baseline_frequency
|
|
|
|
def get_vector_strength(self):
|
|
if self.vector_strength == -1:
|
|
times = self.data.get_base_traces(self.data.TIME)
|
|
eods = self.data.get_base_traces(self.data.EOD)
|
|
v1_traces = self.data.get_base_traces(self.data.V1)
|
|
self.vector_strength = hF.calculate_vector_strength_from_v1_trace(times, eods, v1_traces)
|
|
|
|
return self.vector_strength
|
|
|
|
def get_serial_correlation(self, max_lag):
|
|
if len(self.serial_correlation) != max_lag:
|
|
serial_cors = []
|
|
for spiketimes in self.data.get_base_spikes():
|
|
sc = hF.calculate_serial_correlation(spiketimes, max_lag)
|
|
serial_cors.append(sc)
|
|
serial_cors = np.array(serial_cors)
|
|
mean_sc = np.mean(serial_cors, axis=0)
|
|
|
|
self.serial_correlation = mean_sc
|
|
return self.serial_correlation
|
|
|
|
def get_coefficient_of_variation(self):
|
|
if self.coefficient_of_variation == -1:
|
|
spiketimes = self.data.get_base_spikes()
|
|
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
|
|
cvs = []
|
|
for st in spiketimes:
|
|
st = np.array(st)
|
|
cvs.append(hF.calculate_coefficient_of_variation(st))
|
|
|
|
self.coefficient_of_variation = np.mean(cvs)
|
|
return self.coefficient_of_variation
|
|
|
|
def get_interspike_intervals(self):
|
|
spiketimes = self.data.get_base_spikes()
|
|
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
|
|
isis = []
|
|
for st in spiketimes:
|
|
st = np.array(st)
|
|
isis.extend(np.diff(st))
|
|
|
|
return isis
|
|
|
|
def plot_baseline(self, save_path=None):
|
|
# eod, v1, spiketimes, frequency
|
|
|
|
times = self.data.get_base_traces(self.data.TIME)
|
|
eods = self.data.get_base_traces(self.data.EOD)
|
|
v1_traces = self.data.get_base_traces(self.data.V1)
|
|
spiketimes = self.data.get_base_spikes()
|
|
|
|
fig, axes = plt.subplots(4, 1, sharex="True")
|
|
|
|
for i in range(len(times)):
|
|
axes[0].plot(times[i], eods[i])
|
|
axes[1].plot(times[i], v1_traces[i])
|
|
axes[2].plot(spiketimes, [1]*len(spiketimes), 'o')
|
|
|
|
t, f = hF.calculate_time_and_frequency_trace(spiketimes[i], self.data.get_sampling_interval())
|
|
axes[3].plot(t, f)
|
|
|
|
if save_path is not None:
|
|
plt.savefig(save_path)
|
|
else:
|
|
plt.show()
|
|
|
|
plt.close()
|
|
|
|
|
|
class BaselineModel(Baseline):
|
|
|
|
simulation_time = 30
|
|
|
|
def __init__(self, model: LifacNoiseModel, eod_frequency):
|
|
super().__init__()
|
|
self.model = model
|
|
self.eod_frequency = eod_frequency
|
|
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
|
|
self.v1, self.spiketimes = model.simulate_fast(self.stimulus, self.simulation_time)
|
|
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
|
|
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
|
|
|
|
def get_baseline_frequency(self):
|
|
if self.baseline_frequency == -1:
|
|
self.baseline_frequency = hF.calculate_mean_isi_freq(self.spiketimes)
|
|
|
|
return self.baseline_frequency
|
|
|
|
def get_vector_strength(self):
|
|
if self.vector_strength == -1:
|
|
self.vector_strength = hF.calculate_vector_strength_from_spiketimes(self.time, self.eod, self.spiketimes,
|
|
self.model.get_sampling_interval())
|
|
return self.vector_strength
|
|
|
|
def get_serial_correlation(self, max_lag):
|
|
if len(self.serial_correlation) != max_lag:
|
|
self.serial_correlation = hF.calculate_serial_correlation(self.spiketimes, max_lag)
|
|
return self.serial_correlation
|
|
|
|
def get_coefficient_of_variation(self):
|
|
if self.coefficient_of_variation == -1:
|
|
self.coefficient_of_variation = hF.calculate_coefficient_of_variation(self.spiketimes)
|
|
return self.coefficient_of_variation
|
|
|
|
def get_interspike_intervals(self):
|
|
return np.diff(self.spiketimes)
|
|
|
|
def plot_baseline(self, save_path=None):
|
|
# eod, v1, spiketimes, frequency
|
|
|
|
fig, axes = plt.subplots(4, 1, sharex="True")
|
|
|
|
axes[0].plot(self.time, self.eod)
|
|
axes[1].plot(self.time, self.v1)
|
|
axes[2].plot(self.spiketimes, [1]*len(self.spiketimes), 'o')
|
|
|
|
t, f = hF.calculate_time_and_frequency_trace(self.spiketimes, self.model.get_sampling_interval())
|
|
axes[3].plot(t, f)
|
|
|
|
if save_path is not None:
|
|
plt.savefig(save_path)
|
|
else:
|
|
plt.show()
|
|
|
|
plt.close()
|
|
|
|
|
|
def get_baseline_class(data, eod_freq=None) -> Baseline:
|
|
if isinstance(data, CellData):
|
|
return BaselineCellData(data)
|
|
if isinstance(data, LifacNoiseModel):
|
|
if eod_freq is None:
|
|
raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
|
|
return BaselineModel(data, eod_freq)
|
|
|
|
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))
|