P-unit_model/Baseline.py
2020-07-04 11:28:33 +02:00

400 lines
15 KiB
Python

from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import numpy as np
import matplotlib.pyplot as plt
import pickle
from os.path import join, exists
class Baseline:
def __init__(self):
self.save_file_name = "baseline_values.pkl"
self.baseline_frequency = -1
self.serial_correlation = []
self.vector_strength = -1
self.coefficient_of_variation = -1
self.burstiness = -1
def get_baseline_frequency(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_serial_correlation(self, max_lag):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_vector_strength(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_coefficient_of_variation(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_burstiness(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def __get_burstiness__(self, eod_freq):
isis = np.array(self.get_interspike_intervals())
if len(isis) == 0:
return 0
bursts = isis[isis < 1.5 * (1.0/eod_freq)]
return len(bursts) / float(len(isis))
def get_interspike_intervals(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_spiketime_phases(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def plot_baseline(self, save_path=None, time_length=0.2):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
@staticmethod
def _get_baseline_frequency_given_data(spiketimes):
base_freqs = []
for st in spiketimes:
base_freqs.append(hF.calculate_mean_isi_freq(st))
return np.median(base_freqs)
@staticmethod
def _get_serial_correlation_given_data(max_lag, spikestimes):
serial_cors = []
for st in spikestimes:
sc = hF.calculate_serial_correlation(st, max_lag)
serial_cors.append(sc)
serial_cors = np.array(serial_cors)
res = np.mean(serial_cors, axis=0)
return res
@staticmethod
def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
vs_per_trial = []
for i in range(len(spiketimes)):
vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
vs_per_trial.append(vs)
return np.mean(vs_per_trial)
@staticmethod
def _get_coefficient_of_variation_given_data(spiketimes):
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
cvs = []
for st in spiketimes:
st = np.array(st)
cvs.append(hF.calculate_coefficient_of_variation(st))
return np.mean(cvs)
@staticmethod
def _get_interspike_intervals_given_data(spiketimes):
isis = []
for st in spiketimes:
st = np.array(st)
isis.extend(np.diff(st))
return isis
@staticmethod
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
"""
plots the stimulus / eod, together with the v1, spiketimes and frequency
:return:
"""
length_data_points = int(time_length / sampling_interval)
start_idx = int(len(time) * position)
start_idx = start_idx if start_idx >= 0 else 0
end_idx = int(len(time) * position + length_data_points) + 1
end_idx = end_idx if end_idx <= len(time) else len(time)
spiketimes = np.array(spiketimes)
spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])]
fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)
max_v1 = max(v1[start_idx:end_idx])
axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))],
'o', color='orange')
axes[1].set_ylabel("V1-Trace [mV]")
t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval)
axes[2].plot(t, f)
axes[2].set_ylabel("ISI-Frequency [Hz]")
axes[2].set_xlabel("Time [s]")
if save_path is not None:
plt.savefig(save_path + "baseline.png")
else:
plt.show()
plt.close()
@staticmethod
def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None):
cell_isis = np.array(cell_isis) * 1000
model_isis = np.array(model_isis) * 1000
maximum = max(max(cell_isis), max(model_isis))
bins = np.arange(0, maximum * 1.01, 0.1)
plt.title('Baseline ISIs')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True)
plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True)
plt.legend()
if save_path is not None:
plt.savefig(save_path + "isi-histogram_comparision.png")
else:
plt.show()
plt.close()
def plot_polar_vector_strength(self, save_path=None):
phases = self.get_spiketime_phases()
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
# r = np.arange(0, 1, 0.001)
# theta = 2 * 2 * np.pi * r
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
bins = np.arange(0, np.pi * 2, 0.1)
ax.hist(phases, bins=bins)
if save_path is not None:
plt.savefig(save_path + "vector_strength_polar_plot.png")
else:
plt.show()
plt.close()
def plot_interspike_interval_histogram(self, save_path=None):
isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds
if len(isi) == 0:
print("NON SPIKES IN BASELINE OF CELL/MODEL")
plt.title('Baseline ISIs - NO SPIKES!')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=np.arange(0, 1, 0.1))
if save_path is not None:
plt.savefig(save_path + "isi-histogram.png")
else:
plt.show()
plt.close()
return
maximum = max(isi)
bins = np.arange(0, maximum * 1.01, 0.1)
plt.title('Baseline ISIs')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=bins)
if save_path is not None:
plt.savefig(save_path + "isi-histogram.png")
else:
plt.show()
plt.close()
def plot_serial_correlation(self, max_lag, save_path=None):
plt.title("Baseline Serial correlation")
plt.xlabel("Lag")
plt.ylabel("Correlation")
plt.ylim((-1, 1))
plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag))
if save_path is not None:
plt.savefig(save_path + "serial_correlation.png")
else:
plt.show()
plt.close()
def save_values(self, save_directory):
values = {}
values["baseline_frequency"] = self.get_baseline_frequency()
values["serial correlation"] = self.get_serial_correlation(max_lag=10)
values["vector strength"] = self.get_vector_strength()
values["coefficient of variation"] = self.get_coefficient_of_variation()
values["burstiness"] = self.get_burstiness()
with open(join(save_directory, self.save_file_name), "wb") as file:
pickle.dump(values, file)
print("Baseline: Values saved!")
def load_values(self, save_directory):
file_path = join(save_directory, self.save_file_name)
if not exists(file_path):
print("Baseline: No file to load")
return False
file = open(file_path, "rb")
values = pickle.load(file)
self.baseline_frequency = values["baseline_frequency"]
self.serial_correlation = values["serial correlation"]
self.vector_strength = values["vector strength"]
self.coefficient_of_variation = values["coefficient of variation"]
self.burstiness = values["burstiness"]
print("Baseline: Values loaded!")
return True
class BaselineCellData(Baseline):
def __init__(self, cell_data: CellData):
super().__init__()
self.data = cell_data
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
spiketimes = self.data.get_base_spikes()
self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
times = self.data.get_base_traces(self.data.TIME)
eods = self.data.get_base_traces(self.data.EOD)
spiketimes = self.data.get_base_spikes()
sampling_interval = self.data.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) < max_lag:
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
return self.serial_correlation[:max_lag]
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
return self.coefficient_of_variation
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
def get_spiketime_phases(self):
times = self.data.get_base_traces(self.data.TIME)
spiketimes = self.data.get_base_spikes()
eods = self.data.get_base_traces(self.data.EOD)
sampling_interval = self.data.get_sampling_interval()
phase_list = []
for i in range(len(times)):
spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def get_burstiness(self):
if self.burstiness == -1:
self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency())
return self.burstiness
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
# eod, v1, spiketimes, frequency
time = self.data.get_base_traces(self.data.TIME)[0]
eod = self.data.get_base_traces(self.data.EOD)[0]
v1_trace = self.data.get_base_traces(self.data.V1)[0]
spiketimes = self.data.get_base_spikes()[0]
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)
class BaselineModel(Baseline):
simulation_time = 30
def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
super().__init__()
self.model = model
self.eod_frequency = eod_frequency
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
self.v1_traces = []
self.spiketimes = []
for i in range(trials):
v, st = model.simulate_fast(self.stimulus, self.simulation_time)
self.v1_traces.append(v)
self.spiketimes.append(st)
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
times = [self.time] * len(self.spiketimes)
eods = [self.eod] * len(self.spiketimes)
sampling_interval = self.model.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
return self.coefficient_of_variation
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.spiketimes)
def get_burstiness(self):
if self.burstiness == -1:
self.burstiness = self.__get_burstiness__(self.eod_frequency)
return self.burstiness
def get_spiketime_phases(self):
sampling_interval = self.model.get_sampling_interval()
phase_list = []
for i in range(len(self.spiketimes)):
spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
save_path, position, time_length)
def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
if isinstance(data, CellData):
return BaselineCellData(data)
if isinstance(data, LifacNoiseModel):
if eod_freq is None:
raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
return BaselineModel(data, eod_freq, trials=trials)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))