72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
|
|
import numpy as np
|
|
|
|
from thunderfish.eventdetection import detect_peaks
|
|
|
|
|
|
def detect_spiketimes(time: np.ndarray, v1, threshold=2.0, min_length=5000, split_step=1000):
|
|
all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold, min_length, split_step)
|
|
|
|
return time[all_peak_indicies]
|
|
|
|
|
|
def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
|
|
split_start = 0
|
|
step_size = split_step
|
|
break_threshold = 0.25
|
|
splits = []
|
|
|
|
if len(v1) <= min_length:
|
|
splits = [(0, len(v1))]
|
|
else:
|
|
last_max = max(v1[0:min_length])
|
|
last_min = min(v1[0:min_length])
|
|
idx = min_length
|
|
|
|
while idx <= len(v1):
|
|
if idx + step_size > len(v1):
|
|
splits.append((split_start, len(v1)))
|
|
break
|
|
|
|
max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
|
|
min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold
|
|
|
|
if not max_similar or not min_similar:
|
|
# print("new split")
|
|
end_idx = np.argmin(v1[idx-20:idx+21]) - 20
|
|
splits.append((split_start, idx+end_idx))
|
|
split_start = idx+end_idx
|
|
last_max = max(v1[split_start:split_start + min_length])
|
|
last_min = min(v1[split_start:split_start + min_length])
|
|
idx = split_start + min_length
|
|
continue
|
|
else:
|
|
pass
|
|
# print("elongated!")
|
|
|
|
idx += step_size
|
|
|
|
if splits[-1][1] != len(v1):
|
|
splits.append((split_start, len(v1)))
|
|
|
|
# plt.plot(v1)
|
|
|
|
# for s in splits:
|
|
# plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))
|
|
|
|
all_peaks = []
|
|
for s in splits:
|
|
first_index = s[0]
|
|
last_index = s[1]
|
|
std = np.std(v1[first_index:last_index])
|
|
peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
|
|
peaks = peaks + first_index
|
|
# plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
|
|
all_peaks.extend(peaks)
|
|
|
|
# plt.show()
|
|
# plt.close()
|
|
# all_peaks = np.array(all_peaks)
|
|
|
|
return all_peaks
|