add plot_polar for vector strength

This commit is contained in:
a.ott 2020-05-20 15:19:38 +02:00
parent 676a0e4945
commit 869923a2be

View File

@ -30,6 +30,9 @@ class Baseline:
def get_interspike_intervals(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_spiketime_phases(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def plot_baseline(self, save_path=None, time_length=0.2):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
@ -80,16 +83,16 @@ class Baseline:
return isis
@staticmethod
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2):
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
"""
plots the stimulus / eod, together with the v1, spiketimes and frequency
:return:
"""
length_data_points = int(time_length / sampling_interval)
start_idx = int(len(time) * 0.5 - length_data_points * 0.5)
start_idx = int(len(time) * position)
start_idx = start_idx if start_idx >= 0 else 0
end_idx = int(len(time) * 0.5 + length_data_points * 0.5) + 1
end_idx = int(len(time) * position + length_data_points) + 1
end_idx = end_idx if end_idx <= len(time) else len(time)
spiketimes = np.array(spiketimes)
@ -98,7 +101,7 @@ class Baseline:
fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
axes[0].set_ylabel("Stimulus [mV]")
axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)
max_v1 = max(v1[start_idx:end_idx])
axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
@ -118,6 +121,23 @@ class Baseline:
plt.close()
def plot_polar_vector_strength(self, save_path=None):
phases = self.get_spiketime_phases()
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
# r = np.arange(0, 1, 0.001)
# theta = 2 * 2 * np.pi * r
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
bins = np.arange(0, np.pi * 2, 0.1)
ax.hist(phases, bins=bins)
if save_path is not None:
plt.savefig(save_path + "isi-histogram.png")
else:
plt.show()
plt.close()
def plot_interspike_interval_histogram(self, save_path=None):
isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds
maximum = max(isi)
@ -185,7 +205,23 @@ class BaselineCellData(Baseline):
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
def plot_baseline(self, save_path=None, time_length=0.2):
def get_spiketime_phases(self):
times = self.data.get_base_traces(self.data.TIME)
spiketimes = self.data.get_base_spikes()
eods = self.data.get_base_traces(self.data.EOD)
sampling_interval = self.data.get_sampling_interval()
phase_list = []
for i in range(len(times)):
spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
# eod, v1, spiketimes, frequency
time = self.data.get_base_traces(self.data.TIME)[0]
@ -194,7 +230,7 @@ class BaselineCellData(Baseline):
spiketimes = self.data.get_base_spikes()[0]
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
self.data.get_sampling_interval(), save_path, time_length)
self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)
class BaselineModel(Baseline):
@ -243,9 +279,23 @@ class BaselineModel(Baseline):
def get_interspike_intervals(self):
return self._get_interspike_intervals_given_data(self.spiketimes)
def plot_baseline(self, save_path=None, time_length=0.2):
def get_spiketime_phases(self):
sampling_interval = self.model.get_sampling_interval()
phase_list = []
for i in range(len(self.spiketimes)):
spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
phase_list.extend(phase_times)
return phase_list
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
self.model.get_sampling_interval(), save_path, time_length)
self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
save_path, position, time_length)
def get_baseline_class(data, eod_freq=None) -> Baseline: