P-unit_model/Baseline.py

357 lines
13 KiB
Python

from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import numpy as np
import matplotlib.pyplot as plt
class Baseline:
def __init__(self):
self.baseline_frequency = -1
self.serial_correlation = []
self.vector_strength = -1
self.coefficient_of_variation = -1
def get_baseline_frequency(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_serial_correlation(self, max_lag):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_vector_strength(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_coefficient_of_variation(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_burstiness(self):
isis = np.array(self.get_interspike_intervals()) * 1000 # change unit to ms
if len(isis) <= 10:
return 0
step = 0.1
bins = np.arange(0, min(isis) * 3, step)
num_spikes_per_bin = np.zeros(bins.shape)
for i, bin in enumerate(bins):
num_of_spikes = np.sum(isis[(isis >= bin) & (isis < bin + step)])
num_spikes_per_bin[i] = num_of_spikes
max_found = -1
end_of_peak = -1
if max(num_spikes_per_bin) < 10:
return 0
for i, num in enumerate(num_spikes_per_bin):
if i + 1 >= len(num_spikes_per_bin):
return 0
if max_found == -1:
if num_spikes_per_bin[i+1] > num:
continue
elif num > 10:
max_found = i
else:
if num_spikes_per_bin[i + 1] > num:
end_of_peak = i +1
break
burstiness = sum(num_spikes_per_bin[:end_of_peak]) / len(isis)
# bins = np.arange(0, max(isis) * 1.01, 0.1)
#
# plt.title('Baseline ISIs - burstiness {:.2f}'.format(burstiness))
# plt.xlabel('ISI in ms')
# plt.ylabel('Count')
# plt.hist(isis, bins=bins)
# plt.plot((0.5*step, bins[end_of_peak-1] + 0.5*step,), (0, 0), 'o')
# plt.show()
return burstiness
def get_interspike_intervals(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_spiketime_phases(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def plot_baseline(self, save_path=None, time_length=0.2):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
@staticmethod
def _get_baseline_frequency_given_data(spiketimes):
base_freqs = []
for st in spiketimes:
base_freqs.append(hF.calculate_mean_isi_freq(st))
return np.median(base_freqs)
@staticmethod
def _get_serial_correlation_given_data(max_lag, spikestimes):
serial_cors = []
for st in spikestimes:
sc = hF.calculate_serial_correlation(st, max_lag)
serial_cors.append(sc)
serial_cors = np.array(serial_cors)
return np.mean(serial_cors, axis=0)
@staticmethod
def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
vs_per_trial = []
for i in range(len(spiketimes)):
vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
vs_per_trial.append(vs)
return np.mean(vs_per_trial)
@staticmethod
def _get_coefficient_of_variation_given_data(spiketimes):
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
cvs = []
for st in spiketimes:
st = np.array(st)
cvs.append(hF.calculate_coefficient_of_variation(st))
return np.mean(cvs)
@staticmethod
def _get_interspike_intervals_given_data(spiketimes):
isis = []
for st in spiketimes:
st = np.array(st)
isis.extend(np.diff(st))
return isis
@staticmethod
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
"""
plots the stimulus / eod, together with the v1, spiketimes and frequency
:return:
"""
length_data_points = int(time_length / sampling_interval)
start_idx = int(len(time) * position)
start_idx = start_idx if start_idx >= 0 else 0
end_idx = int(len(time) * position + length_data_points) + 1
end_idx = end_idx if end_idx <= len(time) else len(time)
spiketimes = np.array(spiketimes)
spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])]
fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)
max_v1 = max(v1[start_idx:end_idx])
axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))],
'o', color='orange')
axes[1].set_ylabel("V1-Trace [mV]")
t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval)
axes[2].plot(t, f)
axes[2].set_ylabel("ISI-Frequency [Hz]")
axes[2].set_xlabel("Time [s]")
if save_path is not None:
plt.savefig(save_path + "baseline.png")
else:
plt.show()
plt.close()
def plot_polar_vector_strength(self, save_path=None):
phases = self.get_spiketime_phases()
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
# r = np.arange(0, 1, 0.001)
# theta = 2 * 2 * np.pi * r
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
bins = np.arange(0, np.pi * 2, 0.1)
ax.hist(phases, bins=bins)
if save_path is not None:
plt.savefig(save_path + "isi-histogram.png")
else:
plt.show()
plt.close()
def plot_interspike_interval_histogram(self, save_path=None):
isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds
maximum = max(isi)
bins = np.arange(0, maximum * 1.01, 0.1)
plt.title('Baseline ISIs')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=bins)
if save_path is not None:
plt.savefig(save_path + "isi-histogram.png")
else:
plt.show()
plt.close()
def plot_serial_correlation(self, max_lag, save_path=None):
plt.title("Baseline Serial correlation")
plt.xlabel("Lag")
plt.ylabel("Correlation")
plt.ylim((-1, 1))
plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag))
if save_path is not None:
plt.savefig(save_path + "serial_correlation.png")
else:
plt.show()
plt.close()
class BaselineCellData(Baseline):
def __init__(self, cell_data: CellData):
super().__init__()
self.data = cell_data
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
spiketimes = self.data.get_base_spikes()
self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
times = self.data.get_base_traces(self.data.TIME)
eods = self.data.get_base_traces(self.data.EOD)
spiketimes = self.data.get_base_spikes()
sampling_interval = self.data.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
return self.coefficient_of_variation
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
def get_spiketime_phases(self):
times = self.data.get_base_traces(self.data.TIME)
spiketimes = self.data.get_base_spikes()
eods = self.data.get_base_traces(self.data.EOD)
sampling_interval = self.data.get_sampling_interval()
phase_list = []
for i in range(len(times)):
spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
# eod, v1, spiketimes, frequency
time = self.data.get_base_traces(self.data.TIME)[0]
eod = self.data.get_base_traces(self.data.EOD)[0]
v1_trace = self.data.get_base_traces(self.data.V1)[0]
spiketimes = self.data.get_base_spikes()[0]
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)
class BaselineModel(Baseline):
simulation_time = 30
def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
super().__init__()
self.model = model
self.eod_frequency = eod_frequency
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
self.v1_traces = []
self.spiketimes = []
for i in range(trials):
v, st = model.simulate_fast(self.stimulus, self.simulation_time)
self.v1_traces.append(v)
self.spiketimes.append(st)
def get_baseline_frequency(self):
if self.baseline_frequency == -1:
self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
return self.baseline_frequency
def get_vector_strength(self):
if self.vector_strength == -1:
times = [self.time] * len(self.spiketimes)
eods = [self.eod] * len(self.spiketimes)
sampling_interval = self.model.get_sampling_interval()
self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)
return self.vector_strength
def get_serial_correlation(self, max_lag):
if len(self.serial_correlation) != max_lag:
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
return self.serial_correlation
def get_coefficient_of_variation(self):
if self.coefficient_of_variation == -1:
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
return self.coefficient_of_variation
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.spiketimes)
def get_spiketime_phases(self):
sampling_interval = self.model.get_sampling_interval()
phase_list = []
for i in range(len(self.spiketimes)):
spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
save_path, position, time_length)
def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
if isinstance(data, CellData):
return BaselineCellData(data)
if isinstance(data, LifacNoiseModel):
if eod_freq is None:
raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
return BaselineModel(data, eod_freq, trials=trials)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))