spikeRedetector/redetector.py
2021-03-07 18:33:21 +01:00

72 lines
2.2 KiB
Python

import numpy as np
from thunderfish.eventdetection import detect_peaks
def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000):
all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step)
return time[all_peak_indicies]
def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
split_start = 0
step_size = split_step
break_threshold = 0.25
splits = []
if len(v1) <= min_length:
splits = [(0, len(v1))]
else:
last_max = max(v1[0:min_length])
last_min = min(v1[0:min_length])
idx = min_length
while idx <= len(v1):
if idx + step_size > len(v1):
splits.append((split_start, len(v1)))
break
max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold
if not max_similar or not min_similar:
# print("new split")
end_idx = np.argmin(v1[idx-20:idx+21]) - 20
splits.append((split_start, idx+end_idx))
split_start = idx+end_idx
last_max = max(v1[split_start:split_start + min_length])
last_min = min(v1[split_start:split_start + min_length])
idx = split_start + min_length
continue
else:
pass
# print("elongated!")
idx += step_size
if splits[-1][1] != len(v1):
splits.append((split_start, len(v1)))
# plt.plot(v1)
# for s in splits:
# plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))
all_peaks = []
for s in splits:
first_index = s[0]
last_index = s[1]
std = np.std(v1[first_index:last_index])
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
peaks = peaks + first_index
# plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
all_peaks.extend(peaks)
# plt.show()
# plt.close()
# all_peaks = np.array(all_peaks)
return all_peaks