import os import sys import numpy as np import pandas as pd from IPython import embed from matplotlib import gridspec, pyplot as plt from plotstyle import plot_style from threefish.core import find_folder_name, info from threefish.defaults import default_figsize, default_ticks_talks from threefish.load import load_savedir, save_visualization from threefish.plot.limits import join_x, join_y, set_same_ylim from threefish.RAM.calc_fft import log_calc_psd from threefish.RAM.calc_model import chose_old_vs_new_model from threefish.RAM.plot_labels import label_f_eod_name_core_rm, onebeat_cond, remove_yticks from threefish.RAM.plot_subplots import colors_suscept_paper_dots, plt_spikes_ROC, recalc_fr_to_DF1 from threefish.RAM.values import val_cm_to_inch, vals_model_full from threefish.twobeat.calc_model import calc_roc_amp_core_cocktail from threefish.twobeat.colors import colors_susept, twobeat_cond from threefish.twobeat.labels import f_stable_name, f_vary_name from threefish.twobeat.reformat import c_dist_recalc_func, c_dist_recalc_here, dist_recalc_phaselockingchapter, \ find_dfs, \ get_frame_cell_params from threefish.twobeat.subplots import plt_psd_saturation, plt_single_trace, plt_stim_saturation, plt_vmem_saturation, \ power_spectrum_name from threefish.values import values_nfft_full_model, values_stimuluslength_model_full def nonlin_regime(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='', show=True): runs = 1 n = 1 dev = 0.001 #reshuffled = 'reshuffled' # , # standard combination with intruder small min_amps = '_minamps_' dev_name = ['05'] #model_cells = pd.read_csv(find_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") #if len(cells_here) < 1: # cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=5.5) ######################################## # für das model_full, die Freuqnezen DF1_frmult, DF2_frmult = vals_model_full(val=0.30833333333333335) frame_cvs = pd.read_csv(find_folder_name('calc_base')+'/csv_model_data.csv') frame_cell = frame_cvs[frame_cvs.cell == '2012-07-03-ak-invivo-1'] #embed() for d in range(len(DF1_frmult)): #DF2_frmult[d] = str(DF2_frmult[d])+'Fr' #DF1_frmult[d] = str(DF1_frmult[d]) + 'Fr' DF2_frmult[d] = recalc_fr_to_DF1(DF2_frmult, d, frame_cell.fr_data.iloc[0]) DF1_frmult[d] = recalc_fr_to_DF1(DF1_frmult, d, frame_cell.fr_data.iloc[0]) ##DF2_frmult[d] = recalc_fr_to_DF1(DF2_frmult, d, frame_cell.fr_data.iloc[0]) freqs = [(DF1_frmult[3], DF2_frmult[3])] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 cells_here = ['2012-07-03-ak-invivo-1'] cells_here = ["2013-01-08-aa-invivo-1"] for cell_here in cells_here: ########################################### # über die frequenzen hinweg for freq1, freq2 in freqs: # das ist full_names = [ 'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1temporal', 'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_old_fit_temporal'] full_names = ['calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1temporal', 'calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_old_fit_temporal'] full_names = ['calc_model_amp_freqs-_F1_0.22833333333333333Fr_F2_1Fr_af_coupled__C1Len_50_FirstC1_0.0001_LastC1_1.0_FRrelavtiv__start_0.0001_end_1_StimLen_100_nfft_20000_trialsnr_1_mult_minimum_1_power_1_minamps__dev_original_05_point_1_fft2_temporal'] c_grouped = ['c1'] # , 'c2'] c_nrs_orig = [0.01,0.025, 0.04, 0.1] # 0.0002, 0.05, 0.5 trials_nr = 20 # 20 redo = False # True log = 'log' # 'log' grid0 = gridspec.GridSpec(1, 1, bottom=0.15, top=0.88, left=0.11, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.04, hspace=0.45, subplot_spec=grid0[0], height_ratios=[1.5,1],) # height_ratios=[2,1], grid_up = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs_orig), hspace=0.75, wspace=0.1, subplot_spec=grid00[ 0]) # height_ratios=[1, 1, 0.7], 1.2hspace=0.4,wspace=0.2,len(chirps) grid_down = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.75, wspace=0.1, subplot_spec=grid00[1]) # 1.2hspace=0.4,wspace=0.2,len(chirps) for i, full_name in enumerate(full_names): frame = pd.read_csv(find_folder_name('calc_cocktailparty') + '/' + full_name + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 f += 1 ####################################################################################### # übersicht frame_cell = frame_cell_orig[ (frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[ (frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str( frame_cell.f2.unique()) + ' for cell ' + str(cell_here)) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean_original', nr=4) #print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 c_dist_recalc = dist_recalc_phaselockingchapter() c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc) if not c_dist_recalc: c_nrs = np.array(c_nrs) * 100 #embed() letters = ['A', 'B', 'C', 'D'] indexes = [[0, 1, 2, 3]] #scores_all = [scores] #embed() # scores = ['amp_B1_012_mean_original', 'amp_B2_012_mean_original', 'amp_B1+B2_012_mean_original', 'amp_B1-B2_012_mean_original'] scores = ['c_B1_012_original', 'c_B2_012_original', 'c_B1+B2_012_original', 'c_B1-B2_012_original'] color01, color012, color01_2, color02, color0_burst, color0 = colors_suscept_paper_dots() colors = [color01, color02, color012, color01_2] linestyles = ['-','-','-','-'] #frame_cell_orig['amp_B1+B2_012_mean_original'] #frame_cell_orig['amp_B1-B2_012_mean_original'] #for i, index in enumerate(indexes): index = [0, 1, 2, 3] try: ax_u1 = plt.subplot(grid_down[0, i]) except: print('grid search problem4') embed() plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2, scores=np.array(scores)[index], labels=np.array(labels)[index], colors=np.array(colors)[index], linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index], alpha=np.array(alpha)[index], thesum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc) #ax_us.append(ax_u1) frame_cell = frame_cell_orig[ (frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] c1 = c_dist_recalc_here(c_dist_recalc, frame_cell) ax_u1.set_xlim(0, 25) if i != 0: ax_u1.set_ylabel('') remove_yticks(ax_u1) #if i < 2: # ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]], # frame_cell[np.array(scores)[index][1]], color='grey', # alpha=0.1) ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False) # for c_nn, c_nr in enumerate(c_nrs): ax_u1.text(c_nr, yposs[i][c_nn] + 15, letters[c_nn], color='black', ha='center', va='top') # ax_u1.plot([c_nr, c_nr], [0, 435], color='black', linewidth=0.8, clip_on=False) start = 200 # 1000 mults_period = 3 xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] #axts_all = [] #axps_all = [] #ax_us = [] # über die kontraste gehen axts = [] axps = [] axes = [] p_arrays_all = [] model_fit = ''#'_old_fit_' # ''#'_old_fit_'#''#'_old_fit_'#''#'_old_fit_'#''###'_old_fit_' model_cells, reshuffled = chose_old_vs_new_model(model_fit=model_fit) for c_nn, c_nr in enumerate(c_nrs): ################################# # arrays plot save_dir = load_savedir(level=0).split('/')[0] name_psd = save_dir + '_psd.npy' name_psd_f = save_dir + '_psdf.npy' do = False if ((not os.path.exists(name_psd)) | (redo == True)) & do: if log != 'log': stimulus_length_here = 0.5 stimulus_length_here = values_stimuluslength_model_full() nfft_here = 32768 nfft_here = values_nfft_full_model() else: stimulus_length_here = 50 stimulus_length_here = values_stimuluslength_model_full() trials_nr = 1 nfft_here = values_nfft_full_model() else: nfft_here = 2 ** 14 stimulus_length_here = 0.5 stimulus_length_here = values_stimuluslength_model_full() trials_nr = 1 nfft_here = values_nfft_full_model() # # # a_f2s = [c_nrs_orig[c_nn]] _, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length_here, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) # ff_p, arrays, names, p_arrays_p, arrays_spikes, arrays_stim, p_arrays_here = [p_arrays_p[3]] xlimp = (0, 300) for p in range(len(p_arrays_here)): p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]] ff_p = ff_p[ff_p < xlimp[1]] time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 # plot the first array arrays_time = [arrays[3]] # [v_mems[1],v_mems[3]]#[1,2]#[1::] arrays_here = [arrays[3]] # [arrays[1],arrays[3]]#arrays[1::]# arrays_st = [arrays_stim[3]] #1:: [arrays_stim[1],arrays_stim[3]]# arrays_sp = [arrays_spikes[3]] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::] colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_all.append(p_arrays_here) for a in range(len(arrays_here)): print('a' + str(a)) if a == 0: freqs = [np.abs(freq1)] # ], np.abs(freq2)], elif a == 1: freqs = [np.abs(freq2)] else: freqs = [np.abs(freq1), np.abs(freq2)] grid_pt = gridspec.GridSpecFromSubplotSpec(2, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_up[a, c_nn], ) # hspace=0.4,wspace=0.2,len(chirps) stim = False if stim: axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here, c_dist_recalc=c_dist_recalc) if not c_dist_recalc: a_f2_cm = np.array(a_f2_cm) * 100 #if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): #fish = '3 fish: $' + label_f_eod_name_core_rm() + '$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$ beat_here = '$c_{1}=%s$' % ( int(np.round(c_nrs[c_nn]))) + '$\%$' + ',\,$c_{2}=%s$' % ( int(np.round(a_f2_cm[0]))) + '$\%$' # +'$' plt.suptitle(f_vary_name(freq=int(freq1), delta=True)+f_stable_name(freq=int(freq2), delta=True)) title_name = beat_here # fish + '\n' + +c1+c2#twobeat_cond(big=True, double=True,cond=False) ############################# spikes = False if spikes: axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1) ############################# axt = plt.subplot(grid_pt[0]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) axt.text(1, 1.1, title_name, va='bottom', ha='right', transform=axt.transAxes) axp = plt.subplot(grid_pt[-1]) axps.append(axp) if c_nn == 0: axt.show_spines('') axt.xscalebar(0.1, -0.1, 30, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 200, 'Hz', va='left', ha='top') f_counter += 1 ''' if ((not os.path.exists(name_psd)) | (redo == True)) & do: np.save(name_psd, p_arrays_all) np.save(name_psd_f, ff_p) else: ff_p = np.load(name_psd_f) # p_arrays_p p_arrays_all = np.load(name_psd) # p_arrays_p''' for c_nn, c_nr in enumerate(c_nrs): for a in range(len(arrays_here)): axps_here = [[axps[0], axps[1], axps[2], axps[3]]]#[axps[3], axps[4], axps[5] axp = axps_here[a][c_nn] pp = log_calc_psd(log, p_arrays_all[c_nn][a][0], np.nanmax(p_arrays_all)) colors_peaks = [color01, color02, color012, color01_2] markeredgecolors = [color01, color02, color012, color01_2] freqs = [np.abs(freq1), np.abs(freq2), np.abs(freq1)+np.abs(freq2), np.abs(np.abs(freq1)-np.abs(freq2))] plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs, colors_peaks=colors_peaks, xlim=xlimp, markeredgecolor=markeredgecolors, ) if log: scalebar = False if scalebar: axp.show_spines('b') if c_nn == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') axp.set_ylim(-33, 5) else: axp.show_spines('lb') if c_nn == 0: axp.set_ylabel('dB') # , va='center', ha='left' else: remove_yticks(axp) axp.set_ylim(-39, 5) else: axp.show_spines('lb') if c_nn != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) axp.set_xlabel('Frequency [Hz]') #a#xts_all.extend(axts) #axps_all.extend(axps) #ax_us[-1].legend(loc=(-2.22, 1.2), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1) #axts_all[0].get_shared_y_axes().join(*axts_all) #axts_all[0].get_shared_x_axes().join(*axts_all) #axps_all[0].get_shared_y_axes().join(*axps_all) #axps_all[0].get_shared_x_axes().join(*axps_all) axts[0].get_shared_y_axes().join(*axts) axts[0].get_shared_x_axes().join(*axts) axps[0].get_shared_y_axes().join(*axps) axps[0].get_shared_x_axes().join(*axps) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) #join_x(ax_us) #join_y(ax_us) fig = plt.gcf() #fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7) #fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7) #fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4) save_visualization(cell_here, show) if __name__ == '__main__': #embed() sys.excepthook = info nonlin_regime(yposs = [[270,270, 270, 270]],)#, [430,470], [200,200], [200,200, 200, 200], [200,200, 200, 200], [200,200, 200, 200]