P-unit_model/Baseline.py
2020-05-11 15:28:03 +02:00

232 lines
8.0 KiB
Python

from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import numpy as np
from warnings import warn
import matplotlib.pyplot as plt
class Baseline:
def __init__(self):
self.baseline_frequency = -1
self.serial_correlation = []
self.vector_strength = -1
self.coefficient_of_variation = -1
def get_baseline_frequency(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_serial_correlation(self, max_lag):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_vector_strength(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_coefficient_of_variation(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_interspike_intervals(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def plot_baseline(self, save_path=None):
"""
plots the stimulus / eod, together with the v1, spiketimes and frequency
:return:
"""
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def plot_inter_spike_interval_histogram(self, save_path=None):
isi = self.get_interspike_intervals() * 1000 # change unit to milliseconds
maximum = max(isi)
bins = np.arange(0, maximum * 1.01, 0.1)
plt.title('Baseline ISIs')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=bins)
if save_path is not None:
plt.savefig(save_path)
else:
plt.show()
plt.close()
def plot_serial_correlation(self, max_lag, save_path=None):
plt.title("Baseline Serial correlation")
plt.xlabel("Lag")
plt.ylabel("Correlation")
plt.ylim((-1, 1))
plt.plot(np.arange(1,max_lag+1, 1), self.get_serial_correlation(max_lag))
if save_path is not None:
plt.savefig(save_path)
else:
plt.show()
plt.close()
class BaselineCellData(Baseline):
def __init__(self, cell_data: CellData):
super().__init__()
self.data = cell_data
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
base_freqs = []
for freq in self.data.get_mean_isi_frequencies():
delay = self.data.get_delay()
sampling_interval = self.data.get_sampling_interval()
if delay < 0.1:
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
idx_start = int(0.025 / sampling_interval)
idx_end = int((delay - 0.025) / sampling_interval)
base_freqs.append(np.mean(freq[idx_start:idx_end]))
self.baseline_frequency = np.median(base_freqs)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
times = self.data.get_base_traces(self.data.TIME)
eods = self.data.get_base_traces(self.data.EOD)
v1_traces = self.data.get_base_traces(self.data.V1)
self.vector_strength = hF.calculate_vector_strength_from_v1_trace(times, eods, v1_traces)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
serial_cors = []
for spiketimes in self.data.get_base_spikes():
sc = hF.calculate_serial_correlation(spiketimes, max_lag)
serial_cors.append(sc)
serial_cors = np.array(serial_cors)
mean_sc = np.mean(serial_cors, axis=0)
self.serial_correlation = mean_sc
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
spiketimes = self.data.get_base_spikes()
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
cvs = []
for st in spiketimes:
st = np.array(st)
cvs.append(hF.calculate_coefficient_of_variation(st))
self.coefficient_of_variation = np.mean(cvs)
return self.coefficient_of_variation
def get_interspike_intervals(self):
spiketimes = self.data.get_base_spikes()
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isis = []
for st in spiketimes:
st = np.array(st)
isis.extend(np.diff(st))
return isis
def plot_baseline(self, save_path=None):
# eod, v1, spiketimes, frequency
times = self.data.get_base_traces(self.data.TIME)
eods = self.data.get_base_traces(self.data.EOD)
v1_traces = self.data.get_base_traces(self.data.V1)
spiketimes = self.data.get_base_spikes()
fig, axes = plt.subplots(4, 1, sharex='col')
for i in range(len(times)):
axes[0].plot(times[i], eods[i])
axes[1].plot(times[i], v1_traces[i])
axes[2].plot(spiketimes[i], [1 for i in range(len(spiketimes[i]))], 'o')
t, f = hF.calculate_time_and_frequency_trace(spiketimes[i], self.data.get_sampling_interval())
axes[3].plot(t, f)
if save_path is not None:
plt.savefig(save_path)
else:
plt.show()
plt.close()
class BaselineModel(Baseline):
simulation_time = 30
def __init__(self, model: LifacNoiseModel, eod_frequency):
super().__init__()
self.model = model
self.eod_frequency = eod_frequency
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
self.v1, self.spiketimes = model.simulate_fast(self.stimulus, self.simulation_time)
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
self.baseline_frequency = hF.calculate_mean_isi_freq(self.spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
self.vector_strength = hF.calculate_vector_strength_from_spiketimes(self.time, self.eod, self.spiketimes,
self.model.get_sampling_interval())
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
self.serial_correlation = hF.calculate_serial_correlation(self.spiketimes, max_lag)
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = hF.calculate_coefficient_of_variation(self.spiketimes)
return self.coefficient_of_variation
def get_interspike_intervals(self):
return np.diff(self.spiketimes)
def plot_baseline(self, save_path=None):
# eod, v1, spiketimes, frequency
fig, axes = plt.subplots(4, 1, sharex="col")
axes[0].plot(self.time, self.eod)
axes[1].plot(self.time, self.v1)
axes[2].plot(self.spiketimes, [1 for i in range(len(self.spiketimes))], 'o')
t, f = hF.calculate_time_and_frequency_trace(self.spiketimes, self.model.get_sampling_interval())
axes[3].plot(t, f)
if save_path is not None:
plt.savefig(save_path)
else:
plt.show()
plt.close()
def get_baseline_class(data, eod_freq=None) -> Baseline:
if isinstance(data, CellData):
return BaselineCellData(data)
if isinstance(data, LifacNoiseModel):
if eod_freq is None:
raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
return BaselineModel(data, eod_freq)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))