from CellData import CellData from models.LIFACnoise import LifacNoiseModel from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus import helperFunctions as hF import numpy as np import matplotlib.pyplot as plt import pickle from os.path import join, exists class Baseline: def __init__(self): self.save_file_name = "baseline_values.pkl" self.baseline_frequency = -1 self.serial_correlation = [] self.vector_strength = -1 self.coefficient_of_variation = -1 self.burstiness = -1 def get_baseline_frequency(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_serial_correlation(self, max_lag): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_vector_strength(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_coefficient_of_variation(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_burstiness(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def __get_burstiness__(self, eod_freq): isis = np.array(self.get_interspike_intervals()) if len(isis) == 0: return 0 fullfilled = isis < (2.5 / eod_freq) perc_bursts = np.sum(fullfilled) / len(fullfilled) return perc_bursts * (np.mean(isis)*1000) def get_interspike_intervals(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_spiketime_phases(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def plot_baseline(self, save_path=None, time_length=0.2): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") @staticmethod def _get_baseline_frequency_given_data(spiketimes): base_freqs = [] for st in spiketimes: base_freqs.append(hF.calculate_mean_isi_freq(st)) return np.median(base_freqs) @staticmethod def _get_serial_correlation_given_data(max_lag, spikestimes): serial_cors = [] for st in spikestimes: sc = hF.calculate_serial_correlation(st, max_lag) serial_cors.append(sc) serial_cors = np.array(serial_cors) res = np.mean(serial_cors, axis=0) return res @staticmethod def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval): vs_per_trial = [] for i in range(len(spiketimes)): vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval) vs_per_trial.append(vs) return np.mean(vs_per_trial) @staticmethod def _get_coefficient_of_variation_given_data(spiketimes): # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) cvs = [] for st in spiketimes: st = np.array(st) cvs.append(hF.calculate_coefficient_of_variation(st)) return np.mean(cvs) @staticmethod def _get_interspike_intervals_given_data(spiketimes): isis = [] for st in spiketimes: st = np.array(st) isis.extend(np.diff(st)) return isis @staticmethod def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2): """ plots the stimulus / eod, together with the v1, spiketimes and frequency :return: """ length_data_points = int(time_length / sampling_interval) start_idx = int(len(time) * position) start_idx = start_idx if start_idx >= 0 else 0 end_idx = int(len(time) * position + length_data_points) + 1 end_idx = end_idx if end_idx <= len(time) else len(time) spiketimes = np.array(spiketimes) spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])] fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8)) fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length)) axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx]) axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq) max_v1 = max(v1[start_idx:end_idx]) axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx]) axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))], 'o', color='orange') axes[1].set_ylabel("V1-Trace [mV]") t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval) axes[2].plot(t, f) axes[2].set_ylabel("ISI-Frequency [Hz]") axes[2].set_xlabel("Time [s]") if save_path is not None: plt.savefig(save_path + "baseline.png") else: plt.show() plt.close() @staticmethod def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None): cell_isis = np.array(cell_isis) * 1000 model_isis = np.array(model_isis) * 1000 maximum = max(max(cell_isis), max(model_isis)) bins = np.arange(0, maximum * 1.01, 0.1) plt.title('Baseline ISIs') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True) plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True) plt.legend() if save_path is not None: plt.savefig(save_path + "isi-histogram_comparision.png") else: plt.show() plt.close() def plot_polar_vector_strength(self, save_path=None): phases = self.get_spiketime_phases() fig = plt.figure() ax = fig.add_subplot(111, polar=True) # r = np.arange(0, 1, 0.001) # theta = 2 * 2 * np.pi * r # line, = ax.plot(theta, r, color='#ee8d18', lw=3) bins = np.arange(0, np.pi * 2, 0.1) ax.hist(phases, bins=bins) if save_path is not None: plt.savefig(save_path + "vector_strength_polar_plot.png") else: plt.show() plt.close() def plot_interspike_interval_histogram(self, save_path=None): isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds if len(isi) == 0: print("NON SPIKES IN BASELINE OF CELL/MODEL") plt.title('Baseline ISIs - NO SPIKES!') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(isi, bins=np.arange(0, 1, 0.1)) if save_path is not None: plt.savefig(save_path + "isi-histogram.png") else: plt.show() plt.close() return maximum = max(isi) bins = np.arange(0, maximum * 1.01, 0.1) plt.title('Baseline ISIs') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(isi, bins=bins) if save_path is not None: plt.savefig(save_path + "isi-histogram.png") else: plt.show() plt.close() def plot_serial_correlation(self, max_lag, save_path=None): plt.title("Baseline Serial correlation") plt.xlabel("Lag") plt.ylabel("Correlation") plt.ylim((-1, 1)) plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag)) if save_path is not None: plt.savefig(save_path + "serial_correlation.png") else: plt.show() plt.close() def save_values(self, save_directory): values = {} values["baseline_frequency"] = self.get_baseline_frequency() values["serial correlation"] = self.get_serial_correlation(max_lag=10) values["vector strength"] = self.get_vector_strength() values["coefficient of variation"] = self.get_coefficient_of_variation() values["burstiness"] = self.get_burstiness() with open(join(save_directory, self.save_file_name), "wb") as file: pickle.dump(values, file) print("Baseline: Values saved!") def load_values(self, save_directory): file_path = join(save_directory, self.save_file_name) if not exists(file_path): print("Baseline: No file to load") return False file = open(file_path, "rb") values = pickle.load(file) self.baseline_frequency = values["baseline_frequency"] self.serial_correlation = values["serial correlation"] self.vector_strength = values["vector strength"] self.coefficient_of_variation = values["coefficient of variation"] self.burstiness = values["burstiness"] print("Baseline: Values loaded!") return True class BaselineCellData(Baseline): def __init__(self, cell_data: CellData): super().__init__() self.data = cell_data def get_baseline_frequency(self): if self.baseline_frequency == -1: spiketimes = self.data.get_base_spikes() self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: times = self.data.get_base_traces(self.data.TIME) eods = self.data.get_base_traces(self.data.EOD) spiketimes = self.data.get_base_spikes() sampling_interval = self.data.get_sampling_interval() self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) < max_lag: self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes()) return self.serial_correlation[:max_lag] def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes()) return self.coefficient_of_variation def get_interspike_intervals(self): return self._get_interspike_intervals_given_data(self.data.get_base_spikes()) def get_spiketime_phases(self): times = self.data.get_base_traces(self.data.TIME) spiketimes = self.data.get_base_spikes() eods = self.data.get_base_traces(self.data.EOD) sampling_interval = self.data.get_sampling_interval() phase_list = [] for i in range(len(times)): spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int) rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices) phase_times = (rel_spikes / eod_durs) * 2 * np.pi phase_list.extend(phase_times) return phase_list def get_burstiness(self): if self.burstiness == -1: self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency()) return self.burstiness def plot_baseline(self, save_path=None, position=0.5, time_length=0.2): # eod, v1, spiketimes, frequency time = self.data.get_base_traces(self.data.TIME)[0] eod = self.data.get_base_traces(self.data.EOD)[0] v1_trace = self.data.get_base_traces(self.data.V1)[0] spiketimes = self.data.get_base_spikes()[0] self._plot_baseline_given_data(time, eod, v1_trace, spiketimes, self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length) class BaselineModel(Baseline): simulation_time = 30 def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1): super().__init__() self.model = model self.eod_frequency = eod_frequency self.set_model_adaption_to_baseline() self.stimulus = SinusoidalStepStimulus(eod_frequency, 0) self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval()) self.time = np.arange(0, self.simulation_time, model.get_sampling_interval()) self.v1_traces = [] self.spiketimes = [] for i in range(trials): v, st = model.simulate_fast(self.stimulus, self.simulation_time) self.v1_traces.append(v) self.spiketimes.append(st) def set_model_adaption_to_baseline(self): stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0) self.model.simulate_fast(stimulus, 1) adaption = self.model.get_adaption_trace() self.model.set_variable("a_zero", adaption[-1]) # print("Baseline: model a_zero set to", adaption[-1]) def get_baseline_frequency(self): if self.baseline_frequency == -1: self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: times = [self.time] * len(self.spiketimes) eods = [self.eod] * len(self.spiketimes) sampling_interval = self.model.get_sampling_interval() self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes) return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes) return self.coefficient_of_variation def get_interspike_intervals(self): return self._get_interspike_intervals_given_data(self.spiketimes) def get_burstiness(self): if self.burstiness == -1: self.burstiness = self.__get_burstiness__(self.eod_frequency) return self.burstiness def get_spiketime_phases(self): sampling_interval = self.model.get_sampling_interval() phase_list = [] for i in range(len(self.spiketimes)): spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int) rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices) phase_times = (rel_spikes / eod_durs) * 2 * np.pi phase_list.extend(phase_times) return phase_list def plot_baseline(self, save_path=None, position=0.5, time_length=0.2): self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0], self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency), save_path, position, time_length) def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline: if isinstance(data, CellData): return BaselineCellData(data) if isinstance(data, LifacNoiseModel): if eod_freq is None: raise ValueError("The EOD frequency is needed for the BaselineModel Class.") return BaselineModel(data, eod_freq, trials=trials) raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))