from CellData import CellData from models.LIFACnoise import LifacNoiseModel from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus import helperFunctions as hF import numpy as np from warnings import warn import matplotlib.pyplot as plt class Baseline: def __init__(self): self.baseline_frequency = -1 self.serial_correlation = [] self.vector_strength = -1 self.coefficient_of_variation = -1 def get_baseline_frequency(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_serial_correlation(self, max_lag): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_vector_strength(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_coefficient_of_variation(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_interspike_intervals(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def plot_baseline(self, save_path=None): """ plots the stimulus / eod, together with the v1, spiketimes and frequency :return: """ raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def plot_inter_spike_interval_histogram(self, save_path=None): isi = self.get_interspike_intervals() * 1000 # change unit to milliseconds maximum = max(isi) bins = np.arange(0, maximum * 1.01, 0.1) plt.title('Baseline ISIs') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(isi, bins=bins) if save_path is not None: plt.savefig(save_path) else: plt.show() plt.close() def plot_serial_correlation(self, max_lag, save_path=None): plt.title("Baseline Serial correlation") plt.xlabel("Lag") plt.ylabel("Correlation") plt.plot(np.arange(1,max_lag+1, 1), self.get_serial_correlation(max_lag)) if save_path is not None: plt.savefig(save_path) else: plt.show() plt.close() class BaselineCellData(Baseline): def __init__(self, cell_data: CellData): super().__init__() self.data = cell_data def get_baseline_frequency(self): if self.baseline_frequency == -1: base_freqs = [] for freq in self.data.get_mean_isi_frequencies(): delay = self.data.get_delay() sampling_interval = self.data.get_sampling_interval() if delay < 0.1: warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") idx_start = int(0.025 / sampling_interval) idx_end = int((delay - 0.025) / sampling_interval) base_freqs.append(np.mean(freq[idx_start:idx_end])) self.baseline_frequency = np.median(base_freqs) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: times = self.data.get_base_traces(self.data.TIME) eods = self.data.get_base_traces(self.data.EOD) v1_traces = self.data.get_base_traces(self.data.V1) self.vector_strength = hF.calculate_vector_strength_from_v1_trace(times, eods, v1_traces) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: serial_cors = [] for spiketimes in self.data.get_base_spikes(): sc = hF.calculate_serial_correlation(spiketimes, max_lag) serial_cors.append(sc) serial_cors = np.array(serial_cors) mean_sc = np.mean(serial_cors, axis=0) self.serial_correlation = mean_sc return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: spiketimes = self.data.get_base_spikes() # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) cvs = [] for st in spiketimes: st = np.array(st) cvs.append(hF.calculate_coefficient_of_variation(st)) self.coefficient_of_variation = np.mean(cvs) return self.coefficient_of_variation def get_interspike_intervals(self): spiketimes = self.data.get_base_spikes() # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) isis = [] for st in spiketimes: st = np.array(st) isis.extend(np.diff(st)) return isis def plot_baseline(self, save_path=None): # eod, v1, spiketimes, frequency times = self.data.get_base_traces(self.data.TIME) eods = self.data.get_base_traces(self.data.EOD) v1_traces = self.data.get_base_traces(self.data.V1) spiketimes = self.data.get_base_spikes() fig, axes = plt.subplots(4, 1, sharex="True") for i in range(len(times)): axes[0].plot(times[i], eods[i]) axes[1].plot(times[i], v1_traces[i]) axes[2].plot(spiketimes, [1]*len(spiketimes), 'o') t, f = hF.calculate_time_and_frequency_trace(spiketimes[i], self.data.get_sampling_interval()) axes[3].plot(t, f) if save_path is not None: plt.savefig(save_path) else: plt.show() plt.close() class BaselineModel(Baseline): simulation_time = 30 def __init__(self, model: LifacNoiseModel, eod_frequency): super().__init__() self.model = model self.eod_frequency = eod_frequency self.stimulus = SinusoidalStepStimulus(eod_frequency, 0) self.v1, self.spiketimes = model.simulate_fast(self.stimulus, self.simulation_time) self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval()) self.time = np.arange(0, self.simulation_time, model.get_sampling_interval()) def get_baseline_frequency(self): if self.baseline_frequency == -1: self.baseline_frequency = hF.calculate_mean_isi_freq(self.spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: self.vector_strength = hF.calculate_vector_strength_from_spiketimes(self.time, self.eod, self.spiketimes, self.model.get_sampling_interval()) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: self.serial_correlation = hF.calculate_serial_correlation(self.spiketimes, max_lag) return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: self.coefficient_of_variation = hF.calculate_coefficient_of_variation(self.spiketimes) return self.coefficient_of_variation def get_interspike_intervals(self): return np.diff(self.spiketimes) def plot_baseline(self, save_path=None): # eod, v1, spiketimes, frequency fig, axes = plt.subplots(4, 1, sharex="True") axes[0].plot(self.time, self.eod) axes[1].plot(self.time, self.v1) axes[2].plot(self.spiketimes, [1]*len(self.spiketimes), 'o') t, f = hF.calculate_time_and_frequency_trace(self.spiketimes, self.model.get_sampling_interval()) axes[3].plot(t, f) if save_path is not None: plt.savefig(save_path) else: plt.show() plt.close() def get_baseline_class(data, eod_freq=None) -> Baseline: if isinstance(data, CellData): return BaselineCellData(data) if isinstance(data, LifacNoiseModel): if eod_freq is None: raise ValueError("The EOD frequency is needed for the BaselineModel Class.") return BaselineModel(data, eod_freq) raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))