from parser.CellData import CellData from models.LIFACnoise import LifacNoiseModel from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus from my_util import helperFunctions as hF import numpy as np import matplotlib.pyplot as plt import pickle from os.path import join, exists class Baseline: def __init__(self): self.save_file_name = "baseline_values.pkl" self.baseline_frequency = -1 self.serial_correlation = [] self.vector_strength = -1 self.coefficient_of_variation = -1 self.burstiness = -1 def get_baseline_frequency(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_serial_correlation(self, max_lag): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_vector_strength(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_coefficient_of_variation(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_burstiness(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def __get_burstiness__(self, eod_freq): isis = np.array(self.get_interspike_intervals()) if len(isis) == 0: return 0 fullfilled = isis < (2.5 / eod_freq) perc_bursts = np.sum(fullfilled) / len(fullfilled) return perc_bursts * (np.mean(isis)*1000) def get_interspike_intervals(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_spiketime_phases(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def plot_baseline(self, save_path=None, time_length=0.2): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") @staticmethod def _get_baseline_frequency_given_data(spiketimes): base_freqs = [] for st in spiketimes: base_freqs.append(hF.calculate_mean_isi_freq(st)) return np.median(base_freqs) @staticmethod def _get_serial_correlation_given_data(max_lag, spikestimes): serial_cors = [] for st in spikestimes: sc = hF.calculate_serial_correlation(st, max_lag) serial_cors.append(sc) serial_cors = np.array(serial_cors) res = np.mean(serial_cors, axis=0) return res @staticmethod def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval): vs_per_trial = [] for i in range(len(spiketimes)): vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval) vs_per_trial.append(vs) return np.mean(vs_per_trial) @staticmethod def _get_coefficient_of_variation_given_data(spiketimes): # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) cvs = [] for st in spiketimes: st = np.array(st) cvs.append(hF.calculate_coefficient_of_variation(st)) return np.mean(cvs) @staticmethod def _get_interspike_intervals_given_data(spiketimes): isis = [] for st in spiketimes: st = np.array(st) isis.extend(np.diff(st)) return isis @staticmethod def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2): """ plots the stimulus / eod, together with the v1, spiketimes and frequency :return: """ length_data_points = int(time_length / sampling_interval) start_idx = int(len(time) * position) start_idx = start_idx if start_idx >= 0 else 0 end_idx = int(len(time) * position + length_data_points) + 1 end_idx = end_idx if end_idx <= len(time) else len(time) spiketimes = np.array(spiketimes) spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])] fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8)) fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length)) axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx]) axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq) max_v1 = max(v1[start_idx:end_idx]) axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx]) axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))], 'o', color='orange') axes[1].set_ylabel("V1-Trace [mV]") t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval) axes[2].plot(t, f) axes[2].set_ylabel("ISI-Frequency [Hz]") axes[2].set_xlabel("Time [s]") if save_path is not None: plt.savefig(save_path + "baseline.png") else: plt.show() plt.close() @staticmethod def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None): cell_isis = np.array(cell_isis) * 1000 model_isis = np.array(model_isis) * 1000 maximum = max(max(cell_isis), max(model_isis)) bins = np.arange(0, maximum * 1.01, 0.1) plt.title('Baseline ISIs') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True) plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True) plt.legend() if save_path is not None: plt.savefig(save_path + "isi-histogram_comparision.png") else: plt.show() plt.close() def plot_polar_vector_strength(self, save_path=None): phases = self.get_spiketime_phases() fig = plt.figure() ax = fig.add_subplot(111, polar=True) # r = np.arange(0, 1, 0.001) # theta = 2 * 2 * np.pi * r # line, = ax.plot(theta, r, color='#ee8d18', lw=3) bins = np.arange(0, np.pi * 2, 0.1) ax.hist(phases, bins=bins) if save_path is not None: plt.savefig(save_path + "vector_strength_polar_plot.png") else: plt.show() plt.close() def plot_interspike_interval_histogram(self, save_path=None): isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds if len(isi) == 0: print("NON SPIKES IN BASELINE OF CELL/MODEL") plt.title('Baseline ISIs - NO SPIKES!') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(isi, bins=np.arange(0, 1, 0.1)) if save_path is not None: plt.savefig(save_path + "isi-histogram.png") else: plt.show() plt.close() return maximum = max(isi) bins = np.arange(0, maximum * 1.01, 0.1) plt.title('Baseline ISIs') plt.xlabel('ISI in ms') plt.ylabel('Count') plt.hist(isi, bins=bins) if save_path is not None: plt.savefig(save_path + "isi-histogram.png") else: plt.show() plt.close() def plot_serial_correlation(self, max_lag, save_path=None): plt.title("Baseline Serial correlation") plt.xlabel("Lag") plt.ylabel("Correlation") plt.ylim((-1, 1)) plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag)) if save_path is not None: plt.savefig(save_path + "serial_correlation.png") else: plt.show() plt.close() def save_values(self, save_directory): values = {} values["baseline_frequency"] = self.get_baseline_frequency() values["serial correlation"] = self.get_serial_correlation(max_lag=10) values["vector strength"] = self.get_vector_strength() values["coefficient of variation"] = self.get_coefficient_of_variation() values["burstiness"] = self.get_burstiness() with open(join(save_directory, self.save_file_name), "wb") as file: pickle.dump(values, file) print("Baseline: Values saved!") def load_values(self, save_directory): file_path = join(save_directory, self.save_file_name) if not exists(file_path): print("Baseline: No file to load") return False file = open(file_path, "rb") values = pickle.load(file) self.baseline_frequency = values["baseline_frequency"] self.serial_correlation = values["serial correlation"] self.vector_strength = values["vector strength"] self.coefficient_of_variation = values["coefficient of variation"] self.burstiness = values["burstiness"] print("Baseline: Values loaded!") return True class BaselineCellData(Baseline): def __init__(self, cell_data: CellData): super().__init__() self.data = cell_data def get_baseline_frequency(self): if self.baseline_frequency == -1: spiketimes = self.data.get_base_spikes() self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: times = self.data.get_base_traces(self.data.TIME) eods = self.data.get_base_traces(self.data.EOD) spiketimes = self.data.get_base_spikes() sampling_interval = self.data.get_sampling_interval() self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) < max_lag: self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes()) return self.serial_correlation[:max_lag] def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes()) return self.coefficient_of_variation def get_interspike_intervals(self): return self._get_interspike_intervals_given_data(self.data.get_base_spikes()) def get_spiketime_phases(self): times = self.data.get_base_traces(self.data.TIME) spiketimes = self.data.get_base_spikes() eods = self.data.get_base_traces(self.data.EOD) sampling_interval = self.data.get_sampling_interval() phase_list = [] for i in range(len(times)): spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int) rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices) phase_times = (rel_spikes / eod_durs) * 2 * np.pi phase_list.extend(phase_times) return phase_list def get_burstiness(self): if self.burstiness == -1: self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency()) return self.burstiness def plot_baseline(self, save_path=None, position=0.5, time_length=0.2): # eod, v1, spiketimes, frequency time = self.data.get_base_traces(self.data.TIME)[0] eod = self.data.get_base_traces(self.data.EOD)[0] v1_trace = self.data.get_base_traces(self.data.V1)[0] spiketimes = self.data.get_base_spikes()[0] self._plot_baseline_given_data(time, eod, v1_trace, spiketimes, self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length) class BaselineModel(Baseline): simulation_time = 30 def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1): super().__init__() self.model = model self.eod_frequency = eod_frequency self.set_model_adaption_to_baseline() self.stimulus = SinusoidalStepStimulus(eod_frequency, 0) self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval()) self.time = np.arange(0, self.simulation_time, model.get_sampling_interval()) self.v1_traces = [] self.spiketimes = [] for i in range(trials): v, st = model.simulate(self.stimulus, self.simulation_time) self.v1_traces.append(v) self.spiketimes.append(st) def set_model_adaption_to_baseline(self): stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0) self.model.simulate(stimulus, 1) adaption = self.model.get_adaption_trace() self.model.set_variable("a_zero", adaption[-1]) # print("Baseline: model a_zero set to", adaption[-1]) def get_baseline_frequency(self): if self.baseline_frequency == -1: self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes) return self.baseline_frequency def get_vector_strength(self): if self.vector_strength == -1: times = [self.time] * len(self.spiketimes) eods = [self.eod] * len(self.spiketimes) sampling_interval = self.model.get_sampling_interval() self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval) return self.vector_strength def get_serial_correlation(self, max_lag): if len(self.serial_correlation) != max_lag: self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes) return self.serial_correlation def get_coefficient_of_variation(self): if self.coefficient_of_variation == -1: self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes) return self.coefficient_of_variation def get_interspike_intervals(self): return self._get_interspike_intervals_given_data(self.spiketimes) def get_burstiness(self): if self.burstiness == -1: self.burstiness = self.__get_burstiness__(self.eod_frequency) return self.burstiness def get_spiketime_phases(self): sampling_interval = self.model.get_sampling_interval() phase_list = [] for i in range(len(self.spiketimes)): spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int) rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices) phase_times = (rel_spikes / eod_durs) * 2 * np.pi phase_list.extend(phase_times) return phase_list def plot_baseline(self, save_path=None, position=0.5, time_length=0.2): self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0], self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency), save_path, position, time_length) def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline: if isinstance(data, CellData): return BaselineCellData(data) if isinstance(data, LifacNoiseModel): if eod_freq is None: raise ValueError("The EOD frequency is needed for the BaselineModel Class.") return BaselineModel(data, eod_freq, trials=trials) raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))