diff --git a/Baseline.py b/Baseline.py index b13a089..8df19d2 100644 --- a/Baseline.py +++ b/Baseline.py @@ -27,58 +27,15 @@ class Baseline: def get_coefficient_of_variation(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") - def get_burstiness(self, test_plot=False): - isis = np.array(self.get_interspike_intervals()) * 1000 # change unit to ms - - if len(isis) <= 10: - return 0 - - step = 0.1 - bins = np.arange(0, max(isis), step) - num_spikes_per_bin = np.zeros(bins.shape) - - for i, bin in enumerate(bins): - num_of_spikes = len((isis[(isis >= bin) & (isis < bin + step)])) - num_spikes_per_bin[i] = num_of_spikes - - max_found = -1 - end_of_peak = -1 - - if max(num_spikes_per_bin) < 10: - return 0 - - for i, num in enumerate(num_spikes_per_bin): - if i + 1 >= len(num_spikes_per_bin): - return 0 - if max_found == -1: - if num_spikes_per_bin[i+1] > num: - continue - elif num > 10: - max_found = i - else: - - if num_spikes_per_bin[i + 1] > num: - end_of_peak = i + 1 - break - - burstiness = sum(num_spikes_per_bin[:end_of_peak]) / len(isis) - - if test_plot: - print("burst peak:", sum(num_spikes_per_bin[:end_of_peak])) - print("sum num per bin:", sum(num_spikes_per_bin)) - print("len isis:", len(isis)) - bins = np.arange(0, max(isis) * 1.01, 0.1) - - plt.title('Baseline ISIs - burstiness {:.2f}'.format(burstiness)) - plt.xlabel('ISI in ms') - plt.ylabel('Count') - plt.hist(isis, bins=bins) - - plt.plot([step*(i+0.5) for i in range(len(num_spikes_per_bin))], num_spikes_per_bin, 'o', alpha=0.5) - plt.plot((0.5 * step, bins[end_of_peak - 1] + 0.5 * step,), (0, 0), 'o') - plt.show() + def get_burstiness(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def __get_burstiness__(self, eod_freq): + isis = np.array(self.get_interspike_intervals()) - return burstiness + bursts = isis[isis < 1.5 * (1.0/eod_freq)] + + return len(bursts) / float(len(isis)) def get_interspike_intervals(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") @@ -274,6 +231,9 @@ class BaselineCellData(Baseline): return phase_list + def get_burstiness(self): + return self.__get_burstiness__(self.data.get_eod_frequency()) + def plot_baseline(self, save_path=None, position=0.5, time_length=0.2): # eod, v1, spiketimes, frequency @@ -332,6 +292,9 @@ class BaselineModel(Baseline): def get_interspike_intervals(self): return self._get_interspike_intervals_given_data(self.spiketimes) + def get_burstiness(self): + return self.__get_burstiness__(self.eod_frequency) + def get_spiketime_phases(self): sampling_interval = self.model.get_sampling_interval()