boot implementations using cupy

This commit is contained in:
Till Raab 2023-05-26 10:06:02 +02:00
parent d8d28df2ea
commit 6d7e58ef80

View File

@ -2,11 +2,16 @@ import os
import sys import sys
import argparse import argparse
import numpy as np import numpy as np
try:
import cupy as cp import cupy as cp
except ImportError:
import numpy as cp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec import matplotlib.gridspec as gridspec
import pandas as pd import pandas as pd
from IPython import embed from IPython import embed
from tqdm import tqdm
def load_and_converete_boris_events(trial_path, recording, sr): def load_and_converete_boris_events(trial_path, recording, sr):
@ -86,30 +91,76 @@ def kde(event_dt, max_dt = 60):
plt.plot(conv_t, conv_array) plt.plot(conv_t, conv_array)
def permulation_kde(event_dt, repetitions = 100, max_dt = 60): def permulation_kde(select_event_dt, repetitions = 2000, max_dt = 60, max_mem_use_GB = 1):
def chunk_permutation(select_event_dt, conv_tt, n_chuck, max_jitter, kernal_w, kernal_h):
# array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt))
# event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1))
event_dt_perm = cp.tile(select_event_dt, (len(conv_tt), n_chuck, 1))
jitter = np.random.uniform(-max_jitter, max_jitter, size=(event_dt_perm.shape[1], event_dt_perm.shape[2]))
jitter = np.expand_dims(jitter, axis=0)
event_dt_perm += jitter
# conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
try:
kde_3d_numpy = cp.asnumpy(kde_3d)
del event_dt_perm, gauss_3d, kde_3d
return kde_3d_numpy
except AttributeError:
del event_dt_perm, gauss_3d
return kde_3d
embed() embed()
quit() quit()
kernal_w = 1 kernal_w = 1
kernal_h = 0.2 kernal_h = 0.2
max_jitter = 30
select_event_dt = event_dt[np.abs(event_dt) <= max_dt + max_jitter*2]
conv_t = cp.arange(-max_dt, max_dt, 1) conv_t = cp.arange(-max_dt, max_dt, 1)
conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1)) conv_tt = cp.reshape(conv_t, (len(conv_t), 1, 1))
# array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt)) chunk_size = int(np.floor(max_mem_use_GB / (select_event_dt.nbytes * conv_t.size / 1e9)))
event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1)) chunk_collector =[]
# conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
# for _ in range(repetitions // chunk_size):
for _ in range(3):
chunk_boot_KDE = chunk_permutation(select_event_dt, conv_tt, chunk_size, max_jitter, kernal_w, kernal_h)
chunk_collector.extend(chunk_boot_KDE)
# # array.shape = (120, 100, 15486) = (len(conv_t), repetitions, len(event_dt))
# # event_dt_perm = cp.tile(event_dt, (len(conv_t), repetitions, 1))
# event_dt_perm = cp.tile(event_dt, (len(conv_t), chunk_size, 1))
# jitter = np.random.uniform(-max_jitter, max_jitter, size=(event_dt_perm.shape[1], event_dt_perm.shape[2]))
# jitter = np.expand_dims(jitter, axis=0)
#
# event_dt_perm += jitter
# # conv_t_perm = cp.tile(conv_tt, (1, repetitions, len(event_dt)))
#
# gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
# kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
# try:
# kde_3d_numpy = cp.asnumpy(kde_3d)
# chunk_collector.extend(kde_3d_numpy)
# except AttributeError:
# chunk_collector.extend(kde_3d)
# del event_dt_perm, gauss_3d, kde_3d
chunk_boot_KDE = chunk_permutation(select_event_dt, conv_tt, repetitions % chunk_size, max_jitter, kernal_w, kernal_h)
chunk_collector.extend(chunk_boot_KDE)
gauss_3d = cp.exp(-((conv_tt - event_dt_perm) / kernal_w) ** 2 / 2) * kernal_h
kde_3d = cp.sum(gauss_3d, axis = 2).transpose()
kde_3d_numpy = cp.asnumpy(kde_3d)
def main(base_path): def main(base_path):
trial_summary = pd.read_csv('trial_summary.csv', index_col=0) trial_summary = pd.read_csv('trial_summary.csv', index_col=0)
lose_chrips_centered_on_ag_off_t = [] lose_chrips_centered_on_ag_off_t = []
for index, trial in trial_summary.iterrows(): for index, trial in tqdm(trial_summary.iterrows()):
trial_path = os.path.join(base_path, trial['recording']) trial_path = os.path.join(base_path, trial['recording'])
if trial['group'] < 5: if trial['group'] < 5:
@ -128,12 +179,12 @@ def main(base_path):
load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000) load_and_converete_boris_events(trial_path, trial['recording'], sr=20_000)
### communication ### communication
got_chirps = False if not os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')):
if os.path.exists(os.path.join(trial_path, 'chirp_times_cnn.npy')): continue
chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy')) chirp_t = np.load(os.path.join(trial_path, 'chirp_times_cnn.npy'))
chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy')) chirp_ids = np.load(os.path.join(trial_path, 'chirp_ids_cnn.npy'))
chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]] chirp_times = [chirp_t[chirp_ids == trial['win_ID']], chirp_t[chirp_ids == trial['lose_ID']]]
got_chirps = True
rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter] rise_idx = np.load(os.path.join(trial_path, 'analysis', 'rise_idx.npy'))[::sorter]
rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))] rise_idx_int = [np.array(rise_idx[i][~np.isnan(rise_idx[i])], dtype=int) for i in range(len(rise_idx))]