susceptibility1/nonlin_regime.py
2024-06-13 20:46:52 +02:00

506 lines
23 KiB
Python

import os
import sys
import numpy as np
import pandas as pd
from IPython import embed
from matplotlib import gridspec, pyplot as plt
from plotstyle import plot_style
from threefish.calc_time import extract_am
from threefish.core import find_folder_name, info
from threefish.defaults import default_figsize, default_ticks_talks
from threefish.load import load_savedir, save_visualization
from threefish.plot.limits import join_x, join_y, set_same_ylim
from threefish.RAM.calc_fft import log_calc_psd
from threefish.RAM.calc_model import chose_old_vs_new_model
from threefish.RAM.plot_labels import label_deltaf1, label_deltaf2, label_diff, label_f_eod_name_core_rm, \
label_fbasename_small, label_sum, \
onebeat_cond, \
remove_yticks
from threefish.RAM.plot_subplots import colors_suscept_paper_dots, plt_spikes_ROC, recalc_fr_to_DF1
from threefish.RAM.values import val_cm_to_inch, vals_model_full
from threefish.twobeat.calc_model import calc_roc_amp_core_cocktail
from threefish.twobeat.colors import colors_susept, twobeat_cond
from threefish.twobeat.labels import f_stable_name, f_vary_name
from threefish.twobeat.reformat import c_dist_recalc_func, c_dist_recalc_here, dist_recalc_phaselockingchapter, \
find_dfs, \
get_frame_cell_params
from threefish.twobeat.subplots import plt_psd_saturation, plt_single_trace, plt_stim_saturation, plt_vmem_saturation, \
power_spectrum_name
from threefish.values import values_nfft_full_model, values_stimuluslength_model_full
def nonlin_regime(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='',
nfft_for_morph=4096 * 4,
gain=1,
cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='',
show=True):
runs = 1
n = 1
dev = 0.001
#reshuffled = 'reshuffled' # ,
# standard combination with intruder small
min_amps = '_minamps_'
dev_name = ['05']
#model_cells = pd.read_csv(find_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
#if len(cells_here) < 1:
# cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=5.5)
########################################
# für das model_full, die Freuqnezen
DF1_frmult, DF2_frmult = vals_model_full(val=0.30833333333333335)
frame_cvs = pd.read_csv(find_folder_name('calc_base')+'/csv_model_data.csv')
frame_cell = frame_cvs[frame_cvs.cell == '2012-07-03-ak-invivo-1']
#embed()
for d in range(len(DF1_frmult)):
#DF2_frmult[d] = str(DF2_frmult[d])+'Fr'
#DF1_frmult[d] = str(DF1_frmult[d]) + 'Fr'
DF2_frmult[d] = recalc_fr_to_DF1(DF2_frmult, d, frame_cell.fr_data.iloc[0])
DF1_frmult[d] = recalc_fr_to_DF1(DF1_frmult, d, frame_cell.fr_data.iloc[0])
##DF2_frmult[d] = recalc_fr_to_DF1(DF2_frmult, d, frame_cell.fr_data.iloc[0])
freqs = [(DF1_frmult[3], DF2_frmult[3])]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
cells_here = ['2012-07-03-ak-invivo-1']
cells_here = ["2013-01-08-aa-invivo-1"]
for cell_here in cells_here:
###########################################
# über die frequenzen hinweg
for freq1, freq2 in freqs: # das ist
full_names = [
'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1temporal',
'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_old_fit_temporal']
full_names = ['calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1temporal',
'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_old_fit_temporal']
full_names = ['calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_fft2_temporal']
full_names = ['calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_100_FirstC1_0.0001_LastC1_0.3_FRrelavtiv__start_0.0001_end_0.3_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_not_log__point_1_fft2_temporal']
c_grouped = ['c1'] # , 'c2']
c_nrs_orig = [0.01,0.04, 0.1, 0.2] # 0.0002, 0.05, 0.5
trials_nr = 20 # 20
redo = False # True
log = 'log' # 'log'
grid0 = gridspec.GridSpec(1, 1, bottom=0.13, top=0.88, left=0.11,
right=0.95, wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.04, hspace=0.75,
subplot_spec=grid0[0], height_ratios=[1.5,1],) # height_ratios=[2,1],
grid_up = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs_orig),
hspace=0.75,
wspace=0.25,
subplot_spec=grid00[
0]) # height_ratios=[1, 1, 0.7], 1.2hspace=0.4,wspace=0.2,len(chirps)
grid_down = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.75,
wspace=0.1,
subplot_spec=grid00[1]) # 1.2hspace=0.4,wspace=0.2,len(chirps)
for i, full_name in enumerate(full_names):
frame = pd.read_csv(find_folder_name('calc_cocktailparty') + '/' + full_name + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
f += 1
#######################################################################################
# übersicht
frame_cell = frame_cell_orig[
(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[
(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str(
frame_cell.f2.unique()) + ' for cell ' + str(cell_here))
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean_original', nr=4)
#print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
c_dist_recalc = dist_recalc_phaselockingchapter()
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc)
if not c_dist_recalc:
c_nrs = np.array(c_nrs) * 100
#embed()
letters = ['A', 'B', 'C', 'D']
indexes = [[0, 1, 2, 3]]
#scores_all = [scores]
#embed()
#
scores = ['amp_B1_012_mean_original', 'amp_B2_012_mean_original', 'amp_B1+B2_012_mean_original', 'amp_B1-B2_012_mean_original']
scores = ['c_B1_012_original', 'c_B2_012_original', 'c_B1+B2_012_original', 'c_B1-B2_012_original']
color01, color012, color01_2, color02, color0_burst, color0 = colors_suscept_paper_dots()
colors = [color01, color02, color012, color01_2]
linestyles = ['-','-','-','-']
#frame_cell_orig['amp_B1+B2_012_mean_original']
#frame_cell_orig['amp_B1-B2_012_mean_original']
#for i, index in enumerate(indexes):
index = [0, 1, 2, 3]
try:
ax_u1 = plt.subplot(grid_down[0, i])
except:
print('grid search problem4')
embed()
labels = [label_deltaf1(), label_deltaf2(),
label_sum(), label_diff(), label_fbasename_small()]
plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2,
scores=np.array(scores)[index], labels=np.array(labels)[index],
colors=np.array(colors)[index],
linestyles=np.array(linestyles)[index],
linewidths=np.array(linewidths)[index],
alpha=np.array(alpha)[index],
thesum=False, B_replace='F', default_colors=False,
c_dist_recalc=c_dist_recalc)
#ax_us.append(ax_u1)
frame_cell = frame_cell_orig[
(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
c1 = c_dist_recalc_here(c_dist_recalc, frame_cell)
ax_u1.set_xlim(0, 25)
ax_u1.legend(ncol = 2, loc = (0,1.2))
if i != 0:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
#if i < 2:
# ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]],
# frame_cell[np.array(scores)[index][1]], color='grey',
# alpha=0.1)
ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 0) * np.ones(len(c_nrs)), color='black',
marker='v',
clip_on=False)
#
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.text(c_nr, yposs[i][c_nn] + 30, letters[c_nn], color='black', ha='center',
va='top')
# ax_u1.plot([c_nr, c_nr], [0, 435], color='black', linewidth=0.8, clip_on=False)
ylim = ax_u1.get_ylim()
ax_u1.set_ylim(0, ylim[-1])
start = 200 # 1000
mults_period = 3
xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
#axts_all = []
#axps_all = []
#ax_us = []
# über die kontraste gehen
axts = []
axps = []
axes = []
axss = []
p_arrays_all = []
model_fit = ''#'_old_fit_' # ''#'_old_fit_'#''#'_old_fit_'#''#'_old_fit_'#''###'_old_fit_'
model_cells, reshuffled = chose_old_vs_new_model(model_fit=model_fit)
for c_nn, c_nr in enumerate(c_nrs):
#################################
# arrays plot
trials_nr = 1
nfft_here = values_nfft_full_model()
stimulus_length_here = 100 # values_stimuluslength_model_full()
nfft_here = 20 * 20000
a_f2s = [c_nrs_orig[c_nn]]
a_f2s = [c_nrs_orig[c_nn]]
save_dir = load_savedir(level=0).split('/')[0] +'_afe_'+str(a_f2s[0])+'_nfft_'+str(nfft_here)+'_len_' + str(stimulus_length_here)
name_psd = save_dir + '_psd.npy'
name_psd_f = save_dir + '_psdf.npy'
do = True
if ((not os.path.exists(name_psd)) | (redo == True)) | (do == True):
# psd generierung
_, _, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff,
a_f2s,
fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph,
beat,
printing,
stimulus_length_here,
model_cells, position_diff, dev, cell_here, dev_name=dev_name,
a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps, mean_choice='first')
np.save(name_psd, p_arrays_p)
np.save(name_psd_f, ff_p)
else:
ff_p = np.load(name_psd_f) # p_arrays_p
p_arrays_p = np.load(name_psd) # p_arrays_p
# #
#
############################
# spikes generierung
trials_nr = 10
nfft_here = values_nfft_full_model()
stimulus_length_here = 1 # values_stimuluslength_model_full()
nfft_here = stimulus_length_here * 20000
a_f2s = [c_nrs_orig[c_nn]]
_, arrays_spikes, _, _, _, _, _, _, _, _, _ = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff,
a_f2s,
fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph,
beat,
printing,
stimulus_length_here,
model_cells, position_diff, dev, cell_here, dev_name=dev_name,
a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps, mean_choice = '')
##################################################################
# ff_p, arrays, names, p_arrays_p, arrays_spikes, arrays_stim,
p_arrays_here = [p_arrays_p[3]]
xlimp = (0, 300)
#embed()
for p in range(len(p_arrays_here)):
p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]]
ff_p = ff_p[ff_p < xlimp[1]]
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
# plot the first array
arrays_time = [arrays[3]] # [v_mems[1],v_mems[3]]#[1,2]#[1::]
arrays_here = [arrays[3]] # [arrays[1],arrays[3]]#arrays[1::]#
arrays_st = [arrays_stim[3]] #1:: [arrays_stim[1],arrays_stim[3]]#
arrays_sp = [arrays_spikes[3]] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::]
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_all.append(p_arrays_here)
for a in range(len(arrays_here)):
print('a' + str(a))
if a == 0:
freqs = [np.abs(freq1)] # ], np.abs(freq2)],
elif a == 1:
freqs = [np.abs(freq2)]
else:
freqs = [np.abs(freq1), np.abs(freq2)]
grid_pt = gridspec.GridSpecFromSubplotSpec(3, 1,
hspace=0.2,
wspace=0.2,
subplot_spec=grid_up[a, c_nn], height_ratios = [1, 1,2]
) # hspace=0.4,wspace=0.2,len(chirps)
stim = False
#if stim:
axe = plt.subplot(grid_pt[0])
axes.append(axe)
am, time_am = extract_am(arrays_st, time, extract='',norm = False)
plt_stim_saturation(a, [], am, axe, colors_array_here, f,
f_counter, names, time,
xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here,
c_dist_recalc=c_dist_recalc)
if not c_dist_recalc:
a_f2_cm = np.array(a_f2_cm) * 100
#if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
#fish = '3 fish: $' + label_f_eod_name_core_rm() + '$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$
beat_here = '$c_{1}=%s$' % (
int(np.round(c_nrs[c_nn]))) + '$\%$' + ',\,$c_{2}=%s$' % (
int(np.round(a_f2_cm[0]))) + '$\%$' # +'$'
plt.suptitle(f_vary_name(freq=int(freq1), delta=True)+', '+f_stable_name(freq=int(freq2), delta=True))
title_name = beat_here # fish + '\n' + +c1+c2#twobeat_cond(big=True, double=True,cond=False)
#############################
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1)
axs.text(1, 1.1, title_name, va='bottom', ha='right',
transform=axs.transAxes)
axss.append(axs)
spikes = False
if spikes:
#############################
axt = plt.subplot(grid_pt[0])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
if c_nn == 0:
axt.show_spines('')
axt.xscalebar(0.1, -0.1, 30, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 200, 'Hz', va='left', ha='top')
axp = plt.subplot(grid_pt[-1])
axps.append(axp)
f_counter += 1
''' if ((not os.path.exists(name_psd)) | (redo == True)) & do:
np.save(name_psd, p_arrays_all)
np.save(name_psd_f, ff_p)
else:
ff_p = np.load(name_psd_f) # p_arrays_p
p_arrays_all = np.load(name_psd) # p_arrays_p'''
pps = []
for c_nn, c_nr in enumerate(c_nrs):
for a in range(len(arrays_here)):
axps_here = [[axps[0], axps[1], axps[2], axps[3]]]#[axps[3], axps[4], axps[5]
axp = axps_here[a][c_nn]
pp = log_calc_psd(log, p_arrays_all[c_nn][a][0],
np.nanmax(p_arrays_all))
pps.append(pp)
colors_peaks = [color01, color02, color012, color01_2]
markeredgecolors = [color01, color02, color012, color01_2]
freqs = [np.abs(freq1), np.abs(freq2), np.abs(freq1)+np.abs(freq2), np.abs(np.abs(freq1)-np.abs(freq2))]
plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs,
colors_peaks=colors_peaks, xlim=xlimp,
markeredgecolor=markeredgecolors, )
if log:
scalebar = False
if scalebar:
axp.show_spines('b')
if c_nn == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
axp.set_ylim(-33, 5)
else:
axp.show_spines('lb')
if c_nn == 0:
axp.set_ylabel('dB') # , va='center', ha='left'
else:
remove_yticks(axp)
axp.set_ylim(-39, 5)
else:
axp.show_spines('lb')
if c_nn != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
axp.set_xlabel('Frequency [Hz]')
#if c_nn == 0:
# axp.legend(ncol = 4, loc = (0, 1))#ncols = 4
#a#xts_all.extend(axts)
#axps_all.extend(axps)
#embed()
#ax_us[-1].legend(loc=(-2.22, 1.2), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1)
#axts_all[0].get_shared_y_axes().join(*axts_all)
#axts_all[0].get_shared_x_axes().join(*axts_all)
#axps_all[0].get_shared_y_axes().join(*axps_all)
#axps_all[0].get_shared_x_axes().join(*axps_all)
#axts[0].get_shared_y_axes().join(*axts)
#axts[0].get_shared_x_axes().join(*axts)
axps[0].get_shared_y_axes().join(*axps)
axps[0].get_shared_x_axes().join(*axps)
#join_y(axts)
#set_same_ylim(axts)
set_same_ylim(axps)
#join_x(axts)
#join_x(ax_us)
#join_y(ax_us)
fig = plt.gcf()
#fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7)
#fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7)
fig.tag([axss[0], axss[1], axss[2], axss[3], ax_u1], xoffs=-2.3, yoffs=1.4)
save_visualization(cell_here, show)
if __name__ == '__main__':
#embed()
sys.excepthook = info
nonlin_regime(yposs = [[220, 220, 220, 220]],)#, [430,470], [200,200], [200,200, 200, 200], [200,200, 200, 200], [200,200, 200, 200]