highbeats_pdf/myfunctions.py
2020-12-01 12:00:50 +01:00

766 lines
34 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
from scipy import signal
import time
from scipy.stats import alpha
from scipy.stats import norm
from random import random
import os
import pandas as pd
def load_cell(set, fname = 'singlecellexample5', big_file = 'data_beat', new = False):
if (not os.path.exists(fname +'.pkl')) or (new == True):
data_all = pd.read_pickle(big_file +'.pkl')
d = data_all[data_all['dataset'] == set]
results = pd.DataFrame(d)
results.to_pickle(fname +'.pkl')
#np.save('singlecellexample5.npy', results)
else:
d = pd.read_pickle(fname +'.pkl')
return d
def default_settings(data,lw = 1,intermediate_width = 2.6*3,intermediate_length = 3, ts = 10, ls = 10, fs = 10):
inch_factor = 2.54
plt.rcParams['figure.figsize'] = (intermediate_width, intermediate_length)
plt.rcParams['font.size'] = fs
plt.rcParams['axes.titlesize'] = ts
plt.rcParams['axes.labelsize'] = ls
plt.rcParams['lines.linewidth'] = lw
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.loc'] = 'upper right'
plt.rcParams["legend.frameon"] = False
def remove_tick_marks(ax):
labels = [item.get_text() for item in ax.get_xticklabels()]
empty_string_labels = [''] * len(labels)
ax.set_xticklabels(empty_string_labels)
return ax
def remove_tick_ymarks(ax):
labels = [item.get_text() for item in ax.get_yticklabels()]
empty_string_labels = [''] * len(labels)
ax.set_yticklabels(empty_string_labels)
return ax
def euc_d(conv_beat,conv_chirp,a):
d_euclidean = np.sqrt(np.nansum((conv_beat - conv_chirp) ** 2, a))
return d_euclidean
def mean_euc_d(conv_beat,conv_chirp,a):
mean = np.sqrt(np.nanmean((conv_beat - conv_chirp) ** 2,a))
return mean
def ten_per_to_std(width):
deviation = width / 4.29
return deviation
def gauss(mult,sampling,width,x,conv_fact = 100):
# width in s
deviation = ten_per_to_std(width)
deviation = int(sampling * deviation)
gaussian_n = just_gauss(conv_fact,mult,deviation,x)
return gaussian_n
def just_gauss(conv_fact,mult, deviation,x, func = 'alpha',test = False):
points,to_cut = points_cut(deviation, mult,conv_fact)
if func == 'gaussian':
gaussian = signal.gaussian(points, std=deviation, sym=True)
elif func == 'exp':
gaussian = signal.exponential(points, 0, deviation, False)
if test == True:
check_alpha()
check_exponentials()
elif func == 'exp_self':
x_axis = np.arange(0,points, 1)
gaussian = np.exp(-x_axis / deviation)
elif func == 'alpha':
x_axis = np.arange(0, points, 1)
gaussian = x_axis* np.exp(-2.45 * x_axis / deviation)
gaussian_n = (gaussian*x) / np.sum(gaussian)
return gaussian_n
def points_cut(deviation, mult,conv_fact):
if deviation * 10 % 2:
points = deviation * mult
else:
points = deviation * mult - 1
to_cut = (int(points / 2))/conv_fact
return points,to_cut
def check_exponentials():
devs = [1, 5, 25, 50, 75, 100, 150, 250, 350, 500, 700]
for i in range(len(devs)):
if deviation * 10 % 2:
points = devs[i] * mult
else:
points = devs[i] * mult - 1
gaussian = signal.exponential(points, 0, devs[i], False)
plt.plot(gaussian)
plt.show()
def check_alpha():
conv_fact = 100
aa = [2, 3, 4, 5, 6, 7, 10]
for i in range(len(aa)):
a = aa[i]
# x = np.arange(alpha.ppf(0.01, a),
# alpha.ppf(0.99, a), 1/100000)
x = np.linspace(alpha.ppf(0.01, a),
alpha.ppf(0.99, a), 1000)
# alp = alpha.pdf(x, a)/(np.sum(alpha.pdf(x, a)))
# print(np.argmax(alp)/100)
alp = alpha.pdf(x, a)
print(alpha.ppf(0.90, a))
# plt.scatter(alpha.ppf(0.90, a),alpha.pdf(0.90, a))
plt.scatter(alpha.pdf(0.90, a), alpha.ppf(0.90, a))
# print(len(alp))
plt.plot(np.arange(0, 1000, 1) / conv_fact, alp, label=str(a))
plt.xlabel('ms')
#alpha_func = 1/(x**2*norm.cdf(a)*np.sqrt(2*np.pi))*np.exp(-0.5*(a-1/x)*2)
plt.legend()
plt.show()
rv = alpha(a)
x = np.linspace(0,10,points)
plt.plot(x, rv.pdf(x), 'k-', lw=2, label='frozen pdf')
gaussian = signal.gaussian(points, std=deviation, sym=True)
def shift_gauss(step_ms,sampling,C,gaussian):
G_all = []
step = int(step_ms*sampling/1000)
for i in np.arange(0, len(C) - len(gaussian)+1, step): # shift by 3 ms = 300 data points
G = np.zeros(len(C))
G[0 + i:len(gaussian) + i] = gaussian
G_all.append(G)
return G_all, step_ms
def shift_chirp(step_ms,sampling,C,gaussian):
step = int(step_ms * sampling / 1000)
a = np.arange(0, len(C) - len(gaussian)+1, step)
G_all = [[]]*len(a)
for i in a: # shift by 3 ms = 300 data points
G_all[int(i/step)]= C[0 + i:len(gaussian) + i]
#G_all.append(chirp)
G_all = np.asarray(G_all)
return G_all, step_ms
def shift_beat(step_ms,sampling,C,gaussian):
G_all = []
step = int(step_ms*sampling/1000)
for i in np.arange(0, len(C[0]) - len(gaussian)+1, step): # shift by 3 ms = 300 data points
chirp= C[:,0 + i:len(gaussian) + i]
G_all.append(chirp)
G_all = np.asarray(G_all)
return G_all, step_ms
def consp_create(type,G_all, C, B):
nom = np.dot(G_all, np.transpose(C * B))
if type == 'corr_p':
denom = np.transpose((np.transpose((np.sqrt(np.dot(G_all, np.transpose(np.array(B) * np.array(B))))))* np.transpose(((np.sqrt(np.nansum(G_all * C * C, axis=1)))))))
if type == 'corr_a':
denom = np.transpose((np.transpose(
(0.5 * np.dot(G_all, np.transpose(np.array(B) * np.array(B)))))
+ np.transpose(
((0.5 * np.nansum(G_all * C * C, axis=1))))))
index = nom/denom
#if denom.all() == 0 and nom.all() == 0:
# S_shift2 = np.ones(len(G_all))
#else:
S_shift2 = np.nanmax(index, axis=1)
ConSpi2 = 1 - np.abs(np.asarray(S_shift2))
all = 1-index
return ConSpi2,all
def consp_max(chir, step_ms, ConSpi2, left_c):
if chir == 'stim':
side = 17 / step_ms
I = ConSpi2[int(0.5 * len(ConSpi2)-side):int(0.5 * len(ConSpi2)+side)]#FIXME check if the middle is really the middle
if chir == 'chirp':
I = ConSpi2
Integral = np.sum(I)
max_position = np.argmax(ConSpi2) -left_c
max_value = np.nanmax(I)
return Integral, max_position, max_value
def remove_mean(chirp, beat):
C = chirp - np.mean(chirp)
B = np.transpose(np.transpose(beat) - np.mean(beat, axis = 1))
return C, B
def single_dist(left_c, beat_window, eod_fe, eod_fr, col, filter, beat, chirp, sampling, chir, plot = False, show = False,stages = 'two',test = False):
C, B = remove_mean(chirp, beat)
step_ms = 1
if stages == 'two':
C_all, step_ms = shift_chirp(step_ms, sampling, C, filter)
B_all, step_ms = shift_beat(step_ms, sampling, B, filter)
a, b, c = np.shape(B_all)
distances = [[]] * a
for i in range(a):
distances[i] = mean_euc_d(B_all[i], C_all[i], 1)
optimal_phase = np.nanmin(distances, axis=1)
max_value = np.nanmax(optimal_phase)
else:
C_all = C
B_all = B
a, b = np.shape(B_all)
distances = [[]] * a
for i in range(a):
distances[i] = mean_euc_d(B_all[i], C_all, None)
optimal_phase = np.nanmin(distances)
optimal_phase_position = np.nanargmin(distances)
max_value = optimal_phase
position = optimal_phase_position
if test == True:
for i in range(len(B_all)):
plt.title(str(i)+' '+str(distances[i]))
#plt.plot(B_all[0], color = 'red')
plt.plot(B_all[int(len(B_all)/2)], color='red')
plt.plot(B_all[i], color='blue')
plt.plot(C_all,color = 'green')
plt.show()
if plot:
plot_cons(filter, left_c, beat_window, optimal_phase, eod_fe, col, sampling, B, C, 1, chir)
if show:
plt.show()
return max_value,distances,position
def single_consp(type,left_c, beat_window, eod_fe, eod_fr, col, filter, beat, chirp, sampling, chir, plot = False, show = False):
C, B = remove_mean(chirp, beat)
step_ms = 1
G_all, step_ms = shift_gauss(step_ms, sampling, C, filter)
ConSpi2,all = consp_create(type,G_all, C, B)
Integral, max_position, max_value = consp_max(chir, step_ms, ConSpi2, left_c)
if plot:
plot_cons(filter, left_c, beat_window, ConSpi2, eod_fe, col, sampling, B, C, 0, chir)
if show:
plt.show()
return max_position, max_value, Integral,ConSpi2,all
def plot_cons(gaussian, left_c,beat_window,ConSpi2,beat,col, sampling, B, C, side, var):
fig = plt.figure(figsize=[4, 3])
ax = create_subpl(0.4, fig, 3, 3, 1, share2=True)
if var == 'chirp':
ax[0].plot((np.arange(0, len(ConSpi2), 1) - 0.5 * len(ConSpi2))+len(B[0])/(100*2) -left_c, ConSpi2, color=col)
ax[0].axvline(x = len(B[0])/(100*2)-3.5)
ax[0].plot(np.linspace(-left_c+(len(gaussian)/100)/2, len(B[0]) / 100 - left_c-(len(gaussian)/100)/2, len(ConSpi2)) , ConSpi2, color='red')
if var == 'stim':
ax[0].plot((np.arange(0, len(ConSpi2), 1) - 0.5 * len(ConSpi2)), ConSpi2, color=col)
ax[0].axvline(x=beat_window, color = 'red')
ax[0].axvline(x=-beat_window, color = 'red')
ax[1].axvline(x=-beat_window-13.5, color = 'red')
ax[1].axvline(x=+beat_window+13.5, color = 'red')
ax[1].axvline(x=0, color='red')
ax[0].set_title(beat)
ax[0].set_ylabel('Conspicousy [0-1]')
if var == 'stim':
ax[0].axvline(x=int(side), color=col)
ax[0].axvline(x=-int(side), color=col)
if var == 'chirp':
ax[0].axvline(x=int(side/2 +len(B[0]) / (100 * 2)-left_c), color=col)
ax[0].axvline(x=int(-side/2 +len(B[0]) / (100 * 2)-left_c), color=col)
ax[0].set_ylim([0, 1])
#if chirp == 'chirp':
# B = np.transpose(B)
if var == 'chirp':
ax[2].plot(np.linspace(-left_c,len(B[0])/100 -left_c,len(B[0])),
np.transpose(B[0:int(len(B)):1]))
if var == 'stim':
ax[2].plot(np.arange(-len(np.transpose(B[0:int(len(B)):1])) * 0.5 * 1000 / sampling,
len(np.transpose(B[0:int(len(B)):1])) * 0.5 * 1000 / sampling, 1000 / sampling),
np.transpose(B[0:int(len(B)):1]))
ax[2].set_title('10 shifted Beats')
ax[2].set_ylabel('AM')
if var == 'chirp':
ax[1].plot(np.linspace(-left_c,len(B[0])/100 -left_c,len(B[0])), C, color=col)
if var == 'stim':
ax[1].plot(np.arange(-len(np.transpose(B[0:int(len(B)):1])) * 0.5 * 1000 / sampling,
len(np.transpose(B[0:int(len(B)):1])) * 0.5 * 1000 / sampling, 1000 / sampling), C,
color=col)
ax[1].set_ylabel('AM')
ax[1].set_title('Chirp time window')
def mean_euc_d3(conv_beat,conv_chirp,a):
mean = np.sqrt((np.nansum((conv_beat - conv_chirp)/len(conv_beat) ** 2,a)))
return mean
def mean_euc_d2(conv_beat,conv_chirp,a):
mean = np.sqrt(np.nanmean((conv_beat - conv_chirp) ** 2,a))/len(conv_beat)
return mean
def frame_subplots(w, h, xlabel, ylabel,title):
fig = plt.figure(figsize=[w, h])
plt.axes(frameon=False)
plt.xticks([])
plt.yticks([])
plt.xlabel(xlabel,labelpad=30).set_visible(True)
plt.ylabel(ylabel, labelpad=50).set_visible(True)
ttl = plt.title(title)
ttl.set_position([.5, 1.05])
return fig
def create_subpl(h,fig, subplots,nrow, ncol, share2 = False):
ax = {}
if share2:
ax[0] = fig.add_subplot(nrow, ncol, 1)
for i in range(subplots):
if share2:
ax[i] = fig.add_subplot(nrow,ncol, i+1,sharex=ax[0])
else:
ax[i] = fig.add_subplot(nrow,ncol, i+1)
plt.subplots_adjust(hspace=h)
return ax
def windows_periods2(bef_c_t, aft_c_t,beat_window, min_bw, win, period, time_transform, period_shift, period_distance, perc_bef, perc_aft):
length = ((min_bw / time_transform) / (1 + perc_bef))* time_transform
#embed()
if length < bef_c_t+aft_c_t:
print('length in win2 is smaller as the chirp')
elif length > (bef_c_t+aft_c_t)*2.5:
length = (bef_c_t+aft_c_t)*2.5
for p in range(len(period)):
if win == 'w2':
#embed()
length_exp = 0.020
period_shift[p] = np.ceil(((length_exp+0.012)* time_transform)/period[p])*period[p] #0.12 is the length of the deviations i guess
#period_shift[p] = np.ceil((length_exp + 0.04 * time_transform) / period[p]) * period[
# p] # 0.12 is the length of the deviations i guess
#embed()
#period_shift[p] = 2*np.ceil((length_exp * time_transform) / period[p]) * period[p]
period_distance[p] = length_exp * time_transform
if (period_shift[p]+period_distance[p] *perc_bef)>beat_window[p]:
print('period'+str(p)+' uncool 1')
period_shift[p] = np.ceil((length_exp * time_transform) / period[p]) * period[p]
if (period_shift[p]+period_distance[p] *perc_bef) > beat_window[p]:
print('period' + str(p) + ' uncool 2')
period_shift[p] = 0.020 * time_transform
if win == 'w4':
period_shift[p] = length
period_distance[p] = length
period_left = period_distance *perc_bef
period_right = period_distance * perc_aft
#embed()
return period_shift,period_left, period_right
def auto_rows(data):
ncol = int(np.sqrt(len(data)))
nrow = int(np.sqrt(len(data)))
if ncol * nrow < len(data):
ncol = int(np.sqrt(len(data)))
nrow = int(np.sqrt(len(data)) + 1)
if ncol * nrow < len(data):
ncol = int(np.sqrt(len(data)) + 1)
nrow = int(np.sqrt(len(data)) + 1)
return nrow, ncol
def period_calc(beat_window, m, win, deviation_s, time_transform, beat_corr, bef_c, aft_c, chir, ver = '',conv_fact = 100,shuffled = ''):
period, bef_c_t, aft_c_t, deviation_t, period_shift, period_distance, perc_bef, perc_aft,nr_of_cycles,interval = pre_window(beat_corr,
time_transform,
deviation_s,
bef_c, aft_c, chir)
#if win == 'w1'or win == 'w3':
period_shift, period_left, period_right,exclude,consp_needed = windows_periods(ver,m, time_transform, win, beat_window,perc_bef, perc_aft, period_shift, period_distance, period, interval, nr_of_cycles, bef_c_t,
aft_c_t,shuffled = shuffled)
#if win == 'w2' or win == 'w4':
# period_shift, period_left, period_right = windows_periods2(bef_c_t, aft_c_t,beat_window, m, win, period, time_transform, period_shift, period_distance, perc_bef, perc_aft)
#if win == 'w3':
# period_shift,period_left, period_right = windows_periods3(beat_window,perc_bef, perc_aft, period_shift, period, interval, nr_of_cycles, bef_c_t, aft_c_t)
period_distance_c, period_distance_b, left_b, right_c, left_c, right_b, to_cut = post_window(conv_fact,period_left, period_right,
deviation_t, period_shift)
#embed()
#left_c = left_c+right_c+np.abs(right_b)
#right_c = right_c+right_c+np.abs(right_b)
#left_b_consp = left_b-consp_needed
dels = {}
dels['left_c'] = left_c
dels['right_c'] = right_c
dels['left_b'] = left_b
dels['right_b'] = right_b
dels['consp_needed'] = consp_needed
dels['to_cut'] = to_cut
return left_c, right_c, left_b, right_b,period_distance_c, period_distance_b,period_shift,period,to_cut,exclude,consp_needed,dels,interval
def choose_randomly(beat_window,interval,l_all = 400,r_all = 800):
#np.array(beat_window) * 2 - np.abs(bef_c_t) - np.abs(aft_c_t)
#left_c = np.zeros(len(beat_window))
#for i in range(len(beat_window)):
ra = random()
# print(ra)
# left_c[i] = ra * (np.array(beat_window[i]) * 2 - interval) - beat_window[i]
left_c = ra * (r_all + np.abs(l_all) - interval) - np.abs(l_all)
right_c = left_c + interval
return left_c, right_c
def windows_periods3(beat_window,perc_bef, perc_aft, period_shift, period, interval, nr_of_cycles, bef_c_t, aft_c_t):
period_left = np.zeros(len(period))
period_right = np.zeros(len(period))
for p in range(len(period)):
# & period[p] * perc_bef+
if (period[p] < interval) or ((period[p]+ period[p] * perc_bef)>beat_window[p]):
period_shift[p] = (period[p] * nr_of_cycles[p])
period_left[p] = bef_c_t
period_right[p] = aft_c_t
else:
period_left[p] = period[p] * perc_bef
period_right[p] = period[p] * perc_aft
period_shift[p] = period[p]
if (period[p]+ bef_c_t)>beat_window[p]:
period_shift[p] = (bef_c_t+aft_c_t)
period_left[p] = bef_c_t
period_right[p] = aft_c_t
return period_shift,period_left, period_right
def windows_periods(ver, min_bw, time_transform, win, beat_window,perc_bef, perc_aft,period_shift, period_distance, period, interval, nr_of_cycles, bef_c_t, aft_c_t, make_interval_shorter = False,closest_to_2m = True,shuffled = ''):
exclude = np.zeros(len(period))
exclude = [[]]*len(period)
consp_needed = [[]] * len(period)
interval = abs(aft_c_t) + abs(bef_c_t)
nr_of_cycles = ((interval) / period).astype(int) + 1 # nr of cycles in the length of the chirp
puffer = abs(bef_c_t) -4
if win == 'w1' or win == 'w2' or win == 'w3':
for p in range(len(period)):
if 'consp' in ver:
consp_needed[p] = period[p] * 0.5
consp_puffered = consp_needed[p] - puffer
if consp_puffered <0:
consp_puffered = 0
else:
consp_needed[p] = 0
consp_puffered = 0
mult_interval = period[p] * nr_of_cycles[p]
shorter_bw_aft = interval * perc_aft - 13.5
if period[p] < interval:
if period[p] == 0:
period_shift[p] = (interval) * 2
period_distance[p] = (interval)
exclude[p] = True
else:
if win == 'w1':
period_distance[p] = mult_interval
elif win == 'w3':
period_distance[p] = interval
elif win == 'w2':
period_distance[p] = interval
if closest_to_2m == True:
if ((mult_interval*2) <= beat_window[p]+shorter_bw_aft-period_distance[p]* perc_bef -consp_needed[p]) and (mult_interval * 2 >= period_distance[p] +consp_puffered):
period_shift[p] = mult_interval * 2
exclude[p] = True
else:
period_shift[p] = (((beat_window[p]+shorter_bw_aft- period_distance[p] * perc_bef-consp_needed[p]) / period[p]).astype(int)) * period[p]
exclude[p] = True
if period_shift[p] < mult_interval+consp_puffered:
if make_interval_shorter == False:
if ver == 'consp':
#embed()
exclude[p] = False
#period_shift[p] = []
elif ver == 'dist':
exclude[p] = False
#period_shift[p] = []
elif ver == 'conspdist':
consp_needed[p] = 0
if ((mult_interval * 2) <= beat_window[p]+shorter_bw_aft - period_distance[p] * perc_bef - consp_needed[p]) and (mult_interval * 2 >= period_distance[p] + consp_needed):
period_shift[p] = mult_interval * 2
exclude[p] = True
else:
period_shift[p] = (((beat_window[p]+shorter_bw_aft - period_distance[p] * perc_bef - consp_needed[p]) /
period[p]).astype(int)) * period[p]
if period_shift[p] < mult_interval + consp_needed[p]:
exclude[p] = False
#period_shift[p] = []
if make_interval_shorter == True:
short_interval = 5
#bef_c_t, aft_c_t, perc_bef, perc_aft = percentage(bef_c, aft_c, time_transform)
# FIXME a period distance adapt for w2 (and also for the backwards calculation
# FIXME in beat window adapt 0.8 times period plus minus
# FIXME look that it doesnt has to be 2 periods in long periods
shorter_bw_aft = (interval)*perc_aft-13.5
period_distance[p], period_shift[p] = exclude_interval(consp_needed[p],shorter_bw_aft, p,interval,short_interval,period, win, period_distance, period_shift, beat_window, perc_bef)
exclude[p] = True
if period_shift[p] <= mult_interval+consp_needed[p]:
short_interval = 7.5
#bef_c_t, aft_c_t, perc_bef, perc_aft = percentage(bef_c, aft_c, time_transform)
shorter_bw_aft = (interval) * perc_aft - 13.5
period_distance[p], period_shift[p] = exclude_interval(consp_needed[p],shorter_bw_aft, p,interval, short_interval, period, win,
period_distance, period_shift,
beat_window, perc_bef)
exclude[p] = True
if period_shift[p] < mult_interval+consp_needed[p]:
exclude[p] = False
#period_shift[p] = []
print('exclude')
else:
exclude[p] = True
elif closest_to_2m == False:
if (mult_interval * 2) <= beat_window[p] - period_distance[p] * perc_bef - consp_needed[p]:
period_shift[p] = mult_interval * 2
exclude[p] = True
if ((period[p] * nr_of_cycles[p] + np.round((12.5/period[p]))*period[p]+ period[p] * nr_of_cycles[p] * perc_bef)>beat_window[p]):
period_shift[p] = period[p] * nr_of_cycles[p]
else:
period_shift[p] = period[p] * nr_of_cycles[p] + np.round((12.5/period[p]))*period[p]
#elif ((period[p] + period[p] * perc_bef) > beat_window[p]+shorter_bw_aft):
# #embed()
# if (win == 'w3') or (win == 'w1'):
# exclude[p] = False
# print('should be excluded, not at least a period!!')
# else:
# period_distance[p] = interval
# if ((period[p] + period_distance[p] * perc_bef) > beat_window[p] + shorter_bw_aft):
# exclude[p] = False
# print('should be excluded, not at least a period!!')
# else:
# exclude[p] = True
# if ((bef_c_t + aft_c_t) * 2 + bef_c_t) > beat_window[p]:
# period_shift[p] = (interval)
# period_distance[p] = (interval)
# else:
# period_shift[p] = (interval) * 2
# period_distance[p] = (interval)
else:
if win == 'w1':
period_distance[p] = period[p]
elif win == 'w3':
period_distance[p] = period[p]
elif win == 'w2':
period_distance[p] = interval
if ((period[p]*2 + period_distance[p] * perc_bef) <= beat_window[p]-consp_needed[p]+shorter_bw_aft) and (period[p]*2 >= period_distance[p] + consp_puffered):
exclude[p] = True
period_shift[p] = period[p]*2
elif((period[p] + period_distance[p] * perc_bef) <= beat_window[p]-consp_needed[p]+shorter_bw_aft) and (period[p] >= period_distance[p] + consp_puffered):
exclude[p] = True
period_shift[p] = period[p]
else:
if ver == 'conspdist':
consp_needed[p] = 0
if ((period[p] * 2 + period_distance[p] * perc_bef) <= beat_window[
p] - consp_needed[p] + shorter_bw_aft) and (
period[p] * 2 >= period_distance[p] +consp_puffered):
exclude[p] = True
period_shift[p] = period[p] * 2
elif ((period[p] + period_distance[p] * perc_bef) <= beat_window[
p] - consp_needed[p] + shorter_bw_aft) and (period[p] >= period_distance[p] +consp_puffered):
exclude[p] = True
period_shift[p] = period[p]
else:
exclude[p] = False
#period_shift[p] = []
else:
exclude[p] = False
#period_shift[p] = []
#elif ((period[p] + period_distance[p] * perc_bef) < beat_window[p]+shorter_bw_aft):
# if ver == 'consp':
# exclude[p] = True
# else:
# exclude[p] = False
# period_shift[p] = period[p]
#if period_shift[p]<period_distance[p]:
#embed()
#if (period_distance[p]*perc_bef +period_shift[p])>beat_window[p]:
# print('error')
#if win == 'w2':
# period_distance[p] = interval
#embed()
#break
#if win == 'w2' or win == 'w4':
# for p in range(len(period)):
# if win == 'w2':
# #embed()
# length_exp = 0.020
# period_shift[p] = np.ceil(((length_exp+0.012)* time_transform)/period[p])*period[p] #0.12 is the length of the deviations i guess
# #period_shift[p] = np.ceil((length_exp + 0.04 * time_transform) / period[p]) * period[
# # p] # 0.12 is the length of the deviations i guess
# #embed()
# #period_shift[p] = 2*np.ceil((length_exp * time_transform) / period[p]) * period[p]
# period_distance[p] = length_exp * time_transform
# if (period_shift[p]+period_distance[p] *perc_bef)>beat_window[p]:
# print('period'+str(p)+' uncool 1')
# period_shift[p] = np.ceil((length_exp * time_transform) / period[p]) * period[p]
# if (period_shift[p]+period_distance[p] *perc_bef) > beat_window[p]:
# print('period' + str(p) + ' uncool 2')
# period_shift[p] = 0.020 * time_transform
elif win == 'w4':
for p in range(len(period)):
if interval*perc_bef+interval*2<=beat_window[p]:
period_shift[p] = interval*2
period_distance[p] = interval
exclude[p] = True
# FIXME here das mit den consp noch reinmachen
elif interval*perc_bef+interval<=beat_window[p]:
period_shift[p] = interval
period_distance[p] = interval
exclude[p] = True
else:
exclude[p] = False
#period_shift[p] = []
#length = ((min_bw / time_transform) / (1 + perc_bef))* time_transform
#embed()
#exclude[p] = False
#if length < bef_c_t+aft_c_t:
# exclude[p] = True
# print('length in win2 is smaller as the chirp')
#elif length > (bef_c_t+aft_c_t)*2.5:
# exclude[p] = False
# length = (bef_c_t+aft_c_t)*2.5
#else:
# exclude[p] = False
#period_shift[p] = length*2
#period_distance[p] = length
period_left = period_distance *perc_bef
period_right = period_distance * perc_aft
#embed()
return period_shift, period_left, period_right,exclude,consp_needed
def exclude_interval(consp_needed,bw, p, interval, short_interval, period, win, period_distance, period_shift, beat_window, perc_bef):
nr_of_cycles = ((interval - short_interval) / period).astype(int) + 1
interval = interval - short_interval
if win == 'w1':
period_distance[p] = period[p] * nr_of_cycles[p]
elif win == 'w3':
period_distance[p] = interval
elif win == 'w2':
period_distance[p] = interval
period_shift[p] = (((beat_window[p]+bw - period[p] * nr_of_cycles[p] * perc_bef -consp_needed) / period[
p]).astype(int)) * period[p]
return period_distance[p], period_shift[p]
def post_window(conv_fact,period_left, period_right, deviation_t, period_shift):
points, to_cut = points_cut(deviation_t*conv_fact,10,conv_fact)
#to_cut = 5 * deviation_t
#embed()
left_c = np.round((-period_left- to_cut)*conv_fact)/conv_fact
right_c = np.round((period_right+ to_cut)*conv_fact)/conv_fact
left_b = np.round((-period_left - to_cut -period_shift)*conv_fact)/conv_fact
right_b = np.round((period_right+ to_cut -period_shift)*conv_fact)/conv_fact
period_distance_c = abs(right_c - left_c)-2*to_cut
period_distance_b = abs(right_b - left_b)-2*to_cut
return period_distance_c, period_distance_b, left_b, right_c, left_c, right_b,to_cut
def one_dict_str(beats, spike_phases):
AUCI = {}
for b in range(len(beats)):
AUCI[beats[b]] = [[]] * len(spike_phases[beats[b]])
return AUCI
def one_array_str(beats, spike_phases):
AUCI = [[]]*len(beats)
for b in range(len(beats)):
AUCI[b] = [[]] * len(spike_phases[beats[b]])
return AUCI
def only_one_dict_str(beats, spike_phases):
AUCI = {}
for b in range(len(beats)):
AUCI[beats[b]] = [[]] * len(spike_phases)
return AUCI
def three_dict_str(deviation_list, data):
AUCI = {}
#embed()
for dev in range(len(deviation_list)):
AUCI[deviation_list[dev]] = {}
for d in range(len(data)):
AUCI[deviation_list[dev]][data[d]] = {}
return AUCI
def four_dict_str(deviation_list, data, beats, bin_array):
AUCI = {}
for dev in range(len(deviation_list)):
AUCI[deviation_list[dev]] = {}
for d in range(len(data)):
AUCI[deviation_list[dev]][data[d]] = [[]]*len(beats[d])
for b in range(len(beats[d])):
AUCI[deviation_list[dev]][data[d]][b] = [[]]*len(bin_array)
for p in range(len(bin_array)):
AUCI[deviation_list[dev]][data[d]][b][p] = []
return AUCI
def pre_window(beat_corr, time_transform,deviation_s,bef_c,aft_c,chir):
period = period_func(beat_corr, time_transform, chir)
deviation_t = deviation_s * time_transform
period_shift = np.zeros(len(period))
period_distance = np.zeros(len(period))
exclude = np.zeros(len(period))
bef_c_t, aft_c_t, perc_bef, perc_aft = percentage(bef_c, aft_c, time_transform)
nr_of_cycles = ((bef_c_t + aft_c_t) / period).astype(int) + 1 # nr of cycles in the length of the chirp
interval = abs(aft_c_t) + abs(bef_c_t)
#embed()
return period, bef_c_t, aft_c_t, deviation_t, period_shift, period_distance, perc_bef, perc_aft,nr_of_cycles,interval
def percentage(bef_c, aft_c, time_transform):
bef_c_t = bef_c * time_transform
aft_c_t = aft_c * time_transform
perc_bef = (np.abs(bef_c_t) / (np.abs(bef_c_t) + np.abs(aft_c_t)))
perc_aft = (np.abs(aft_c_t) / (np.abs(bef_c_t) + np.abs(aft_c_t)))
return bef_c_t, aft_c_t, perc_bef, perc_aft
def period_func(beat_corr,time_transform, chir):
if chir == 'chir':
period = np.abs(1 / beat_corr) * time_transform
if chir == 'stim':
period = (np.round(np.abs(1 / beat_corr) * time_transform)).astype(int)
period[period == float('inf')] = np.max(period[period != np.inf])
return period
def two_ylabels(y_pad,x,y,z,fig,left, right):
ax1 = fig.add_subplot(x, y, z)
ax1.set_xticks([])
ax1.set_yticks([])
ax0 = ax1.twinx()
ax0.set_xticks([])
ax0.set_yticks([])
ax1.set_ylabel(left, labelpad = y_pad)
ax0.set_ylabel(right, rotation = 270,labelpad = 40)
def del_ticks(data, ncol, ax):
for d in range(len(data) - ncol):
ax[d].set_xticks([])
a = np.arange(0, len(data), 1)
deleted = np.delete(a, (np.arange(0, len(data), ncol)))
for d in deleted:
ax[d].set_yticks([])
def axvline_everywhere(ax,subplots,eod_fe,eod_fr):
for e in range(subplots):
for i in range(int(np.max(eod_fe) / np.max(eod_fr))):
ax[e].axvline(x=eod_fr * i, color='black', linewidth=1, linestyle='-')