add plot_polar for vector strength
This commit is contained in:
parent
676a0e4945
commit
869923a2be
66
Baseline.py
66
Baseline.py
@ -30,6 +30,9 @@ class Baseline:
|
||||
def get_interspike_intervals(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_spiketime_phases(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def plot_baseline(self, save_path=None, time_length=0.2):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
@ -80,16 +83,16 @@ class Baseline:
|
||||
return isis
|
||||
|
||||
@staticmethod
|
||||
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, save_path=None, time_length=0.2):
|
||||
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
|
||||
"""
|
||||
plots the stimulus / eod, together with the v1, spiketimes and frequency
|
||||
:return:
|
||||
"""
|
||||
length_data_points = int(time_length / sampling_interval)
|
||||
|
||||
start_idx = int(len(time) * 0.5 - length_data_points * 0.5)
|
||||
start_idx = int(len(time) * position)
|
||||
start_idx = start_idx if start_idx >= 0 else 0
|
||||
end_idx = int(len(time) * 0.5 + length_data_points * 0.5) + 1
|
||||
end_idx = int(len(time) * position + length_data_points) + 1
|
||||
end_idx = end_idx if end_idx <= len(time) else len(time)
|
||||
|
||||
spiketimes = np.array(spiketimes)
|
||||
@ -98,7 +101,7 @@ class Baseline:
|
||||
fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
|
||||
fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
|
||||
axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
|
||||
axes[0].set_ylabel("Stimulus [mV]")
|
||||
axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)
|
||||
|
||||
max_v1 = max(v1[start_idx:end_idx])
|
||||
axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
|
||||
@ -118,6 +121,23 @@ class Baseline:
|
||||
|
||||
plt.close()
|
||||
|
||||
def plot_polar_vector_strength(self, save_path=None):
|
||||
phases = self.get_spiketime_phases()
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, polar=True)
|
||||
# r = np.arange(0, 1, 0.001)
|
||||
# theta = 2 * 2 * np.pi * r
|
||||
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
|
||||
bins = np.arange(0, np.pi * 2, 0.1)
|
||||
ax.hist(phases, bins=bins)
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "isi-histogram.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
def plot_interspike_interval_histogram(self, save_path=None):
|
||||
isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds
|
||||
maximum = max(isi)
|
||||
@ -185,7 +205,23 @@ class BaselineCellData(Baseline):
|
||||
def get_interspike_intervals(self):
|
||||
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
|
||||
|
||||
def plot_baseline(self, save_path=None, time_length=0.2):
|
||||
def get_spiketime_phases(self):
|
||||
times = self.data.get_base_traces(self.data.TIME)
|
||||
spiketimes = self.data.get_base_spikes()
|
||||
eods = self.data.get_base_traces(self.data.EOD)
|
||||
sampling_interval = self.data.get_sampling_interval()
|
||||
|
||||
phase_list = []
|
||||
for i in range(len(times)):
|
||||
spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
|
||||
rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)
|
||||
|
||||
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
|
||||
phase_list.extend(phase_times)
|
||||
|
||||
return phase_list
|
||||
|
||||
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
|
||||
# eod, v1, spiketimes, frequency
|
||||
|
||||
time = self.data.get_base_traces(self.data.TIME)[0]
|
||||
@ -194,7 +230,7 @@ class BaselineCellData(Baseline):
|
||||
spiketimes = self.data.get_base_spikes()[0]
|
||||
|
||||
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
|
||||
self.data.get_sampling_interval(), save_path, time_length)
|
||||
self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)
|
||||
|
||||
|
||||
class BaselineModel(Baseline):
|
||||
@ -243,9 +279,23 @@ class BaselineModel(Baseline):
|
||||
def get_interspike_intervals(self):
|
||||
return self._get_interspike_intervals_given_data(self.spiketimes)
|
||||
|
||||
def plot_baseline(self, save_path=None, time_length=0.2):
|
||||
def get_spiketime_phases(self):
|
||||
sampling_interval = self.model.get_sampling_interval()
|
||||
|
||||
phase_list = []
|
||||
for i in range(len(self.spiketimes)):
|
||||
spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
|
||||
rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)
|
||||
|
||||
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
|
||||
phase_list.extend(phase_times)
|
||||
|
||||
return phase_list
|
||||
|
||||
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
|
||||
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
|
||||
self.model.get_sampling_interval(), save_path, time_length)
|
||||
self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
|
||||
save_path, position, time_length)
|
||||
|
||||
|
||||
def get_baseline_class(data, eod_freq=None) -> Baseline:
|
||||
|
Loading…
Reference in New Issue
Block a user