import ast import csv import warnings from random import sample import numpy import seaborn as sns from scipy.optimize import curve_fit from scipy.signal import vectorstrength from scipy.stats import alpha, gaussian_kde from sklearn import metrics from sklearn.linear_model import LinearRegression from thunderfish.eventdetection import hist_threshold # from utils_all import cr_spikes_mat, load_folder_name warnings.filterwarnings("ignore", message="WARNING:root:MultiTag type relacs.stimulus.segment") from scipy import optimize, stats '''try: from utils_all import default_settings, load_folder_name except: # das ist das gleiche wie drüber nur dass es einen anderen Namen hat from utils_all_down import column2,find_mean_period, default_settings, load_folder_name, chose_mat_max_value, \ create_stimulus_SAM, \ default_settings, find_code_vs_not, load_folder_name, plt_peaks, \ plt_peaks_several, resave_small_files, \ restrict_cell_type, thresh_crossings, zenter_and_normalize''' try: from utils_all import * except: # das ist das gleiche wie drüber nur dass es einen anderen Namen hat from utils_all_down import * try: pass except: a = 0 try: import nixio as nix except: print('nixio not there') import numpy as np import pandas as pd import scipy from IPython import embed from matplotlib import gridspec, pyplot as plt, pyplot, ticker as ticker import os import matplotlib.mlab as ml import matplotlib.gridspec as gridspec from scipy.ndimage import gaussian_filter from thunderfish import fakefish try: import rlxnix as rlx except: a = 5 try: from numba import jit except ImportError: def jit(): def decorator_jit(func): return func return decorator_jit import inspect if 'cv_cell_types' not in inspect.stack()[-1][1]: try: from plotstyle import plot_style, plot_style as style, spines_params except: a = 5 import itertools as it def plot_rec_stimulus(grid, transform_fact, stimulus, color1, time, counter, eod_fr, deltat, nfft, xlim=0.05, shift=0, lw=0.5): axt = plt.subplot(grid[0]) time_here = (time[0:len(stimulus)] - shift) * transform_fact stim_here = stimulus[time_here < xlim * transform_fact] extracted, _ = extract_am(stimulus, time / 1000, norm=False, extract='globalmax', sampling=1 / deltat, eodf=eod_fr) # time_here_here/1000 extracted_here = extracted[time_here < xlim * transform_fact] time_here_here = time_here[time_here < xlim * transform_fact] axt.plot(time_here_here, stim_here, color=color1, linewidth=lw) axt.plot(time_here_here, extracted_here, color='red', linewidth=1) counter += 1 # am_time*1000 axt.set_xlim(0, xlim * transform_fact) axt.set_ylim(-1.2, 1.7) axt.show_spines('lb') axt.axhline(0, color='black', lw=0.5) axt.set_xticks_blank() axp = plt.subplot(grid[1]) ff, pp = calc_psd(stimulus, deltat, nfft) axp.set_xticks_blank() return counter, axt, ff, pp, axp def plot_lowpass(g_p_t, transform_fact, time, shift, v_dent_output, color1, deltat, fft_type, nfft, eod_fr, extract=True, lw=0.5, xlim=0.05): ff, ff_am, pp, pp_am, time_here, extracted = time_psd_calculation(deltat, eod_fr, extract, fft_type, nfft, shift, time, transform_fact, v_dent_output) axt_p2 = plt_time_arrays(color1, g_p_t, lw, v_dent_output, extracted, xlim, time_here, transform_fact=transform_fact) axp_p2 = plt.subplot(g_p_t[1]) axp_p2.set_xticks_blank() return axt_p2, ff, pp, axp_p2, ff_am, pp_am def plt_time_arrays_here(color1, g_p_t, lw, time_here, transform_fact, v_dent_output, xlim): axt_p2 = plt.subplot(g_p_t[0]) axt_p2.plot(time_here, v_dent_output, color=color1, linewidth=lw) axt_p2.show_spines('lb') # am_time*1000 axt_p2.set_xlim(0.0, xlim * transform_fact) axt_p2.axhline(0, color='black', lw=0.5) return axt_p2 def model_sheme_only(grid_sheme, stimulus_length=5, a_fr=1, a_fe=0.2, v_exp=1, exp_tau=0.1): # need to reduce parameters # parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0) # load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv" # parameters = pd.read_csv(load_name) # potentiell zellen wo der range nicht zu weit ist : 0, 9 # problem 0, 10: spikt manchmal # problem 9: presynaptic oscilation nicht so schön # gute Zellen: 5 # ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus # good_cells = pd.read_csv("good_model_cells.csv", index_col=0) # 2 ist wohl sehr nah, das geht! # 0,1,3,7 weit entfernt # 4,5,6 sehr weit entfernt # 8 ist wohl am nächsten! # embed() # model_params = parameters.iloc[0] # model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2] # model_params = load_model(load_name=load_name, cell_nr = cell_nr) models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0] eod_fr = model_params.pop('EODf') # .iloc[0] deltat = model_params.pop("deltat") # .iloc[0] eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50 # REMAINING rows color_p3 = 'grey' # 'red'#palette['red'] color_p1 = 'grey' # 'blue'#palette['blue'] counter_here = 0 grid_sheme = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid_sheme, wspace=0.2, hspace=0.95) counter_g = 0 for mult_nr in range(len(eod_fe)): try: time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters( stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr) except: print('parameter thing6') embed() colorful_title = False # einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn eod_fish_r_rec = eod_fish_r * 1 eod_fish_r_rec[eod_fish_r_rec < 0] = 0 _, _, _, _, _ = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1, color_p3, mult_nr, eod_fr, eod_fe, stimulus, stimulus_rec, colorful_title) # for g, stimulus_here in enumerate([stimuli[1]]): # And plot correspoding sheme axsheme = plt.subplot(grid_sheme[0]) plot_sheme_nonlinearity(axsheme, color_p1) # SECOND Row: Dendridic Low pass filter axsheme = plt.subplot(grid_sheme[1]) plot_sheme_lowpass(axsheme) # THIRD /FORTH Row: LIF /EIF axsheme = plt.subplot(grid_sheme[2]) exponential = '' plot_sheme_IF(axsheme, exp_tau, v_exp, exponential) counter_g += 1 counter_here += 1 def model_and_data_vertical(nr_clim=10, many=False, width=0.005, row='no', HZ50=True, fs=8, hs=0.39, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'], contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], clims='all', restrict='restrict'): plot_style() default_settings(lw=0.5, column=2, length=8.5) stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 repeats = [30, 100000] # , aa = 0 _, _ = overlap_cells() cells_all = ['2012-07-03-ak-invivo-1', '2018-05-08-ae-invivo-1', '2011-10-25-ad-invivo-1'] # good_data[, good_data = cells_all for _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, in it.product( cells, D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 if row == 'no': col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange( else: pass if row == 2: default_settings(column=2, length=7.5) # 2+2.25+2.25 elif row == 1: default_settings(column=2, length=4) row = 5 fig = plt.figure(figsize=(6.8, 7.5)) grid_orig = gridspec.GridSpec(1, 2, wspace=0.15, bottom=0.05, hspace=0.1, left=0.07, width_ratios=[4, 1.3], right=0.99, top=0.88) # , height_ratios = [0.4,3] # plot lower part grid_lower = gridspec.GridSpecFromSubplotSpec(4, 1, grid_orig[0], wspace=0.05, hspace=0.53, height_ratios=[0.2, 1, 1, 1]) wr = [1, 1, 1] if row == 2: plt.subplots_adjust(bottom=0.067, wspace=0.45, top=0.81, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 elif row == 1: plt.subplots_adjust(bottom=0.1, wspace=0.45, top=0.81, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 else: plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 a = 0 maxs = [] mins = [] ims = [] perc05 = [] perc95 = [] iternames = [D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' for all in it.product(*iternames): var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all hs = 0.25 ################################# # model cells ax_model = [] for t, trials_stim in enumerate(repeats): grid_model = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[2 + t], hspace=hs, width_ratios=wr) save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') # '../calc_model/noise2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' path = save_name + '.pkl' model = load_model_susept(path, cells_all, save_name) # cells adapt_type_name, ref_type_name, dendrid_name, stim_type_noise_name = define_names(var_type, stim_type_noise, dendrid, ref_type, adapt_type) cells_all = model.groupby('cv_stim').first().sort_values(by='cv_stim').cell # ('cv_stim') for c, cell in enumerate(cells_all): print(c) ax = plt.subplot(grid_model[c]) # grid_30_s[1] if len(model) > 0: stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) suptitles, titles = titles_susept_names(a_fe, extract, noise_added, stim_type_afe_name, stim_type_noise_name2, trials_stim, var_items, var_type) # find_titles_susept(a_fe, cell, extract, noise_added, stim_type_afe_name, # stim_type_noise_name2, suptitles, titles, trials_stim, # var_items, var_type) model_show = model[ (model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] new_keys = model_show.index.unique() # [0:490] try: stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) ax.set_xlim(0, 300) ax.set_ylim(0, 300) ax.set_aspect('equal') ax.set_xticks_delta(100) ax.set_yticks_delta(100) model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='') stack_plot = RAM_norm(stack_plot, trials_stim, D_derived) if many == True: titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0])) color = title_color(cell) print(color) if t == 0: ax.set_title( titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str( int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str( np.round(model_show.cv.iloc[0], 2)) + \ ' $cv_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) + ' $D_{sig}$=' + str( np.round(D_derived, 5)) + ' s=' + str( np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color) perc = '' # 'perc' im = plt_RAM_perc(ax, perc, stack_plot) ims.append(im) maxs.append(np.max(np.array(stack_plot))) mins.append(np.min(np.array(stack_plot))) perc05.append(np.percentile(stack_plot, 5)) perc95.append(np.percentile(stack_plot, 95)) plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300, model_show.eod_fr.iloc[0]) if HZ50: plt_50_Hz_noise(ax, 300) ax.set_aspect('equal') cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, width=width) if c == 0: ax.set_ylabel(F2_xlabel()) else: remove_yticks(ax) if c == 2: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) if t == 1: ax.set_xlabel(F1_xlabel(), labelpad=20) else: remove_xticks(ax) print(c) ax_model.append(ax) a += 1 model_sheme_in_one(grid_orig[1]) # grid_sheme grid_lower[3] ################################# # data cells grid_data = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[1], hspace=hs, width_ratios=wr) grid_isi = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[0], hspace=hs, width_ratios=wr) frame = load_cv_base_frame(cells_all) ax_isi = [] for f, cell in enumerate(cells_all): ax = plt.subplot(grid_data[f]) if f == 2: plot = True else: plot = False ax_data = plt_data_up(cell, ax, fig, cells_all, cell_type='p-unit', cbar_label=plot, width=width) if f == len(cells) - 1: ax.set_ylabel(F2_xlabel()) # else: remove_yticks(ax) remove_xticks(ax) axi = plt.subplot(grid_isi[f]) # grid_30_d[0] frame_cell = frame[(frame['cell'] == cell)] spikes = frame_cell.spikes.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) alpha = 1 for hh, h in enumerate(hists): axi.hist(h, bins=100, color='blue', alpha=float(alpha - 0.05 * hh)) ax_isi.append(axi) axi.set_ylabel('Nr') if f == len(cells_all) - 1: axi.set_xlabel('EODf multiple') axi.set_ylabel('Nr') ax_isi[0].get_shared_x_axes().join(*ax_isi) end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title) # +' file ' set_clim_shared(clims, ims, maxs, mins, nr_clim, perc05, perc95) axes = np.array([np.array(ax_data), np.array(ax_model[0:int(len(ax_model) / 2)]), np.array(ax_model[int(len(ax_model) / 2)::]), np.array(ax_isi)]) fig.tag(np.transpose(axes), xoffs=-3, yoffs=1.2, minor_index=2) save_visualization(pdf=True) def model_sheme_in_one(grid_sheme, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5, fft_type='mppsd', a_fr=1, a_fe=0.2, v_exp=1, exp_tau=0.1, counter=0, shift=0.25): # need to reduce parameters # parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0) # load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv" # parameters = pd.read_csv(load_name) # potentiell zellen wo der range nicht zu weit ist : 0, 9 # problem 0, 10: spikt manchmal # problem 9: presynaptic oscilation nicht so schön # gute Zellen: 5 # ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus # good_cells = pd.read_csv("good_model_cells.csv", index_col=0) # 2 ist wohl sehr nah, das geht! # 0,1,3,7 weit entfernt # 4,5,6 sehr weit entfernt # 8 ist wohl am nächsten! # embed() # model_params = parameters.iloc[0] # model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2] load_name = 'models_big_fit_d_right.csv' models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0] eod_fr = model_params.pop('EODf') # .iloc[0] deltat = model_params.pop("deltat") # .iloc[0] v_offset = model_params.pop("v_offset") # .iloc[0] eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50 # REMAINING rows color_p3 = 'grey' # 'red'#palette['red'] color_p1 = 'grey' # 'blue'#palette['blue'] color_diagonal = 'grey' # 'cyan'#palette['cyan'] colors = [color_diagonal, color_p1, color_p1, color_p3] ax_rec = [[]] * 4 ax_low = [[]] * 4 axt_IF2 = [] delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0] counter_here = 0 nrs = [1, 2, 3, 4] # first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3 grid0 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_sheme, width_ratios=[1, 3], wspace=0.35) # Grid for the sheme try: pass except: print('grid thing5') embed() lw = 0.5 wr = [1, 1.2] xlim = 0.065 axps = [] axps_lowpass = [] axps_stimulus = [] pps = [] pps_lowpass = [] pps_stimulus = [] colors_chosen = [] counter_g = 0 for mult_nr in range(len(eod_fe)): try: time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters( stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr) except: print('parameter thing5') embed() colorful_title = False # einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn eod_fish_r_rec = eod_fish_r * 1 eod_fish_r_rec[eod_fish_r_rec < 0] = 0 add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1, color_p3, mult_nr, eod_fr, eod_fe, stimulus, stimulus_rec, colorful_title) stimulus_here = do_withenoise_stimulus(deltat, eod_fr, stimulus_length) titles = [titles[1]] g = 0 color = colors[counter_here] # A grid for a single POWER column grid_power_col = gridspec.GridSpecFromSubplotSpec(6, 1, subplot_spec=grid0[nrs[counter_here]], height_ratios=[0.7, 1, 0.7, 1, 0.7, 1], wspace=0.45, hspace=0.65) # FIRST Row: Rectified stimulus grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[1], wspace=ws, hspace=1.3, width_ratios=wr) counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform, stimulus_here, color, time, counter, eod_fr, deltat, nfft, shift=shift, lw=lw, xlim=xlim) pps_stimulus.append(pp) axps_stimulus.append(axp) colors_chosen.append(color) if counter_here == 0: ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') ax_rec[counter_here].text(add_pos[g], 1.1, titles[g], transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right', # And plot correspoding sheme axsheme = plt.subplot(grid_power_col[0]) plot_sheme_nonlinearity(axsheme, color_p1) # REMAINING Rows: dendridic filter / LIF /EIF stimulus exponential = '' # , 'EIF' manual_offset = False if manual_offset: spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, eod_fish_rs[g], deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) print('Firing rate baseline ' + str(len(spike_times) / stimulus_length)) spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, stimulus_here, deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) # SECOND Row: Dendridic Low pass filter grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[3], width_ratios=wr, wspace=ws, hspace=1.3) ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform, time, shift, v_dent_output, color, deltat, fft_type, nfft, eod_fr, xlim=xlim, lw=lw) pps_lowpass.append(pp) axps_lowpass.append(axp_p2) colors_chosen.append(color) if counter_here == 0: ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') axsheme = plt.subplot(grid_power_col[2]) plot_sheme_lowpass(axsheme) # THIRD /FORTH Row: LIF /EIF # plot the voltage of the exponentials grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=wr, subplot_spec=grid_power_col[5], wspace=ws, hspace=1.45) axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color, spike_times, shift, deltat, fft_type, nfft, eod_fr, xlim=xlim, exponential=exponential, counter_here=counter_here) # , add = add axps.append(axp_s) pps.append(pp_s) colors_chosen.append('black') axt_IF2.append(axt_IF) if g == 0: axsheme = plt.subplot(grid_power_col[4]) # grid_sheme[ee + 2] plot_sheme_IF(axsheme, exp_tau, v_exp, exponential) ################################ # here plot the amplitude modulation relationship counter_g += 1 counter_here += 1 # plot psd with shared log lim #################################### # cut first parts # because otherwise there is a dip at the beginning and thats a problem for the range thing ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0) # here I calculate the log and do the same range for all power spectra # this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down # not to see the peaks in the noise pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True) axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ]) pps3 = create_same_max(pps, same=True) pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3])) axps_stimulus = np.concatenate([axps_stimulus, axps]) # there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm # therefore we dont show it colors = [color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, ] plot_points = [[], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', ] for a, axp in enumerate(axps_stimulus): lw_p = 0.8 plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr) if a % 4 == 3: axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right') axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus) def model_sheme(grid_sheme, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5, fft_type='mppsd', a_fr=1, a_fe=0.2, v_exp=1, exp_tau=0.1, counter=0, shift=0.25): # need to reduce parameters # parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0) # load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv" # parameters = pd.read_csv(load_name) # potentiell zellen wo der range nicht zu weit ist : 0, 9 # problem 0, 10: spikt manchmal # problem 9: presynaptic oscilation nicht so schön # gute Zellen: 5 # ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus # good_cells = pd.read_csv("good_model_cells.csv", index_col=0) # 2 ist wohl sehr nah, das geht! # 0,1,3,7 weit entfernt # 4,5,6 sehr weit entfernt # 8 ist wohl am nächsten! # embed() # model_params = parameters.iloc[0] # model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2] load_name = 'models_big_fit_d_right.csv' models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0] eod_fr = model_params.pop('EODf') # .iloc[0] deltat = model_params.pop("deltat") # .iloc[0] v_offset = model_params.pop("v_offset") # .iloc[0] eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50 # REMAINING rows color_p3 = 'grey' # 'red'#palette['red'] color_p1 = 'grey' # 'blue'#palette['blue'] color_diagonal = 'grey' # 'cyan'#palette['cyan'] colors = [color_diagonal, color_p1, color_p1, color_p3] ax_rec = [[]] * 4 ax_low = [[]] * 4 axt_IF1 = [] axt_IF2 = [] delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0] counter_here = 0 nrs = [1, 2, 3, 4] # first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3 grid0 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_sheme, width_ratios=[1, 3], wspace=0.35) # Grid for the sheme try: grid_sheme = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid0[0], wspace=0.2, hspace=0.95) except: print('grid thing2') embed() lw = 0.5 wr = [1, 1.2] xlim = 0.065 axps = [] axps_lowpass = [] axps_stimulus = [] pps = [] pps_lowpass = [] pps_stimulus = [] colors_chosen = [] counter_g = 0 for mult_nr in range(len(eod_fe)): try: time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters( stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr) except: print('parameter thing3') embed() colorful_title = False # einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn eod_fish_r_rec = eod_fish_r * 1 eod_fish_r_rec[eod_fish_r_rec < 0] = 0 add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1, color_p3, mult_nr, eod_fr, eod_fe, stimulus, stimulus_rec, colorful_title) stimulus_here = do_withenoise_stimulus(deltat, eod_fr, stimulus_length) titles = [titles[1]] g = 0 color = colors[counter_here] # A grid for a single POWER column grid_power_col = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid0[nrs[counter_here]], height_ratios=[1, 1, 1], wspace=0.45, hspace=0.5) # FIRST Row: Rectified stimulus grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[0], wspace=ws, hspace=1.3, width_ratios=wr) counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform, stimulus_here, color, time, counter, eod_fr, deltat, nfft, shift=shift, lw=lw, xlim=xlim) pps_stimulus.append(pp) axps_stimulus.append(axp) colors_chosen.append(color) if counter_here == 0: ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') ax_rec[counter_here].text(add_pos[g], 1.1, titles[g], transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right', # And plot correspoding sheme if g == 0: axsheme = plt.subplot(grid_sheme[0]) plot_sheme_nonlinearity(axsheme, color_p1) # REMAINING Rows: dendridic filter / LIF /EIF stimulus exponentials = [''] # , 'EIF' for ee, exponential in enumerate(exponentials): manual_offset = False if manual_offset: spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, eod_fish_rs[g], deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) print('Firing rate baseline ' + str(len(spike_times) / stimulus_length)) spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, stimulus_here, deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) if ee == 0: # SECOND Row: Dendridic Low pass filter grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[1], width_ratios=wr, wspace=ws, hspace=1.3) ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform, time, shift, v_dent_output, color, deltat, fft_type, nfft, eod_fr, xlim=xlim, lw=lw) pps_lowpass.append(pp) axps_lowpass.append(axp_p2) colors_chosen.append(color) if counter_here == 0: ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') if g == 0: axsheme = plt.subplot(grid_sheme[1]) plot_sheme_lowpass(axsheme) # THIRD /FORTH Row: LIF /EIF # plot the voltage of the exponentials grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=wr, subplot_spec=grid_power_col[ee + 2], wspace=ws, hspace=1.45) axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color, spike_times, shift, deltat, fft_type, nfft, eod_fr, xlim=xlim, exponential=exponential, counter_here=counter_here) # , add = add axps.append(axp_s) pps.append(pp_s) colors_chosen.append('black') if ee == 0: axt_IF1.append(axt_IF) else: axt_IF2.append(axt_IF) if g == 0: axsheme = plt.subplot(grid_sheme[ee + 2]) plot_sheme_IF(axsheme, exp_tau, v_exp, exponential) ################################ # here plot the amplitude modulation relationship counter_g += 1 counter_here += 1 #################################### # cut first parts # because otherwise there is a dip at the beginning and thats a problem for the range thing ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0) # here I calculate the log and do the same range for all power spectra # this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down # not to see the peaks in the noise pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True) axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ]) pps3 = create_same_max(pps, same=True) pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3])) axps_stimulus = np.concatenate([axps_stimulus, axps]) # there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm # therefore we dont show it colors = [color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, ] plot_points = [[], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', ] for a, axp in enumerate(axps_stimulus): lw_p = 0.8 plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr) if a % 4 == 3: axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right') axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus) def model_sheme_vertical(grid_sheme_orig, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5, fft_type='mppsd', a_fr=1, a_fe=0.2, v_exp=1, exp_tau=0.1, counter=0, shift=0.25): # need to reduce parameters # parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0) # load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv" # parameters = pd.read_csv(load_name) # potentiell zellen wo der range nicht zu weit ist : 0, 9 # problem 0, 10: spikt manchmal # problem 9: presynaptic oscilation nicht so schön # gute Zellen: 5 # ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus # good_cells = pd.read_csv("good_model_cells.csv", index_col=0) # 2 ist wohl sehr nah, das geht! # 0,1,3,7 weit entfernt # 4,5,6 sehr weit entfernt # 8 ist wohl am nächsten! # embed() # model_params = parameters.iloc[0] load_name = 'models_big_fit_d_right.csv' models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') deltat, eod_fr, model_params, v_offset = get_model_params(models, cell='2012-07-03-ak-invivo-1') eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50 # REMAINING rows color_p3 = 'grey' # 'red'#palette['red'] color_p1 = 'grey' # 'blue'#palette['blue'] color_diagonal = 'grey' # 'cyan'#palette['cyan'] colors = [color_diagonal, color_p1, color_p1, color_p3] ax_rec = [[]] * 4 ax_low = [[]] * 4 axt_IF1 = [] delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0] counter_here = 0 # first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3 grid0 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid_sheme_orig, wspace=0.35) # height_ratios=[1, 3] # Grid for the sheme lw = 0.5 wr = [1, 1.2] hr = [1] xlim = 0.065 axps = [] axps_lowpass = [] axps_stimulus = [] pps = [] pps_lowpass = [] pps_stimulus = [] colors_chosen = [] for mult_nr in range(len(eod_fe)): try: time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters( stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr) except: print('parameter thing2') embed() colorful_title = False # einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn eod_fish_r_rec = eod_fish_r * 1 eod_fish_r_rec[eod_fish_r_rec < 0] = 0 add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1, color_p3, mult_nr, eod_fr, eod_fe, stimulus, stimulus_rec, colorful_title) sampling = 1 / deltat time_eod = np.arange(0, stimulus_length, deltat) eod_interp, time_wn_cut, _ = load_noise('gwn300Hz50s0.3') eod_interp = interpolate(time_wn_cut, eod_interp, time_eod, kind='cubic') fake_fish = fakefish.wavefish_eods('Alepto', frequency=eod_fr, samplerate=sampling, duration=len(time_eod) / sampling, phase0=0.0, noise_std=0.00) stimulus_here = fake_fish * (1 + eod_interp * 0.2) titles = [titles[1]] g = 0 color = colors[counter_here] hs = 0.2 # 1.3 # A grid for a single POWER column grid_power_col = gridspec.GridSpecFromSubplotSpec(1, 6, subplot_spec=grid0[0], width_ratios=[0.5, 1, 0.5, 1, 0.5, 1], wspace=0.45, hspace=0.5) # FIRST Row: Rectified stimulus grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[1], wspace=ws, hspace=hs, width_ratios=wr, height_ratios=hr) counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform, stimulus_here, color, time, counter, eod_fr, deltat, nfft, shift=shift, lw=lw, xlim=xlim) pps_stimulus.append(pp) axps_stimulus.append(axp) colors_chosen.append(color) if counter_here == 0: ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') ax_rec[counter_here].text(add_pos[g], 1.1, titles[g], transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right', # And plot correspoding sheme axsheme = plt.subplot(grid_power_col[0]) plot_sheme_nonlinearity(axsheme, color_p1) # REMAINING Rows: dendridic filter / LIF /EIF stimulus exponential = '' # , 'EIF' manual_offset = False if manual_offset: spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, eod_fish_rs[g], deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) print('Firing rate baseline ' + str(len(spike_times) / stimulus_length)) spike_times, v_dent_output, v_mem_output = simulate2(load_name, v_offset, stimulus_here, deltat=deltat, exponential=exponential, v_exp=v_exp, exp_tau=exp_tau, **model_params) # SECOND Row: Dendridic Low pass filter grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_power_col[3], height_ratios=hr, width_ratios=wr, wspace=ws, hspace=hs) ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform, time, shift, v_dent_output, color, deltat, fft_type, nfft, eod_fr, xlim=xlim, lw=lw) pps_lowpass.append(pp) axps_lowpass.append(axp_p2) colors_chosen.append(color) if counter_here == 0: ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center') axsheme = plt.subplot(grid_power_col[2]) plot_sheme_lowpass(axsheme) # THIRD /FORTH Row: LIF /EIF # plot the voltage of the exponentials grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, height_ratios=hr, width_ratios=wr, subplot_spec=grid_power_col[5], wspace=ws, hspace=hs) axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color, spike_times, shift, deltat, fft_type, nfft, eod_fr, xlim=xlim, exponential=exponential, counter_here=counter_here) # , add = add axps.append(axp_s) pps.append(pp_s) colors_chosen.append('black') axt_IF1.append(axt_IF) axsheme = plt.subplot(grid_power_col[4]) plot_sheme_IF(axsheme, exp_tau, v_exp, exponential) # plot psd with shared log lim #################################### # cut first parts # because otherwise there is a dip at the beginning and thats a problem for the range thing ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0) # here I calculate the log and do the same range for all power spectra # this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down # not to see the peaks in the noise pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True) axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ]) pps3 = create_same_max(pps, same=True) pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3])) axps_stimulus = np.concatenate([axps_stimulus, axps]) # there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm # therefore we dont show it colors = [color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, color_diagonal, color_p1, color_p1, color_p3, ] plot_points = [[], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', [], 'yes', ] for a, axp in enumerate(axps_stimulus): lw_p = 0.8 plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr) if a % 4 == 3: axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right') axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus) def get_model_params(models, cell='2012-07-03-ak-invivo-1'): model_params = models[models['cell'] == cell].iloc[0] eod_fr = model_params.pop('EODf') # .iloc[0] deltat = model_params.pop("deltat") # .iloc[0] v_offset = model_params.pop("v_offset") # .iloc[0] return deltat, eod_fr, model_params, v_offset def share_yaxis(axes): for ax in axes: # , axt_IF2 ax[0].get_shared_y_axes().join(*ax) maxs = [] mins = [] for a in ax: maxs.append(np.nanmax(a.get_ylim())) mins.append(np.nanmin(a.get_ylim())) for a in ax: a.set_ylim(np.min(mins), np.max(maxs)) def flowchart(): default_settings(column=2, length=7) cell = "2013-01-08-aa-invivo-1" model_cells = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') model_params = model_cells[model_cells.cell == cell].iloc[0] noise_strength = model_params.noise_strength # **2/2 a_fr = 1 # ,0]#0,,0,]#,0 ,0 ] # ,0,]#0#1 eod_fr = model_params['EODf'] deltat = model_params.pop("deltat") cut_offs = [eod_fr / 2] var_type = 'additiv_cutoff_scaled' # ]#'additiv_visual_d_4_scaled'] fig, ax = plt.subplots(6, (len(cut_offs) + 1) * len(var_types)) # , constrained_layout = True figsize=(12, 5), colors_title = [['black', 'purple', 'black', 'black', 'black', 'black'], ['black', 'black', 'black', 'black', 'black', 'purple']] d_new_zeros = {} tags = [] c_sig = 0.9 c_noise = 0.1 d_new_zeros[var_type] = [] arrays2 = np.load(load_folder_name('calc_RAM') + '\RAM_extraction_a1_' + var_type + '.npy', allow_pickle=True) arrays = np.load(load_folder_name('calc_RAM') + '\RAM_extraction_a0_' + var_type + '.npy', allow_pickle=True) titles = ['Noise', 'RAM', 'RAM*Carrier', 'RAM*Carrier to RAM', 'V_dent', 'V_dent to RAM'] colors = ['grey', 'red', 'grey', 'red', 'grey', 'red'] max_f = 9999.875 # hab ich aus np.fft.fft(noise) var_desired = '$var_{desired} = $' + str(np.round(np.var(arrays[0]) * c_sig * cut_off / max_f, 5)) D_desired = np.round(np.sqrt(noise_strength * 2 * c_sig), 5) plt.suptitle( '$Contrast_{receiver}=$' + str(a_fr) + ', $c_{noise}=$' + str(c_noise) + ', $c_{signal}=$' + str( c_sig) + ', ' + var_desired + r' $\sqrt{2D*c_{signal}}$=' + str( np.round(D_desired, 5))) grid = gridspec.GridSpecFromSubplotSpec(1, 4, grid_orig[0], wspace=1.2, hspace=0.13) for i in range(len(arrays)): sampling = 1 / deltat ax = plt.subplot(grid[i]) if len(np.arange(0, len(arrays[i]) / sampling, 1 / sampling)) > len(arrays[i]): ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays[i]) / sampling, 1 / sampling)[0:-1], arrays[i], color=colors[i]) else: ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays[i]) / sampling, 1 / sampling), arrays[i], color=colors[i]) if len(arrays2[i]) > 0: if len(np.arange(0, len(arrays2[i]) * deltat, deltat)) > len(arrays2[i]): ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays2[i]) * deltat, deltat)[0:-1], arrays2[i], color='red') else: ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays2[i]) * deltat, deltat), arrays2[i], color='red') tags.append(ax[0 + i, 0 + v * 2]) ax[0 + i, 0 + v * 2].set_title(titles[i] + ' var=' + str(np.round(np.var(arrays[i]), 5)), color=colors_title[v][i], fontsize=8) # +' var/c='+str(np.round(np.var(arrays[i])/cs[i],5)) ax[0 + i, 0 + v * 2].set_xlim(0, 0.1) p_array_fft = np.fft.fft(arrays[i] - np.mean(arrays[i]), norm='forward') f = np.fft.fftfreq(len(arrays[i]), deltat) f_sorted = np.sort(f) p_sorted = np.abs(p_array_fft)[np.argsort(f)] ax[0 + i, 1 + v * 2].plot(f_sorted, p_sorted, color='grey') # np.log10(p_noise / np.max(p_noise)) left = np.argmin(np.abs(f_sorted) - 0) - 10 left2 = np.argmin(np.abs(f_sorted) - 0) d_new_zero = np.mean(p_sorted[left:left2]) ax[0 + i, 1 + v * 2].plot(f_sorted[left:left2], p_sorted[left:left2], color='blue') d_new_zeros[var_type].append(d_new_zero) ax[0 + i, 1 + v * 2].set_title('D close to 0 = ' + str(np.round(d_new_zero, 5)), color=colors_title[v][i]) ax[0 + i, 1 + v * 2].set_xlim(-eod_fr / 2 * 1.2, eod_fr / 2 * 1.2) if i < len(arrays) - 1: remove_yticks(ax[0 + i, 0]) remove_yticks(ax[0 + i, 1]) remove_yticks(ax[0 + i, 0 + v * 2]) remove_yticks(ax[0 + i, 1 + v * 2]) ax[0, 0 + v * 2].text(0, 1.6, var_type + ': $D_{RAM}/V_dent_{RAM} =$' + str( np.round(d_new_zeros[var_type][1] / d_new_zeros[var_type][-1], 3)), transform=ax[0, 0 + v * 2].transAxes, color='purple') ax[-1, 0 + v * 2].set_xlabel('Time [s]') ax[-1, 1 + v * 2].set_xlabel('Frequency [Hz]') ax[-1, 1 + v * 2].set_ylabel('[Hz]') ax[0, 0 + v * 2].get_shared_x_axes().join(*np.concatenate([ax[:, 0], ax[:, 0 + v * 2]])) ax[0, 1 + v * 2].get_shared_x_axes().join(*np.concatenate([ax[:, 1], ax[:, 1 + v * 2]])) ax[2, 1 + v * 2].get_shared_y_axes().join(ax[2, 1], ax[4, 1], ax[2, 1 + v * 2], ax[4, 1 + v * 2]) ax[1, 1 + v * 2].get_shared_y_axes().join(ax[0, 1], ax[1, 1], ax[3, 1], ax[5, 1], ax[0, 1 + v * 2], ax[1, 1 + v * 2], ax[3, 1 + v * 2], ax[5, 1 + v * 2]) ax[2, 0 + v * 2].get_shared_y_axes().join(ax[2, 0], ax[4, 0], ax[2, 0 + v * 2], ax[4, 0 + v * 2]) ax[1, 0 + v * 2].get_shared_y_axes().join(ax[1, 0], ax[3, 0], ax[5, 0], ax[1, 0 + v * 2], ax[3, 0 + v * 2], ax[5, 0 + v * 2]) # ax[0, 0],ax[0, 0+v*2], def find_cells(file_names_exclude, sorting, cells_chosen, cell_type, cell_type_type, cell_type_chosen, load_path): frame_base = load_cv_table() frame_base = unify_cell_names(frame_base, cell_type=cell_type_type) frame_base = frame_base[frame_base[cell_type_type] == cell_type_chosen] cell_base = frame_base.cell.unique() if '.csv' not in load_path: stack = pd.read_csv(load_path + '.csv') # ,index_col = 0 else: stack = pd.read_csv(load_path) stack = stack[~stack['file_name'].isin(file_names_exclude)] stack_files = stack # [stack['celltype'].isin(cell_type)]#cell_type_type cells_gwn = stack_files.cell.unique() cell_chose = 'base' if cell_chose == 'base': cells = cell_base else: cells = cells_gwn if 'p-unit' in cell_type: if len(cells_chosen) == 0: stack_cells = stack_files[stack_files['cell'].isin(cells)] cvs = stack_cells[sorting] # .iloc[0] cells = np.array(stack_cells.cell) cvs = np.array(cvs) lengths = stack_cells['stimulus_length'] cv_min = False if cv_min: cells = cells[cvs < 0.3] lengths = lengths[cvs < 0.3] cvs = cvs[cvs < 0.3] cells = cells[lengths > 3] cvs = cvs[lengths > 3] cells, cvs_unique = make_cell_unique(cvs, cells) cells = list(cells) # Zellen mit starken Artefakten cells_rem = ['2010-08-25-ab-invivo-1', '2010-11-08-aa-invivo-1', '2010-11-11-al-invivo-1', '2011-02-18-ab-invivo-1', '2011-09-21-ab-invivo-1', '2011-10-25-ac-invivo-1', '2011-11-10-ab-invivo-1', '2011-11-10-ag-invivo-1', '2012-12-19-aa-invivo-1', '2012-12-19-ab-invivo-1', '2012-12-19-ac-invivo-1', '2013-02-21-aa-invivo-1', '2013-04-09-ab-invivo-1', '2013-04-16-aa-invivo-1', '2013-04-16-ab-invivo-1', '2013-04-16-ac-invivo-1', '2013-04-17-af-invivo-1', '2013-04-18-ac-invivo-1', ] # cells_rem_wo_base = ['2010-11-08-ab-invivo-1', '2010-07-29-ae-invivo-1', '2011-11-10-ah-invivo-1', '2011-11-10-ak-invivo-1', '2012-05-30-aa-invivo-1', '2012-07-12-al-invivo-1', '2012-10-19-aa-invivo-1', '2012-10-19-ad-invivo-1', '2012-12-20-af-invivo-1', '2013-04-16-ad-invivo-1', '2013-04-11-ab-invivo-1', '2014-01-23-ac-invivo-1', '2014-01-16-aj-invivo-1'] cells_rem_wo_base_not_nice = ['2010-11-26-al-invivo-1', '2010-11-11-aj-invivo-1'] for cell_rem in cells_rem: if cell_rem in cells: cells.remove(cell_rem) for cell_rem in cells_rem_wo_base: if cell_rem in cells: cells.remove(cell_rem) for cell_rem in cells_rem_wo_base_not_nice: if cell_rem in cells: cells.remove(cell_rem) cells = cells[0:16] else: cells = cells_chosen elif cell_type == [' A-unit', ' Ampullary']: stack_cells = stack_files[stack_files['cell'].isin(cells)] cvs = stack_cells[sorting] # .iloc[0] cells = np.array(stack_cells.cell) cvs = np.array(cvs) lengths = stack_cells['stimulus_length'] cv_min = False if cv_min: cells = cells[cvs < 0.3] lengths = lengths[cvs < 0.3] cvs = cvs[cvs < 0.3] cells = cells[lengths > 3] cvs = cvs[lengths > 3] cells, cvs_unique = make_cell_unique(cvs, cells) cells = list(cells) # Zellen mit starken Artefakten cells_rem = ['2010-08-25-ab-invivo-1', '2010-11-08-aa-invivo-1', '2010-11-11-al-invivo-1', '2011-02-18-ab-invivo-1', '2011-09-21-ab-invivo-1', '2011-10-25-ac-invivo-1', '2011-11-10-ab-invivo-1', '2011-11-10-ag-invivo-1', '2012-12-19-aa-invivo-1', '2012-12-19-ab-invivo-1', '2012-12-19-ac-invivo-1', '2013-02-21-aa-invivo-1', '2013-04-09-ab-invivo-1', '2013-04-16-aa-invivo-1', '2013-04-16-ab-invivo-1', '2013-04-16-ac-invivo-1', '2013-04-17-af-invivo-1', '2013-04-18-ac-invivo-1', ] # cells_rem_wo_base = ['2010-11-08-ab-invivo-1', '2010-07-29-ae-invivo-1', '2011-11-10-ah-invivo-1', '2011-11-10-ak-invivo-1', '2012-05-30-aa-invivo-1', '2012-07-12-al-invivo-1', '2012-10-19-aa-invivo-1', '2012-10-19-ad-invivo-1', '2012-12-20-af-invivo-1', '2013-04-16-ad-invivo-1', '2013-04-11-ab-invivo-1', '2014-01-23-ac-invivo-1', '2014-01-16-aj-invivo-1'] cells_rem_wo_base_not_nice = ['2010-11-26-al-invivo-1', '2010-11-11-aj-invivo-1'] for cell_rem in cells_rem: if cell_rem in cells: cells.remove(cell_rem) for cell_rem in cells_rem_wo_base: if cell_rem in cells: cells.remove(cell_rem) for cell_rem in cells_rem_wo_base_not_nice: if cell_rem in cells: cells.remove(cell_rem) cells = cells[0:16] return cells def load_cell_types(file_name_exclude, sorting, load_path, cell_type, cells_chosen=[], cell_type_chosen=' Ampullary', cell_type_type='cell_type_reclassified'): # das ist jetzt ein Funktion die das selber auswählt für die punit.py und ampullary.py functions if os.path.exists(load_path + '.csv'): # hier finde ich quasi nur die Zellen raus die ich haben will cells = find_cells(file_name_exclude, sorting, cells_chosen, cell_type, cell_type_type, cell_type_chosen, load_path) stack = load_data_susept(load_path + '.csv', load_path, cells=cells) else: # wenn das noch nicht abgespeichert ist machen wir das so stack = load_data_susept(load_path + '.csv', load_path) stack_files = stack[stack['celltype'].isin(cell_type)] cells = stack.cell.unique() return stack_files, cells def colorbar_outside_right(ax, fig, im, shrink=0.6, width=0.02, plusx=0.01): pos = ax.get_position() # [[xmin, ymin], [xmax, ymax]]. pos = np.array(pos) xmin = pos[0][0] ymin = pos[0][1] ymax = pos[1][1] left = xmin + plusx bottom = ymax # - 0.076+add#85 height = (ymax - ymin) cbar_ax = fig.add_axes([left, bottom, width, height]) # [left, bottom, width, height cbar_ax.xaxis.set_label_position('bottom') cbar_ax.set_xticklabels(cbar_ax.get_xticklabels(), rotation='vertical') cbar_ax.tick_params(labelsize=6) cbar = fig.colorbar(im, orientation="vertical", cax=cbar_ax, shrink=shrink) return cbar, left, bottom, width, height def plt_cv_part(cell, frame_save, frame, cell_nr, ax, lim_here=[], color_bar='grey', xlim=(0, 17)): cv_isi = frame.iloc[cell_nr].cv # embed()#'\n cv_inst '+str(np.round(cv_inst,2))+' cv_inst_fr '+ str(np.round(cv_inst_fr,2))' cv_mat '+ str(np.round(cv_mat,2))' m_isi '+str(np.round(mean_isi))'\n m_inst_fr '+ str(np.round(mean_inst_fr))+ cv_title = False if cv_title: if cv_isi < 0.2: color = 'red' elif cv_isi < 0.3: color = 'purple' elif cv_isi < 0.4: color = 'orange' elif cv_isi < 0.7: color = 'green' else: color = 'blue' else: color = title_color(cell) frame_here = frame_save[frame_save.cell == cell] try: hist = frame_here['hist'].iloc[0][0] except: print('hist problem') embed() width = (hist[1][1] - hist[1][0]) if lim_here != []: ex = list(hist[1] > lim_here) ex_bars = ex[0:-1] y = hist[0][ex_bars] # [0]+ width / 2 x = hist[1][np.array(ex)][0:-1] # [0] else: x = hist[1][0:-1] + width / 2 y = hist[0] ax.bar(x, height=y, width=width, color=color_bar) if xlim: ax.set_xlim(xlim) return color, cv_isi def plt_psd_traces(grid1, grid2, axs, min_lim, max_lim, eod_fr, fr, fr_stim, stack_final1, fr_color='red', fr_stim_color='darkred', peaks=True, db='', rmv_axo=True, stack_isf=[], stack_osf=[], rmv_axi=True, reset_pos=True, eod_fr_color='magenta', eod_fr_half_color='purple'): ax_pos, axi, colors_f, freqs, isf_resaved = plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr, fr_color, fr_stim, fr_stim_color, grid2, max_lim, min_lim, peaks, reset_pos, rmv_axi, stack_final1, stack_isf, isf_name='isf') # plot output trace ax_pos, axo, colors_f, freqs, isf_resaved = plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr, fr_color, fr_stim, fr_stim_color, grid1, max_lim, min_lim, peaks, reset_pos, rmv_axo, stack_final1, stack_osf, isf_name='osf') return axo, axi def plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr, fr_color, fr_stim, fr_stim_color, grid2, max_lim, min_lim, peaks, reset_pos, rmv_axi, stack_final1, stack_isf, isf_name='isf', clip_on=True): ax_pos = np.array(axs.get_position()) # [[xmin, ymin], [xmax, ymax]]. if len(stack_isf) == 0: isf = stack_final1[isf_name] isf_resaved = False else: isf = stack_isf isf_resaved = True axi = plt.subplot(grid2) ax_pos2 = np.array(axi.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width if reset_pos: axi.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0], ax_pos2[1][1] - ax_pos2[0][1]]) freqs = [fr, fr * 2, fr_stim, fr_stim * 2, eod_fr, eod_fr * 2, eod_fr / 2] colors_f = [fr_color, fr_color, fr_stim_color, fr_stim_color, eod_fr_color, eod_fr_color, eod_fr_half_color] plt_isf_ps_red(stack_final1, isf, 0, axi, freqs=freqs, colors=colors_f, clip_on=clip_on, peaks=peaks, db=db, max_lim=max_lim, osf_resaved=osf_resaved, ) axi.set_xlim(min_lim, max_lim) if rmv_axi: remove_xticks(axi) return ax_pos, axi, colors_f, freqs, isf_resaved def plt_isf_ps_red(stack_final, isf, l, ax_i, color='black', power=1, several=False, maxi=1, peaks=True, freqs=[], db='', max_lim=None, colors=[], osf_resaved=False, clip_on=False): f = find_f(stack_final) try: if osf_resaved: f_axis = f[0:len(isf)] means = np.transpose(isf) means_all = np.mean(np.abs(means) ** power, axis=0) p = np.abs(means.iloc[0]) ** power else: f_axis = f[0:len(isf.iloc[l][0])] means = get_array_from_pandas(isf) means_all = np.mean(np.abs(means) ** power, axis=0) p = np.abs(isf.iloc[l][0]) ** power if db == 'db': p = 10 * np.log10(p / maxi) means_all = 10 * np.log10(means_all / maxi) add = np.percentile(means_all, 90) if max_lim: if several: ax_i.plot(f_axis[f_axis < max_lim], p[f_axis < max_lim], color='grey', zorder=1) ax_i.plot(f_axis[f_axis < max_lim], means_all[f_axis < max_lim], color='black', zorder=1) else: if several: ax_i.plot(f_axis, p, color='grey', zorder=1) ax_i.plot(f_axis, means_all, color=color, zorder=1) ax_i.set_xlim(0, 700) if peaks: for i in range(len(freqs)): plt_peaks(ax_i, p, freqs[i], f_axis, fr_color=colors[i], add=add, clip_on=clip_on) except: print('f axis problem') embed() def find_cells_plot(save_names, amps_desired=[5, 10, 20], cell_class=' Ampullary'): # 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv' # frame_csv_overview_test = pd.read_csv('../data/Noise/noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv') load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv' frame_csv_overview = pd.read_csv(load_name, low_memory=False) frame_csv_overview[['cell', 'amp']].sort_values('cell') ######################## # here find the cells that are in the amps unique_combos = frame_csv_overview[['cell', 'amp']].drop_duplicates() if len(amps_desired) > 0: combos_three = unique_combos[unique_combos.amp.isin(amps_desired)] cell_counts = combos_three.cell.value_counts() cell_to_plot = cell_counts[cell_counts == len(amps_desired)].keys() else: cell_counts = unique_combos.cell.value_counts() cell_to_plot = cell_counts[cell_counts > 3].keys() # hier nehmen wir wirklich nur die die auch ein GWN haben, das ist der Unterschied frame = load_cv_table() cell_type_type = 'cell_type_reclassified' frame = unify_cell_names(frame, cell_type=cell_type_type) cell_types = frame[cell_type_type].unique() cells_dict = cluster_cells_by_group_dict(cell_types, frame, cell_type_type) cells = cells_dict[cell_class] cells_plot = cell_to_plot[cell_to_plot.isin(cells)] frame_cv = frame[frame.cell.isin(cells_plot)] frame_cv = frame_cv.sort_values('cv') cells_plot = frame_cv.cell cells_plot = list(cells_plot) if '2012-06-08-ae-invivo-1' in cells_plot: cells_plot.remove('2012-06-08-ae-invivo-1') return amps_desired, cell_type_type, cells_plot, frame, cell_types def load_cells_in_sample(cell_class, save_names, amps_desired, cell_type_type, frame): load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv' frame_csv_overview = pd.read_csv(load_name, low_memory=False) # dtype={'three': bool,'cell':str,'highest_fr':float} frame_csv_overview[['cell', 'amp']].sort_values('cell') unique_combos = frame_csv_overview[['cell', 'amp']].drop_duplicates() if len(amps_desired) > 0: combos_three = unique_combos[unique_combos.amp.isin(amps_desired)] cell_counts = combos_three.cell.value_counts() cell_to_plot = cell_counts[cell_counts == len(amps_desired)].keys() else: cell_counts = unique_combos.cell.value_counts() cell_to_plot = cell_counts[cell_counts > 3].keys() # hier nehmen wir wirklich nur die die auch ein GWN haben, das ist der Unterschied cell_types = frame[cell_type_type].unique() cells_dict = cluster_cells_by_group_dict(cell_types, frame, cell_type_type) cells = cells_dict[cell_class] cells_plot = cell_to_plot[cell_to_plot.isin(cells)] frame_cv = frame[frame.cell.isin(cells_plot)] frame_cv = frame_cv.sort_values('cv_min') cells_plot = frame_cv.cell cells_plot = list(cells_plot) if '2012-06-08-ae-invivo-1' in cells_plot: cells_plot.remove('2012-06-08-ae-invivo-1') return cells_plot, cell_types def load_isis(save_names, amps_desired=[5, 10, 20], cells_given=[], cell_class=' Ampullary'): # 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv' # frame_csv_overview_test = pd.read_csv(load_folder_name('calc_RAM')+'/calc_RAM_model-2_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv') # if os.path.exists(): cell_type_type = 'cell_type_reclassified' frame = load_cv_base_frame(cells_given, cell_type_type=cell_type_type) ##################### # rausfinden welche Zellen wir plotten wollen cells_plot, cell_types = load_cells_in_sample(cell_class, save_names, amps_desired, cell_type_type, frame) return amps_desired, cell_type_type, cells_plot, frame, cell_types def remove_tick_marks(ax): ax.xaxis.set_major_formatter(ticker.NullFormatter()) return ax def plt_scatter_two(ax0, ax2, frame, cell_types, cell_type_type, annotate, colors): add = ['', '_burst_corr_individual', ] # ok hier plotten wir nur den scatter der auch ein gwn hat, aber was ist wenn es mehr sind? # ok im prinzip sollte das zwar schon stimmen aber für das Bild kann man wirklich mehr machen for c, cell_type_it in enumerate(cell_types): frame_g = frame[ (frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))] plt_cv_fr(annotate, ax0, add[0], frame_g, colors, cell_type_it) ax2.set_title('burst') for c, cell_type_it in enumerate(cell_types): frame_g = frame[ (frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))] plt_cv_fr(annotate, ax2, add[1], frame_g, colors, cell_type_it) return ax0, ax2 def square_func(ax, stack_final, perc_min=5, perc_max=95, norm='', s=0): new_keys, stack_plot = retrieve_mat_plot(stack_final) eod_fr = stack_final.eod_fr.unique()[0] fr2 = np.unique(stack_final.fr_stim) fr = stack_final.fr.unique()[0] # todo: hier das normen noch anpassen mat = ram_norm_choice(stack_plot, norm, stack_final) imshow = True if imshow: vmin = np.nanpercentile(mat, perc_min) vmax = np.nanpercentile(mat, perc_max) im = ax[s].imshow(mat, vmin=vmin, extent=[mat.index[0], mat.index[-1], mat.columns[0], mat.columns[-1]], vmax=vmax, origin='lower', cmap='viridis') else: im = ax[s].pcolormesh(mat.index, mat.columns, mat, vmin=0, vmax=np.nanpercentile(mat, 97), cmap='viridis', rasterized=True) # np.nanpercentile(mat, 1) , cmap ='hot''Greens' pcolormesh ax[s].set_aspect('equal') ax[s].set_xlabel(F1_xlabel()) ax[s].set_ylabel(F2_xlabel(), labelpad=0.2) ax[s].set_xlim(mat.index[0], mat.index[-1]) ax[s].set_ylim(mat.columns[0], mat.columns[-1]) plt_triangle(ax[s], fr, np.mean(fr2), new_keys[-1], eod_fr, eod_fr_half_color='purple', fr_color='red', eod_fr_color='magenta', fr_stim_color='darkred') return im, mat.columns[0], mat.columns[-1] def retrieve_mat_plot(stack_final): keys = stack_final.keys() new_keys = stack_final.index stack_plot = stack_final[new_keys] return new_keys, stack_plot def plot_square_core(ax, stack_final, s=0, nr=3, eod_metrice=True, fr=None, cbar_do=True, perc=True, line_length=1 / 4, add_nonlin_title=None): new_keys, stack_plot = convert_csv_str_to_float(stack_final) eod_fr = stack_final.eod_fr.unique()[0] fr2 = np.unique(stack_final.fr_stim) if not fr: fr = stack_final.fr.unique()[0] norm_d = False if norm_d: mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0]) else: mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat, add_nonlin_title=add_nonlin_title) print(add_nonlin_title) imshow = True if imshow: if perc: im = ax[s].imshow(mat, vmin=np.nanpercentile(mat, 5), extent=[mat.index[0], mat.index[-1], mat.columns[0], mat.columns[-1]], vmax=np.nanpercentile(mat, 95), cmap='viridis', origin='lower') # else: im = ax[s].imshow(mat, extent=[mat.index[0], mat.index[-1], mat.columns[0], mat.columns[-1]], cmap='viridis', origin='lower') else: im = ax[s].pcolormesh(mat.index, mat.columns, mat, vmin=0, vmax=np.nanpercentile(mat, 97), cmap='Greens', rasterized=True) # np.nanpercentile(mat, 1) , cmap ='hot''Greens' pcolormesh ax[s].set_aspect('equal') ax[s].set_xlabel(F1_xlabel()) ax[s].set_ylabel(F2_xlabel(), labelpad=0.2) ax[s].set_xlim(mat.index[0], mat.index[-1]) ax[s].set_ylim(mat.columns[0], mat.columns[-1]) plt_triangle(ax[s], fr, np.mean(fr2), new_keys[-1], eod_fr, line_length=line_length, eod_metrice=eod_metrice, nr=nr) # eod_fr_half_color='purple', power_noise_color='blue', if cbar_do: try: cbar = plt.colorbar(im, ax=ax[s], shrink=0.6) except: print('colorbar problem') cbar = [] else: cbar = [] return cbar, mat, im, add_nonlin_title def cluster_cells_by_group_dict(cell_types, frame, cell_type_type): cells = {} for ct in np.sort(cell_types): fr = frame[frame[cell_type_type] == ct].cell fr = fr.astype('str') cells[ct] = np.sort(np.array(fr.unique())) return cells def plt_cell_body_isf_single_rotate2(axi, grid1, ax0, ax1, ax2, b, cell, frame, colors, amps_desired, save_names, cell_type_type, xlim=[0, 13], burst_corr='', predefined_amps2=False, norm=False): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) cell_type = frame_cell[cell_type_type].iloc[0] spikes = frame_cell.spikes.iloc[0] fr = frame_cell.fr.iloc[0] cv = frame_cell.cv.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont = load_spikes(spikes, eod_fr) # cont heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht if cont: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: lim_here = find_lim_here(cell, burst_corr=burst_corr) if np.min(np.concatenate(hists)) < lim_here: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all), lim=lim_here, burst_corr=burst_corr) hists_both = [hists, hists2] else: hists_both = [hists, hists] # das ist der title fals der square nicht plottet plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz', fontsize=11, ) # cell[0:13] + color=color+ cell_type load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = punit_file_exclude() # else: file_names_exclude = ampullary_file_exclude() # files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amps2: for a, amp in enumerate(amps): if amp not in amps_desired: predefined_amp = False if predefined_amp: pass else: pass amps_defined = [np.min(amps)] file, cut_offs = find_optimal_files(files) stack_file = stack[stack['file_name'] == file] for a, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): axs, axo, axin = square_isf(grid1, norm, b, cell, stack_file, amp, eod_fr, file) ################################ # do the scatter of these cells add = ['', '_burst_corr', ] try: ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('colors_f problem') embed() if len(ax1) > 0: ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') plt_hists(axi, cell_type, colors, hists_both, xlim, b, alpha=1) return axs, axo, axin def find_optimal_files(files): cut_offs = [] for file in files: cut_offs.append(calc_cut_offs(file)) file = files[np.argmax(cut_offs)] return file, cut_offs def square_isf(grid1, norm, b, cell, stack_file, amp, eod_fr, file): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: axs = plt.subplot(grid1[2]) except: print('grid problem6') embed() im, min_lim, max_lim = square_func([axs], stack_final1, norm=norm) cbar = plt.colorbar(im, ax=axs, orientation='vertical') if b != 0: cbar.set_label(nonlin_title(), rotation=270, labelpad=1000) fr = stack_final1.fr.unique()[0] snippets = stack_final1['snippets'].unique()[0] cv = stack_final1.cv.unique()[0] ser = stack_final1.ser.unique()[0] cv_stim = stack_final1.cv_stim.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] ser_stim = stack_final1.ser_stim.unique()[0] plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + '' + 'S.Nr ' + str( snippets) + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz' + ' ser ' + str(np.round(ser)) + ' Stim: cv ' + str(np.round(cv_stim, 2)) + ' fr ' + str( np.round(fr_stim)) + ' Hz' + ' ser ' + str(np.round(ser_stim)) + ' length ' + str( length) , fontsize=11, ) # cell[0:13] + color=color+ cell_type eod_fr_half_color = 'purple' fr_color = 'red' eod_fr_color = 'magenta' fr_stim_color = 'darkred' axo, axin = plt_psd_traces(grid1[0], grid1[1], axs, min_lim, max_lim, eod_fr, fr, fr_stim, stack_final1, fr_color, fr_stim_color, eod_fr_color, eod_fr_half_color) axo.set_title(' std ' + str(amp) + ' ' + file) return axs, axo, axin def plt_hists(axi, cell_type, colors, hists_both, xlim, b, alpha=1): if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] for gg in range(len(range(b + 1))): hists_here = hists_both[gg] for hh, h in enumerate(hists_here): try: axi.hist(h, bins=100, color=colors_hist[gg], label='CV ' + str(np.round(np.std(h) / np.mean(h), 3)), alpha=float(alpha - 0.05 * hh)) except: print('alpha problem4') embed() axi.legend(ncol=2) if len(xlim) > 0: axi.set_xlim(xlim) axi.set_xlabel('isi') def plt_cv_fr(annotate, ax0, add, frame_g, colors, cell_type): ax0.scatter(frame_g['cv' + add], frame_g['fr' + add], alpha=0.5, label=cell_type, s=7, color=colors[str(cell_type)]) exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax0.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['fr' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax0.set_xlim(0, 1.5) ax0.set_ylabel('Base Freq [Hz]') ax0.set_xlabel('CV') def plt_cv_part_several(row, col, cell, frame_save, frame, cell_nr, counter, ax): cv_isi = frame.iloc[cell_nr].cv fr = frame.iloc[cell_nr].fr cv_title = False if cv_title: if cv_isi < 0.2: color = 'red' elif cv_isi < 0.3: color = 'purple' elif cv_isi < 0.4: color = 'orange' elif cv_isi < 0.7: color = 'green' else: color = 'blue' else: color = title_color(cell) frame_here = frame_save[frame_save.cell == cell] ax[counter].text(0, 1.26, cell[0:-9] + '\n cv ' + str(np.round(cv_isi, 2)) + ' fr ' + str(np.round(fr)) + ' Hz', transform=ax[counter].transAxes, color=color, fontsize=8) try: hist = frame_here['hist'].iloc[0][0] except: print('hist problem') embed() width = (hist[1][1] - hist[1][0]) ax[counter].bar(hist[1][0:-1] + width / 2, height=hist[0], width=width, ) if counter == row * col - col: ax[counter].set_xlabel('Inter Spike Interval, EODf multiples') ax[counter].set_ylabel('nr') ax[counter].set_xlim(0, 17) counter += 1 return counter def plt_squares_special(params, col_desired=2, var_items=['contrasts'], show=False, contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['RAM'], external_noise=['RAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000], receiver_contrast=[1], visualize=True, dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], perc='', share=True, c_signal=[0.9], new_plot=True, cut_offs1=[300], clims='all', restrict='restrict', label=r'$\frac{1}{mV^2S}$', width=0.005, cells_given=None, lp=100, ax=[], titles_plot=True): nffts = ['whole'] # ,int(2 ** 16) int(2 ** 16), int(2 ** 15), stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 powers = [1] # ,3]#, 3, 1, 1.5, 0.5, ] # ,1,1.5, 0.5] #[1,1.5, 0.5] # 1.5,0.5]3, 1, variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 if new_plot: plot_style() if cells_given: params_len = np.arange(0, len(params) * len(cells_given), 1) else: params_len = params col, row = find_row_col(params_len, col=col_desired) # np.arange( if col == 2: default_settings(column=2, length=7.5) # 2+2.25+2.25 elif col == 1: default_settings(column=2, length=4) elif col > 2: if row == 2: default_settings(column=2, length=4.5) else: default_settings(column=2, length=7.5) else: default_settings(column=2, length=7.5) fig, ax_orig = plt.subplots(row, col, sharex=True, sharey=True) # constrained_layout=True,, figsize=(11, 5) if row != 1: ax = np.concatenate(ax_orig) else: ax = ax_orig if col == 2: plt.subplots_adjust(bottom=0.067, top=0.81, hspace=0.39, right=0.95, left=0.075) # , hspace = 0.6, wspace = 0.5 elif col == 1: plt.subplots_adjust(bottom=0.1, top=0.81, hspace=0.39, right=0.95, left=0.075) # , hspace = 0.6, wspace = 0.5 else: if row == 2: plt.subplots_adjust(bottom=0.07, top=0.76, wspace=0.9, hspace=0.4, right=0.85, left=0.075) # , hspace = 0.6, wspace = 0.5 else: plt.subplots_adjust(bottom=0.05, top=0.81, wspace=0.9, hspace=0.2, right=0.85, left=0.075) # , hspace = 0.6, wspace = 0.5 else: col = col_desired maxs = [] mins = [] ims = [] ####################################################################### # das ist jetzt der core a = 0 aa = 0 for var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product( D_extraction_method, external_noise , repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 nr = '2' for p, param in enumerate(params): print(a) 'contrasts' a_fe = params[p]['contrasts'][0] var_type = params[p]['D_extraction_method'][0] extract = params[p]['level_extraction'][0] if 'repeats' in params[p]: trials_stim = params[p]['repeats'][0] save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') adapt_type_name, dendrid_name, ref_type_name, stim_type_noise_name = add_ends(adapt_type, dendrid, ref_type, stim_type_noise, var_type) stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) path = save_name + '.pkl' # '../'+ cell_add, cells_save = find_cell_add(cells_given) model = load_model_susept(path, cells_save, save_name.split(r'/')[-1] + cell_add) test = False if test: from utils_test import test_model test_model() if len(model) > 0: cells = model.cell.unique() # model = pd.read_pickle(path) if not cells_given: cells = [cells[0]] else: cells = cells_given for c, cell in enumerate(cells): suptitles, titles = titles_susept_names(a_fe, extract, noise_added, stim_type_afe_name, stim_type_noise_name2, trials_stim, var_items, var_type) if len(cells) > 1: titles = cell + ' ' + titles add_nonlin_title, cbar, fig, stack_plot, im = plt_single_square_modl(ax[a], cell, model, perc, titles, width, titles_plot) ims.append(im) maxs.append(np.max(np.array(stack_plot))) mins.append(np.min(np.array(stack_plot))) if a in np.arange(col - 1, 100, col): cbar.set_label(label, labelpad=lp) # rotation=270, if new_plot: if a >= row * col - col: ax[a].set_xlabel(F1_xlabel(), labelpad=20) if len(cells) > 1: a += 1 ax[0].set_ylabel(F2_xlabel()) if a in np.arange(0, len(ax), 1) * col: if a < len(ax): try: ax[a].set_ylabel(F2_xlabel()) except: print('ax a thing') embed() if len(cells) == 1: a += 1 if titles_plot: end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title) # +' file ' if share: set_clim_same(clims=clims, ims=ims, maxs=maxs, mins=mins, lim_type='', nr_clim='10', perc='') improved = False if improved: set_clim_same(ims, lim_type='up') if new_plot: if col < 3: fig.tag(ax, xoffs=-3, yoffs=5.8) else: if row == 2: fig.tag([ax_orig[0, :], ax_orig[1, :]], xoffs=-5.5, yoffs=3.8) else: fig.tag([ax_orig[0, :], ax_orig[1, :], ax_orig[2, :]], xoffs=-3, yoffs=3.8) if visualize: save_visualization(pdf=True) if show: plt.show() def plt_single_square_modl(ax, cell, model, perc, titles, width, bias_factor=1, fr_print=False, eod_metrice=True, nr=3, titles_plot=False, xpos=1.1, resize=False, ls=8): model_show, stack_plot, stack_plot_wo_norm = get_stack(cell, model, bias_factor=bias_factor) print(np.max(np.max(stack_plot))) #embed() if resize: stack_plot, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_plot) else: add_nonlin_title = '' try: ax.set_xlim(0, 300) except: print('aa thing') embed() ax.set_ylim(0, 300) ax.set_aspect('equal') cbar = [] im = [] if len(model_show) > 0: if fr_print: add_here = '\n fr$_{S}$=' + str(int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz cv$_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) else: add_here = '' if titles_plot: ax.text(xpos, 1.05, titles + add_here, ha='right', transform=ax.transAxes) # , fontsize7= + cell_type# cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str( im = plt_RAM_perc(ax, perc, stack_plot) #print(np.max(np.max(stack_plot))) #embed() plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300, model_show.eod_fr.iloc[0], eod_metrice=eod_metrice, nr=nr) ax.set_aspect('equal') fig = plt.gcf() cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, ls=ls, shrink=0.6, width=width) # 0.02 return add_nonlin_title, cbar, fig, stack_plot, im def get_stack(cell, model, bias_factor=1): try: model_show = model[(model.cell == cell)] except: print('cell something') embed() stack_plot_wo_norm = change_model_from_csv_to_plots(model_show) stack_plot = RAM_norm(stack_plot_wo_norm, model_show=model_show, bias_factor=bias_factor) return model_show, stack_plot, stack_plot_wo_norm def plt_all_scatter_rotated(ax0, ax1, frame, cell_types, add='', alpha=0.5, s=7, annotate=False, cell_type_type='cell_type_info'): frame_g = ptl_fr_cv(add, alpha, annotate, ax0, cell_type_type, cell_types, frame, s) colors = colors_overview() for c, cell_type in enumerate(cell_types): vs = np.array(list(map(float, np.array(frame_g['vs' + add])))) cv = np.array(list(map(float, np.array(frame_g['cv' + add])))) exclude = np.isnan(cv) | np.isnan(vs) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax1.text(frame_g_ex['vs' + add].iloc[f], frame_g_ex['cv' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax0.set_ylim(0, 1.5) ax1.scatter(frame_g['vs' + add], frame_g['cv' + add], alpha=alpha, label=cell_type, s=s, color=colors[str(cell_type)]) # ax1.set_ylim(0, 1.5) ax1.set_xlim(0, 1) if 'burst' in add: ax1.set_ylabel('$CV_{Burst Corr}$') ax0.set_ylabel('$CV_{Burst Corr}$') else: ax0.set_ylabel('CV') ax1.set_ylabel('CV') ax0.set_xlabel('Base Freq [Hz]') ax1.set_xlabel('VS') plt.subplots_adjust(wspace=0.25, bottom=0.1) def ptl_fr_cv(add, alpha, annotate, ax0, cell_type_type, cell_types, frame, s, color_given=None, cv='cv', fr='fr'): colors = colors_overview() for c, cell_type in enumerate(cell_types): print(cell_type) frame_g = frame[(frame[cell_type_type] == cell_type) & ((frame.gwn == True) | (frame.fs == True))] if not color_given: color_given = colors[str(cell_type)] try: ax0.scatter(frame_g[fr + add], frame_g[cv + add], alpha=alpha, label=cell_type, s=s, color=color_given, clip_on=True) except: print('scatter thing') embed() print('mean(' + str(fr + add) + str(np.mean(frame_g[fr + add])) + ' ' + 'mean(' + str(cv + add) + str( np.mean(frame_g[cv + add]))) c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_g, cv + add, cv_name=fr + add, score=cv + add) try: legend_wo_dot(ax0, 0.9 - 0.1 * c, x_axis, y_axis, ha='left', color=color_given, x_pos=0) except: print('something') embed() exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax0.text(frame_g_ex['fr' + add].iloc[f], frame_g_ex['cv' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) test = False if test: pass return frame_g def plt_all_scatter(ax0, ax1, frame, cell_types, colors, add='', alpha=0.5, s=7, annotate=False, cell_type_type='cell_type_info'): for c, cell_type in enumerate(cell_types): frame_g = frame[(frame[cell_type_type] == cell_type) & ((frame.gwn == True) | (frame.fs == True))] ax0.scatter(frame_g['cv' + add], frame_g['fr' + add], alpha=alpha, label=cell_type, s=s, color=colors[str(cell_type)]) exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax0.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['fr' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['vs' + add]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax1.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['vs' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax0.set_xlim(0, 1.5) ax1.scatter(frame_g['cv' + add], frame_g['vs' + add], alpha=alpha, label=cell_type, s=s, color=colors[str(cell_type)]) # ax1.set_xlim(0, 1.5) ax1.set_ylim(0, 1) if 'burst' in add: ax0.set_xlabel('$CV_{Burst Corr}$') ax1.set_xlabel('$CV_{Burst Corr}$') ax1.set_ylabel('$VS_{Burst Corr}$') ax0.set_ylabel('Base Freq [Hz] $_{Burst Corr}$') else: ax0.set_xlabel('CV') ax1.set_xlabel('CV') ax0.set_ylabel('Base Freq [Hz]') ax1.set_ylabel('VS') if add == '': ax0.legend(ncol=5, loc=(0, 1.05)) plt.subplots_adjust(wspace=0.25, bottom=0.1) def plt_all_width_rotated(frame, cell_types, frame_cell, add, gg, cell_type, ax2, annotate=False, alpha=1, xlim=[0, 25], s=15): colors = colors_overview() if 'width_75' + add[gg] in frame_cell.keys(): ax2.scatter(frame_cell['width_75' + add[gg]], frame_cell['width_75' + add[gg]], alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') for c, cell_type in enumerate(cell_types): # frame_all = frame[(frame['cell_type_info'] == cell_type)] # frame_g = frame[(frame['cell_type_reclassified'] == cell_type) & ((frame.gwn == True) | (frame.fs == True))] ax2.scatter(frame_g['width_75' + add[gg]], frame_g['cv' + add[gg]], alpha=alpha, label=cell_type, s=s, color=colors[str(cell_type)]) exclude = np.isnan(frame_g['cv' + add[gg]]) | np.isnan(frame_g['width_75' + add[gg]]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax2.text(frame_g_ex['width_75' + add[gg]].iloc[f], frame_g_ex['cv' + add[gg]].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax2.set_ylim(0, 1.5) if 'burst' in add[gg]: ax2.set_ylabel('$CV_{Burst Corr}$') else: ax2.set_ylabel('CV') ax2.set_xlabel('Width at 75 %') ax2.set_xlim(xlim) def plt_all_width(frame, cell_types, frame_cell, add, gg, colors, cell_type, ax2, annotate=False, alpha=1, s=15): if 'width_75' + add[gg] in frame_cell.keys(): ax2.scatter(frame_cell['width_75' + add[gg]], frame_cell['width_75' + add[gg]], alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') for c, cell_type in enumerate(cell_types): frame_g = frame[(frame['cell_type_reclassified'] == cell_type) & ((frame.gwn == True) | (frame.fs == True))] ax2.scatter(frame_g['cv' + add[gg]], frame_g['width_75' + add[gg]], alpha=alpha, label=cell_type, s=s, color=colors[str(cell_type)]) exclude = np.isnan(frame_g['cv' + add[gg]]) | np.isnan(frame_g['width_75' + add[gg]]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax2.text(frame_g_ex['cv' + add[gg]].iloc[f], frame_g_ex['width_75' + add[gg]].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax2.set_xlim(0, 0.9) if 'burst' in add[gg]: ax2.set_xlabel('$CV_{Burst Corr}$') ax2.set_ylabel('Width at 75 % $_{Burst Corr}$') else: ax2.set_ylabel('Width at 75 %') ax2.set_xlabel('CV') ax2.set_ylim(0, 25) def plt_scatter_three2(grid2, frame, cell_type_type, annotate, colors, cell_types=[' P-unit', ' Ampullary'], add=['', '_burst_corr_individual']): ax0 = plt.subplot(grid2[0]) # ok hier plotten wir nur den scatter der auch ein gwn hat, aber was ist wenn es mehr sind? # ok im prinzip sollte das zwar schon stimmen aber für das Bild kann man wirklich mehr machen for c, cell_type_it in enumerate(cell_types): frame_g = frame[ (frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))] plt_cv_fr(annotate, ax0, add[0], frame_g, colors, cell_type_it) ax1 = plt.subplot(grid2[1]) for c, cell_type_it in enumerate(cell_types): frame_g = frame[ (frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))] plt_cv_vs(frame_g, ax1, add[0], annotate, colors, cell_type_it) ax2 = plt.subplot(grid2[2]) for c, cell_type_it in enumerate(cell_types): frame_g = frame[ (frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))] plt_cv_fr(annotate, ax2, add[1], frame_g, colors, cell_type_it) ax2.set_ylabel('Base Freq [Hz] $_{Burst Corr}$') ax2.set_xlabel('$CV_{Burst Corr}$') return ax0, ax1, ax2 def plt_cell_body_single_amp(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=[]): for c, cell in enumerate(cells_plot): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) cell_type = frame_cell[cell_type_type].iloc[0] spikes = frame_cell.spikes.iloc[0] fr = frame_cell.fr.iloc[0] cv = frame_cell.cv.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) # cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht # also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben! if cont_spikes: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: if np.min(np.concatenate(hists)) < 1.5: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all)) hists_both = [hists, hists2] else: hists_both = [hists] # das ist der title fals der square nicht plottet plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz', fontsize=11, ) # cell[0:13] + color=color+ cell_type load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') file_names_exclude = file_names_to_exclude(cell_type) files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() for a, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: grid_s = gridspec.GridSpecFromSubplotSpec(5, 1, grid1[c], height_ratios=[1.5, 1.5, 5, 1.5, 1.5, ], hspace=0) axs = plt.subplot(grid_s[2]) except: print('grid problem5') embed() cbar, mat, im = plot_square_core([axs], stack_final1) if a == len(amps) - 1: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) fr = stack_final1.fr.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns), np.max(mat.columns), eod_fr, fr, fr_stim, stack_final, ) if c == 0: axi.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length) if a != 0: axi.set_ylabel('') # do the scatter of these cells add = ['', '_burst_corr', ] if type(ax0) != list: ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if type(ax1) != list: ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if type(ax2) != list: ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if ax3 != []: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type) try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() ################################ # do the hist alpha = 1 axi = plt.subplot(grid_s[-1]) if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] try: for gg in range(len(hists_both)): hists_here = hists_both[gg] for hh, h in enumerate(hists_here): try: axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh)) except: print('alpha problem5') axi.set_title( 'CV ' + str(np.round(np.std(h) / np.mean(h), 3)) + ' ' + cell) # +' VS '+str(vs) axi.set_xlabel('isi') except: print('hists not there yet') def plt_cell_body3(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[]): for c, cell in enumerate(cells_plot): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) try: cell_type = frame_cell[cell_type_type].iloc[0] except: embed() spikes = frame_cell.spikes.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) # cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht # also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben! if cont_spikes: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: if np.min(np.concatenate(hists)) < 1.5: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all)) hists_both = [hists, hists2] else: hists_both = [hists] # das ist der title fals der square nicht plottet load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') file_names_exclude = file_names_to_exclude(cell_type) files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() for a, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a + 1], height_ratios=[1.5, 1.5, 5], hspace=0) axs = plt.subplot(grid_s[2]) except: print('grid problem4') embed() cbar, mat, im = plot_square_core([axs], stack_final1) if xlim: axs.set_xlim(xlim) axs.set_ylim(xlim) if a == len(amps) - 1: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) fr = stack_final1.fr.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns), np.max(mat.columns), eod_fr, fr, fr_stim, stack_final, ) if c == 0: axo.set_title(' $std=$' + str(amp) + ' %') # files[0] + ' l ' + str(length) if a != 0: axi.set_ylabel('') axo.set_ylabel('') axs.set_ylabel('') if c != 2: axs.set_xlabel('') remove_xticks(axi) if a == 1: axo.set_title(cell) # +' VS '+str(vs) ################################ # do the scatter of these cells add = ['', '_burst_corr', ] ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if ax3 != []: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type) try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() ################################ # do the hist alpha = 1 axi = plt.subplot(grid1[c, 0]) if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] for gg in range(len(hists_both)): hists_here = hists_both[gg] for hh, h in enumerate(hists_here): try: axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh), label='CV ' + str( np.round(np.std(h) / np.mean(h), 3))) except: print('alpha problem6') embed() axi.legend() axi.set_xlim(0, 13) if c != len(cells_plot) - 1: remove_xticks(axi) else: axi.set_xlabel('isi') def plt_cell_body(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[]): for c, cell in enumerate(cells_plot): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) try: cell_type = frame_cell[cell_type_type].iloc[0] except: embed() spikes = frame_cell.spikes.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) # cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht # also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben! if cont_spikes: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: if np.min(np.concatenate(hists)) < 1.5: _, _, _ = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all)) else: pass # das ist der title fals der square nicht plottet load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = punit_file_exclude() # else: file_names_exclude = ampullary_file_exclude() # files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() for a, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a], height_ratios=[1.5, 1.5, 5], hspace=0) axs = plt.subplot(grid_s[2]) except: print('grid problem3') embed() cbar, mat, im = plot_square_core([axs], stack_final1) if xlim: axs.set_xlim(xlim) axs.set_ylim(xlim) if a == len(amps) - 1: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) fr = stack_final1.fr.unique()[0] snippets = stack_final1['snippets'].unique()[0] fr1 = np.unique(stack_final1.fr) cv = stack_final1.cv.unique()[0] ser = stack_final1.ser.unique()[0] cv_stim = stack_final1.cv_stim.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] ser_stim = stack_final1.ser_stim.unique()[0] axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns), np.max(mat.columns), eod_fr, fr, fr_stim, stack_final, ) if c == 0: axo.set_title(' $std=$' + str(amp) + ' %') # files[0] + ' l ' + str(length) if a != 0: axi.set_ylabel('') axo.set_ylabel('') axs.set_ylabel('') if c != 2: axs.set_xlabel('') remove_xticks(axi) if a == 1: axo.set_title(cell) # +' VS '+str(vs) ################################ # do the scatter of these cells add = ['', '_burst_corr', ] ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if ax3 != []: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type) try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() def base_to_stim(load_name, frame, cell_type_type, cell_type_it, stack=[]): if len(stack) == 0: if os.path.exists(load_name): stack_stim = pd.read_csv(load_name, low_memory=False) else: stack_stim = stack cells = frame[frame[cell_type_type] == cell_type_it].cell.unique() frame_gr = stack_stim[stack_stim.cell.isin(cells)] frame1 = frame_gr['cell'] frame_g = frame_gr.loc[frame1.drop_duplicates().index] return frame_g def plt_cv_vs(frame_g, ax1, add, annotate, colors, cell_type): ax1.scatter(frame_g['cv' + add], frame_g['vs' + add], alpha=0.5, label=cell_type, s=7, color=colors[str(cell_type)]) exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['vs' + add]) frame_g_ex = frame_g[~exclude] if annotate: for f in range(len(frame_g_ex)): ax1.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['vs' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13], rotation=45, color=colors[str(cell_type)], fontsize=6) ax1.set_xlim(0, 1.5) ax1.set_ylim(0, 1) ax1.set_xlabel('CV') ax1.set_ylabel('VS') def plt_data_up(cell, ax, fig, cells_chosen, cell_type='p-unit', width=0.005, cbar_label=True): if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'gwn25Hz10s0.3', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn50Hz50s0.3', 'FileStimulus-file-gaussian25.0', 'gwn50Hz10s0.3', ] # else: file_names_exclude = ['blwn125Hz10s0.3', 'gwn50Hz10s0.3', 'InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0', 'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] # if len(cells_chosen) > 0: cells = cells_chosen col = 4 _, _, = find_row_col(cells, col=col) ax_data = [] if cell == '2012-07-03-ak-invivo-1': save_name = 'noise_data9_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s' # _burst_corr else: save_name = 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s' # _burst_corr load_name = load_folder_name('calc_RAM') + '/' + save_name ax_data.append(ax) ######################################### # also die einzelzellen sind in pkls stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell) try: stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)] except: print('stack cell problem') stack_cell = [] if len(stack_cell): file_names = stack_cell.file_name.unique() cut_off_nr = [] for ff, file_name in enumerate(file_names): if 'hz' in file_name.lower(): cut_off_nr = get_cut_off_for_wn(cut_off_nr, file_name) elif 'gaussian' in file_name: cut_off_nr.append(file_name.split('gaussian')[1]) else: cut_off_nr.append(file_name[-5::]) try: maxs = list(map(float, cut_off_nr)) except: print('maxs something') embed() file_names = file_names[np.argmax(maxs)] stack_file = stack_cell[stack_cell['file_name'] == file_names] amps = [np.min(stack_file.amp.unique())] amps = restrict_punits(cell, amps) for amp in amps: stack_amps = stack_file[stack_file['amp'] == amp] lengths = stack_amps.stimulus_length.unique() try: length_max = [np.max(lengths)] except: print('length something') embed() for length in length_max: stack_final = stack_amps[stack_amps['stimulus_length'] == length] if len(stack_final) < 1: embed() snippets = stack_final['snippets'].unique()[0] eod_fr = stack_final.eod_fr.unique()[0] cv = stack_final.cv.unique()[0] fr = stack_final.fr.unique()[0] cv_stim = stack_final.cv_stim.unique()[0] fr_stim = stack_final.fr_stim.unique()[0] ax.set_title( cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str( snippets) + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' std ' + str( amp) + ' % ' + '\n $cv_{B}=$' + str(np.round(cv, 2)) + ',$f_{B}=$' + str(np.round(fr)) + 'Hz' + ',$cv_{S}=$' + str(np.round(cv_stim, 2)) + ',$f_{S}=$' + str( np.round(fr_stim)) + 'Hz' , fontsize=7) # + cell_type stack_plot = stack_final keys = stack_plot.keys() new_keys = stack_plot.index try: stack_plot = stack_plot[new_keys] except: new_keys = list(map(str, new_keys)) stack_plot = stack_plot[new_keys] stack_plot = stack_plot.astype(complex) stack_plot.columns = list(map(float, stack_plot.columns)) mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0]) plot = True if plot: pcolor = False if pcolor: im = ax.pcolormesh(np.array(mat.index), np.array(list(map(float, mat.columns))), mat, vmin=np.nanpercentile(mat, 5), vmax=np.nanpercentile(mat, 95), cmap='Greens', rasterized=True ) # rasterized = True else: im = ax.imshow(mat, origin='lower', extent=[float(np.min(mat.columns)), float(np.max(mat.columns)), float(np.min(mat.index)), float(np.max(mat.index))], vmin=np.nanpercentile(mat, 5), vmax=np.nanpercentile(mat, 95), cmap='viridis', ) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95), plt.suptitle(cell_type) ax.set_xlim(float(np.min(mat.index)), float(np.max(mat.index))) ax.set_ylim(float(np.min(mat.index)), float(np.max(mat.index))) ax.set_xlim(0, 300) ax.set_ylim(0, 300) ax.set_aspect('equal') plt_triangle(ax, fr, fr_stim, new_keys[-1], eod_fr) plt_50_Hz_noise(ax, new_keys[-1]) if plot: cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, shrink=0.6, width=width) # 0.02 if cbar_label: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) return ax_data def plt_data_susept(fig, grid, cells_chosen, eod_metrice=True, fr_print=False, amp_given=None, nr=3, cell_type='p-unit', xlabel=True, lp=10, title=True, cbar_label=True, xpos=1.1, width=0.005, n_print = True): file_names_exclude = get_file_names_exclude(cell_type) if len(cells_chosen) > 0: cells = cells_chosen ax_data = [] stack_spikes_all = [] eod_frs = [] for f, cell in enumerate(cells): ax = plt.subplot(grid[f]) ax_data.append(ax) eod_fr, stack_spikes = plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp, title, width, fr_print=fr_print, nr=nr, eod_metrice=eod_metrice, amp_given=amp_given, n_print = n_print, xpos=xpos, xlabel=xlabel) stack_spikes_all.append(stack_spikes) eod_frs.append(eod_fr) return ax_data, stack_spikes_all, eod_frs def plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp, title, width, fr_print=False, eod_metrice=True, xpos = 1.1, ypos=1.05, n_print = True, nr=3, xlabel=True, amp_given=None): if cell == '2012-07-03-ak-invivo-1': pass else: pass save_name = version_final() # ] load_name = load_folder_name('calc_RAM') + '/' + save_name ######################################### # also die einzelzellen sind in pkls add = '_cell' + cell # str(f) # + '_amp_' + str(amp) stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell, add=add, load_version='csv') try: stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)] except: print('stack cell problem') stack_cell = [] if len(stack_cell): file_names = stack_cell.file_name.unique() file_names2 = exclude_file_name_short(file_names) cut_off_nr = get_cutoffs_nr(file_names2) try: maxs = list(map(float, cut_off_nr)) except: print('error1') embed() file_names2 = file_names2[np.argmax(maxs)] try: stack_file = stack_cell[stack_cell['file_name'] == file_names2] except: print('stack file something') embed() amps = [np.min(stack_file.amp.unique())] amps = restrict_punits(cell, amps) for amp in amps: stack_amps = stack_file[stack_file['amp'] == amp] lengths = stack_amps.stimulus_length.unique() try: length_max = [np.max(lengths)] except: print('length thing') embed() for length in length_max: trial_nr_double = stack_amps.trial_nr.unique() trial_nr = np.max(trial_nr_double) stack_final = stack_amps[ (stack_amps['stimulus_length'] == length) & (stack_amps.trial_nr == trial_nr)] stack_spikes = load_data_susept(load_name + '_' + cell + '.pkl', load_name, load_version='csv', load_type='spikes', add=add, trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_names2) snippets = stack_final['snippets'].unique()[0] eod_fr = stack_final.eod_fr.unique()[0] fr = stack_final.fr.unique()[0] cv_stim = stack_final.cv_stim.unique()[0] fr_stim = stack_final.fr_stim.unique()[0] if title: if amp_given: amp = amp_given if n_print: add = '\n $N = % s$' % snippets else: add = '' if fr_print: add += '\n fr$_{S}$=' + str(int(np.round(fr_stim))) + 'Hz' + ' cv$_{S}$=' + str( np.round(cv_stim, 2)) else: add += '' ax.text(xpos, ypos, 'Recorded P-unit' + add, ha='right', transform=ax.transAxes) # , fontsize7= + cell_type# cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str( mat, new_keys = get_mat_susept(stack_final) mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat) im, plot = plt_mat_susept(ax, mat) if f == len(cells) - 1: ax.set_xticks_delta(100) if xlabel: set_xlabel_arrow(ax, xpos=xpos) else: remove_xticks(ax) ax.set_xlim(float(np.min(mat.index)), float(np.max(mat.index))) ax.set_ylim(float(np.min(mat.index)), float(np.max(mat.index))) ax.set_xlim(0, 300) ax.set_ylim(0, 300) ax.set_aspect('equal') plt_triangle(ax, fr, fr_stim, new_keys[-1], eod_fr, lines=False, eod_metrice=eod_metrice, nr=nr) set_clim_same([im], mats=[mat], lim_type='up', nr_clim='perc', clims='', percnr=perc_model_full()) if plot: cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, shrink=0.6, width=width) # 0.02 if cbar_label: cbar.set_label(nonlin_title(' [' + add_nonlin_title), rotation=90, labelpad=lp) else: stack_spikes = [] eod_fr = [] return eod_fr, stack_spikes def set_xlabel_arrow(ax, xpos=1.05, ypos=-0.35, color='black', arrow = False): val = F1_xlabel() set_xlabel_arrow_core(ax, val, xpos, ypos, color=color) if arrow: ax.arrow_spines('b') def exclude_file_name_short(file_names): file_names2 = [] for file in file_names: if 'short' not in file: file_names2.append(file) return file_names2 def plt_mat_susept(ax, mat): plot = True if plot: pcolor = False if pcolor: im = ax.pcolormesh(np.array(mat.index), np.array(list(map(float, mat.columns))), mat, vmin=np.nanpercentile(mat, 5), vmax=np.nanpercentile(mat, 95), cmap='Greens', rasterized=True ) # rasterized = True else: im = ax.imshow(mat, origin='lower', extent=[float(np.min(mat.columns)), float(np.max(mat.columns)), float(np.min(mat.index)), float(np.max(mat.index))], vmin=np.nanpercentile(mat, 5), vmax=np.nanpercentile(mat, 95), cmap='viridis', ) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95), return im, plot def get_mat_susept(stack_final): new_keys, stack_plot = convert_csv_str_to_float(stack_final) norm_d = False if norm_d: mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0]) else: mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # return mat, new_keys def get_file_names_exclude(cell_type='p-unit'): if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'gwn25Hz10s0.3', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn50Hz50s0.3', 'FileStimulus-file-gaussian25.0', 'gwn50Hz10s0.3', ] # else: file_names_exclude = ['blwn125Hz10s0.3', 'gwn50Hz10s0.3', 'InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0', 'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] # return file_names_exclude def get_cutoffs_nr(file_names): cut_off_nr = [] for ff, file_name in enumerate(file_names): if 'hz' in file_name.lower(): cut_off_nr = get_cut_off_for_wn(cut_off_nr, file_name) elif 'gaussian' in file_name: cut_off_nr.append(file_name.split('gaussian')[1]) else: cut_off_nr.append(file_name[-5::]) return cut_off_nr def find_eod(frame_cell, EOD='EOD', sp=0): if EOD in frame_cell: eods, hists, frs_calc, cont = load_spikes(frame_cell[EOD].iloc[0], frame_cell['EODf'].iloc[0]) try: eod = eods[sp] except: print('eod sp thing') embed() sampling_rate = frame_cell.sampling.iloc[0] ds = int(frame_cell.downsample.iloc[0]) time_eod = np.arange(0, len(eod) / sampling_rate, 1 / sampling_rate) # [::ds] if len(time_eod) > len(eod): time_eod = time_eod[0:len(eod)] elif len(time_eod) < len(eod): eod = eod[0:len(time_eod)] return eod, sampling_rate, ds, time_eod def plot_lin_nonlin(aa, add, amp, amps_defined, axds, axos, c, cells_plot, file_name, grid_s1, ims, load_name, stack_file, xlim=[], test_clim=False, power_type=False, permuted=False, peaks_extra=False, zorder=1, alpha=1, extra_input=False, fr=None, title_square='', fr_diag=None, nr=1, line_length=1 / 4, text_scalebar=False, xpos_xlabel=-0.2, add_nonlin_title=None, amp_give=True, color='grey', axo2=None, axd2=None, axi=None, eod_metrice=True, ax_square=None, transfer=True, base_extra=False, color_same=True, snippets = 20, iterate_var=[0, 1], normval=1): if not fr_diag: fr_diag = fr eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file, snippets = snippets) stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='osf', trial_nr=trial_nr, stimulus_length=length, add=add, amp=amp, file_name=file_name) stack_isf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='isf', trial_nr=trial_nr, stimulus_length=length, add=add, amp=amp, file_name=file_name) test_limits = False # hier bereinige ich von Duplicates add_nonlin_title, stack_plot = reduce_dubplicates(add_nonlin_title, stack_final) mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat, add_nonlin_title=add_nonlin_title) axis_d = axis_projection(mat, axis='') if power_type: ################################## # das ist wenn wir das psd geplottet haben wollen axd, _, axo2 = plt_psds_all(axd2, axo2, mat, stack_final, stack_osf, test_limits, xlim, color=color, alpha=alpha, db='db', fr=fr, power_type=power_type, zorder=zorder, eod_fr=eod_fr, peaks_extra=peaks_extra) else: ################################### # das ist jetzt die DEFAULT version # plot the diagonal db_diag = 'db' # try: xmax, xmin, diagonals_prj_l = plt_diagonal(axd2, color, db_diag, fr_diag, mat, alpha, eod_fr, peaks_extra, xlim=xlim, zorder=zorder, normval=normval, color_same=color_same) prob = False except: print('diagonal of mat not working') diagonals_prj_l = [] prob = True if permuted: add_nonlin_title = plt_permuted_diagonal(add_nonlin_title, axd2, axis_d, db_diag, stack_final) # ################################### # plt transferfunction axd = axd2 if transfer: plt_transferfunction(alpha, axo2, color, stack_final, label=title_square, normval=normval, zorder=zorder) xmax_tf = 400 if normval != 1: axo2.set_xlim(xmin, xmax_tf / eod_fr) else: axo2.set_xlim(xmin, xmax / 2) axos.append(axo2) # np.max(mat.columns) axds.append(axd) # np.max(mat.columns) ######################################## # plot input if we need it # NOT DEFAULT if extra_input: axis_d = axis_projection(mat, axis='') xmax, xmin = get_xlim_psd(axis_d, xlim) plt_power_trace(alpha, axi, color, 'db', stack_final, stack_isf, test_limits, xmax, eod_fr=eod_fr) axi.set_xlim(xmin, xmax) ############################# # plot second-order susceptibility if not ax_square: ax_square = plt.subplot(grid_s1[:, 2 + aa]) if (aa == len(iterate_var) - 1) | test_clim: cbar_true = True else: cbar_true = False # embed() mat, test_limits, im, add_nonlin_title = plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims, stack_final1, [], perc=False, cbar_true=cbar_true, xpos=0, ypos=1.05, color=color, fr=fr, base_extra=base_extra, eod_metrice=eod_metrice, nr=nr, amp_give=amp_give, title_square=title_square, line_length=line_length, ha='left', xpos_xlabel=xpos_xlabel, alpha=alpha, add_nonlin_title=add_nonlin_title) ims.append(im) if text_scalebar: if (aa == len(iterate_var) - 1) | test_clim: fig = plt.gcf() _, _, _, _, _ = colorbar_outside(ax_square, im, fig, add=5, width=0.01) ax_square.text(1.45, 0.25, nonlin_title(' [' + add_nonlin_title), ha='center', rotation=90, transform=ax_square.transAxes) if prob: print('prob something') embed() return diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title def reduce_dubplicates(add_nonlin_title, stack_final): new_keys, stack_plot = convert_csv_str_to_float(stack_final) duplicate_mask = stack_final.duplicated(subset=new_keys) if duplicate_mask.any(): stack_final.drop_duplicates(subset=new_keys, inplace=True) new_keys, stack_plot = convert_csv_str_to_float(stack_final) test = False if test: from utils_test import test_dublicates1, test_dublicates2 add_nonlin_title = test_dublicates1(stack_final) test_dublicates2(stack_final) return add_nonlin_title, stack_plot def plt_permuted_diagonal(add_nonlin_title, axd2, axis_d, db_diag, stack_final): add_nonlin_title, isfs_all, isfs_correct, mats_all, mats_all_correct2 = get_fft_matrices(stack_final, add_nonlin_title) _, _ = get_mat_diagonals(mats_all_correct2) if db_diag == 'db': pass mats_all = np.array(mats_all) diags_permuted = [] isfs_all = np.array(isfs_all) for i in range(300): random_numbers = sample(range(1, len(mats_all)), 20) mean_matrix = np.sum(mats_all[random_numbers], axis=0) mean_matrix2 = norm_suscept_whole(abs, isfs_all[random_numbers], stack_final, mean_matrix, len(isfs_correct)) mean_matrix2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mean_matrix2, add_nonlin_title=add_nonlin_title) diag, diagonals_prj_l_perm = get_mat_diagonals(np.array(mean_matrix2)) if db_diag == 'db': diagonals_prj_l_perm = 10 * np.log10(diagonals_prj_l_perm) diags_permuted.append(diagonals_prj_l_perm) diags_perm = np.transpose(diags_permuted) axd2.plot(axis_d, np.percentile(diags_perm, 95, axis=1), color='darkgrey') axd2.plot(axis_d, np.percentile(diags_perm, 5, axis=1), color='darkgrey') return add_nonlin_title def get_fft_matrices2(stack_final, add_nonlin_title=''): isfs = get_isfs(stack_final, isf_name='isf') osfs = get_isfs(stack_final, isf_name='osf') f_range = np.arange(len(stack_final)) mats_all_correct = [] isfs_correct = [] for t in range(len(osfs)): print('t' + str(t)) f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osfs[t], f_range, isfs[t], norm='') # stimulus, mats_all_correct.append(mat_all) isfs_correct.append(isfs[t]) ######################### # the corrected matrices mats_all_correct = np.array(mats_all_correct) mats_all_correct2 = np.sum(np.array(mats_all_correct), axis=0) mats_all_correct2 = norm_suscept_whole(abs, isfs_correct, stack_final, mats_all_correct2, len(isfs_correct)) mats_all_correct2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mats_all_correct2, add_nonlin_title=add_nonlin_title) return add_nonlin_title, isfs_correct, mats_all_correct2 def get_fft_matrices(stack_final, add_nonlin_title=''): isfs = get_isfs(stack_final, isf_name='isf') osfs = get_isfs(stack_final, isf_name='osf') f_range = np.arange(len(stack_final)) mats_all = [] mats_all_correct = [] isfs_correct = [] isfs_all = [] for t in range(len(osfs)): for tt in range(len(osfs)): print('t' + str(t) + ' tt' + str(tt)) f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osfs[t], f_range, isfs[tt], norm='') # stimulus, if t != tt: mats_all.append(mat_all) isfs_all.append(isfs[tt]) else: mats_all_correct.append(mat_all) isfs_correct.append(isfs[tt]) ######################### # the corrected matrices mats_all_correct = np.array(mats_all_correct) mats_all_correct2 = np.sum(np.array(mats_all_correct), axis=0) mats_all_correct2 = norm_suscept_whole(abs, isfs_correct, stack_final, mats_all_correct2, len(isfs_correct)) mats_all_correct2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mats_all_correct2, add_nonlin_title=add_nonlin_title) return add_nonlin_title, isfs_all, isfs_correct, mats_all, mats_all_correct2 def plt_psds_in_one_squares(aa, add, amp, amps_defined, axds, axes, axis, axos, c, cells_plot, colors_b, file_name, files, grid_s1, grid_s2, ims, load_name, stack_file, wss, xlim, axo2=None, axd2=None, iterate_var=[0, 1]): if aa == 0: try: axd2 = plt.subplot(grid_s2[1, 0]) # plt.subplot(grid_s[0]) axo2 = plt.subplot(grid_s2[0, 0]) except: print('grid thing3') embed() eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='osf', trial_nr=trial_nr, stimulus_length=length, add=add, amp=amp, file_name=file_name) test_limits = False new_keys, stack_plot = convert_csv_str_to_float(stack_final) mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # axd, axi, axo2 = plt_psds_all(axd2, axo2, mat, stack_final, stack_osf, test_limits, xlim, color='grey', db='db') grid_s = grid_s1 axd = None axo = None else: grid_s = grid_s2 axd = axd2 axo = axo2 ax_square, axi, eod_fr, fr, stack_final1, stack_spikes, im, axd, axo = plt_square_with_psds(aa, amp, amps_defined, axes, axis, c, cells_plot, files, grid_s, ims, load_name, stack_file, xlim, cbar_true=False, axd=axd, axo=axo, color= colors_b[aa], add=add, file_name=file_name) axos.append(axo) # np.max(mat.columns) axds.append(axd) # np.max(mat.columns) test_limits = False if test_limits: axo.set_ylabel('Output') axd.set_ylabel('Projection') else: remove_yticks(axo) remove_yticks(axd) if aa == 0: axo.text(-0.45, 0, 'Output', rotation=90, transform=axo.transAxes) axd.text(-0.45, 0, 'Projection', rotation=90, transform=axd.transAxes) axd.yscalebar(-0.1, 0.5, 10, 'dB', va='center', ha='left') axo.yscalebar(-0.1, 0.5, 10, 'dB', va='center', ha='left') axd.show_spines('b') axo.show_spines('b') ims.append(im) if aa == len(iterate_var) - 1: fig = plt.gcf() cbar, left, bottom, width, height = colorbar_outside(ax_square, im, fig, add=5, width=0.01) cbar.set_label(nonlin_title(), rotation=90, labelpad=10) return axi, eod_fr, fr, stack_final1, stack_spikes, axds, axos, ax_square, axo2, axd2 def nix_load(cell, stack_final1): data_dir = 'cells/' data_name = cell name_core = load_folder_name('data') + data_dir + data_name nix_name = name_core + '/' + data_name + '.nix' # '/' f = nix.File.open(nix_name, nix.FileMode.ReadOnly) b = f.blocks[0] try: names_mt_gwn = stack_final1['names_mt_gwn'].unique()[0] except: print('names mt') embed() mt = b.multi_tags[names_mt_gwn] features, id, data_between_2017_2018, mt_ids = find_feature_gwn_id(mt) dataset, rlx_problem = load_rlxnix(nix_name) # wir machen das hier für diese rlx only weil ich nur so an den Kontrast komme spikes_loaded = [] if rlx_problem: file_name, file_name_save, cut_off, file, sd = find_file_names(nix_name, mt, names_mt_gwn) file_extra, idx_c, base_properties, id_names = get_contrasts_over_rlx_calc_RAM(dataset) dataset.close() # contrasts_sort_idx = np.argsort(base_properties) try: base_properties = base_properties.sort_values(by='c', ascending=False) except: print('contrast problem sorting') embed() # hier muss ich nochmal nach dem file sortieren! if data_between_2017_2018 != 'all': file_name_sorted = base_properties[base_properties.file_name == file_name] else: file_name_sorted = base_properties if len(file_name_sorted) < 1: print('file_name problem') embed() file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1] # ich sollte auf dem level schon nach dem richtigen filename filtern! file_name_sorted = file_name_sorted[file_name_sorted['c_orig'] == stack_final1['c_orig'].unique()[0]] grouped = file_name_sorted.groupby('c') # ok es gibt wohl eine Zelle die erste, Zelle '2010-06-15-af' wo eben das nicht input arr heißt sondern gwn 300, was da passiert ist kann ich # euch jetzt so auch nicht sagen, aber alle anderen Zellen sehen gut aus! Scheint die einzige zu sein° data_array_names = get_data_array_names(b) # ,find_indices_to_match_contrats,get_data_array_names if 'eod' in ''.join(data_array_names).lower(): for g, group in enumerate(grouped): # hier erstmal alles nach dem Kontrast sortieren sd, start, end, rep, cut_off, c_len, c_unit, c_orig, c_len, files_load, cc, id_g, amplsel = open_group_gwn( group, file_name, cut_off, sd, data_between_2017_2018) indices, ends_mt = find_indices_to_match_contrats(grouped, group, mt, id_g, mt_ids, data_between_2017_2018) indices = list(map(int, indices)) max_f = cut_off if max_f == 0: print('max f = 0') embed() for mm, m in enumerate(indices): first, minus, second, stimulus_length = find_first_second(b, names_mt_gwn, m, mt, False, mm=mm, ends_mt=ends_mt) spikes_mt = link_arrays_spikes(b, first, second, minus) # spikes_loaded.append(spikes_mt * 1000) eod_mt, sampling = link_arrays_eod(b, first, second, array_name='LocalEOD-1') # hier noch das stimpresaved laden else: print('rlx thing') return eod_mt, sampling, spikes_loaded def burst_data(): plot_style() cells = p_units_to_show(type_here='burst_didactic') save_names = [version_final()] # amps_desired, cell_type_type, cells_plot, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class) cell_type_type = 'cell_type_reclassified' # frame = load_cv_base_frame(cells, cell_type_type=cell_type_type, redo=True) default_settings(column=2, width=12, length=8.5) # ts=10, fs=10, ls=10, frame, frame_spikes = load_cv_vals_susept(cells, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_spikes='/calc_base_data-base_frame_EOD1__overview.pkl', frame_general=False) frame = unify_cell_names(frame, cell_type=cell_type_type) frame_load = frame # [frame['cell'].isin(cells_exclude)] colors = colors_overview() tags_cell = [] grid = gridspec.GridSpec(len(cells), 1, wspace=0.1, hspace=0.21, top=0.97, left=0.105, bottom=0.085, right=0.9) for c, cell in enumerate(cells): print(cell) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame) ims = [] tags = [] add_here = '_cell' + cell # str(c) for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv') try: grid_base_stim = gridspec.GridSpecFromSubplotSpec(2, 1, grid[c], height_ratios=[3, 5], hspace=0.3) except: print('cell thing3') embed() grid_base = gridspec.GridSpecFromSubplotSpec(2, 2, grid_base_stim[0], hspace=0.3) if len(stack) > 0: files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True) file_name = files[0] stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() amps_defined = amps grid_stim = gridspec.GridSpecFromSubplotSpec(len(amps), 1, grid_base_stim[1], hspace=0.3) trues = [] for amp in amps_defined: if amp in amps: trues.append(True) cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1'] if cell == cells_amp: print('cell thing') embed() ims = [] xlim_e = [0, 200] for aa, amp in enumerate(amps_defined): add_save = '_cell' + str(cell) + '_amp_' + str(amp) alpha_min = (1 - 0.2) / len(np.unique(stack_file['amp'])) # 25 if amp in np.array(stack_file['amp']): grid_stim_aa = gridspec.GridSpecFromSubplotSpec(2, 2, grid_stim[aa], height_ratios=[3, 5], hspace=0.3, width_ratios=[5, 2]) ax_square = plt.subplot(grid_stim_aa[:, -1]) eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) _, _, _, _ = plt_square_here(aa, amp, amps_defined, ax_square, c, cells, ims, stack_final1, [], amp_give=False, cbar_true=False) tags.append(ax_square) ###################################################### spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) axe = plt.subplot(grid_stim_aa[0, 0]) plt_stimulus(eod_fr, axe, stack_final1, xlim_e, file_name=files[0]) tags.insert(1, axe) ################################ # spikes ax_spikes = plt.subplot(grid_stim_aa[1, 0]) eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) # todo: hier noch mehr trials laden stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv', load_type='spikes', trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_name) test = False if test: pass if aa == 2: scale = True else: scale = False # embed() # todo: das mit dem hist will ich dohc noch haben plt_spikes(c, cells, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes, alpha=1 - alpha_min * aa, scale=scale) ax_spikes.text(1, 0.5, str(amp) + '$\%$', transform=ax_spikes.transAxes, ) set_clim_same(ims, clims='all', same='same') ################################ # isi if len(isi) > 0: ax_isi = plt.subplot(grid_base[:, 1]) plt_susept_isi_base(colors[str(cell_type)], ax_isi, isi) tags.insert(0, ax_isi) cell_type_type = 'cell_type_reclassified' frame_spikes_cell = frame_spikes[(frame_spikes['cell'] == cell)] eod, sampling_rate, ds, time_eod = find_eod(frame_spikes_cell, EOD='EOD') eod_period, zero_crossings, zero_crossings_eod2 = find_mean_period( eod, sampling_rate) nrs = 6 spikes_cut, eods_cut, times_cut = cut_spikes_to_certain_lenght_in_period(time_eod, ax_isi, eod, False, nrs, spikes_base[0], xlim_e, zero_crossings) axe = plt.subplot(grid_base[0, 0]) for nr in range(1): axe.plot(times_cut[nr], eods_cut[nr]) ax_isi = plt.subplot(grid_base[1, 0]) ax_isi.eventplot(spikes_cut) ############################### # stimulus tags_cell.append(tags) fig = plt.gcf() fig.tag(tags_cell, xoffs=-4.7, yoffs=1.9) # -1.5diese Offsets sind nicht intuitiv save_visualization() def plt_cellbody_eigen(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[], titles=['Baseline \n Susceptibility', 'Half EODf \n Susceptibility'], peaks_extra=[False, False, False], base_extra=False): colors = colors_overview() axis = [] tags_cell = [] lengths = [0.5, 0.25] for c, cell in enumerate(cells_plot): print(cell) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame) ims = [] tags = [] add_here = '_cell' + cell # str(c) mats = [] xlim_e = [0, 70] zorders = [100, 50] for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot) axes = [] if len(stack) > 0: files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True) file_name = files[0] stack_here = stack[stack.trial_nr > 1] stack_file = stack_here[stack_here['file_name'] == files[0]] amps = stack_file['amp'].unique() predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps trues = [] for amp in amps_defined: if amp in amps: trues.append(True) # ok das ist jetzt extra für die Bespiele ausgesucht amps_defined = [20] # [amps[nr_e[c]]] cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1'] if cell == cells_amp: print('cell thing') embed() ws = 0.5 first = 1 wr_l = [first, 0, 1.3, 0, 1] # 1] wr_u = [1.15, 0.15, 1.5] # , 1] ws_total = np.sum(wr_u) + len(wr_u) * ws ws_total - 2 * ws - 0.15 - 1.5 grid_cell, grid_upper = grids_upper_susept_pics(c, grid1, ws=ws, hr=[1, 0.4], row=2, col=3, wr_u=wr_u) ws = 0.3 ims = [] axds = [] axos = [] extra_input = False several = False axd2, axi, axo2, grid_lower, grid_s1, grid_s2 = grids_for_psds2(amps_defined, extra_input, grid_cell, several, wr=wr_l, ws=ws, add=1) add_nonlin_title = None xpos_xlabel = -0.25 normval = 1 for aa, amp in enumerate(amps_defined): alpha = find_alpha_val(aa, amps_defined) add_save = '_cell' + str(cell) + '_amp_' + str(amp) wss = ws_for_susept_pic() colors_b = ['grey', colors[cell_type]] right = False if amp in np.array(stack_file['amp']): print(zorders[aa]) if not several: diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title = plot_lin_nonlin( aa, add_save, amp, amps_defined, axds, axos, c, cells_plot, file_name, grid_lower, ims, load_name, stack_file, xlim=[], peaks_extra=peaks_extra[c], zorder=zorders[aa], alpha=alpha, extra_input=extra_input, line_length=lengths[c], xpos_xlabel=xpos_xlabel, add_nonlin_title=add_nonlin_title, color=colors[cell_type], axo2=axo2, axd2=axd2, axi=axi, iterate_var=amps_defined, normval=normval) mats.append(mat) else: axi, eod_fr, fr, stack_final1, stack_spikes, axds, axos, ax_square, axo2, axd2 = plt_psds_in_one_squares( aa, add, amp, amps_defined, axds, axes, axis, axos, c, cells_plot, colors_b, file_name, files, grid_s1, grid_s2, ims, load_name, stack_file, wss, xlim, axo2=axo2, axd2=axd2, iterate_var=amps_defined) if aa == 0: if len(axi) < 1: tags.append(axo2) else: tags.append(axi) tags.append(ax_square) tags.append(axd2) ###################################################### spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) ################################ # spikes ax_spikes = plt.subplot(grid_upper[1 + aa, -2::]) eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv', load_type='spikes', trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_name) if aa == 0: scale = True else: scale = False plt_spikes(c, cells_plot, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes, alpha=alpha, xlim_e=xlim_e, sc=10, scale=scale) # 1 - alpha_min * aa amp_name = round_for_nice_float_strs(amp) ax_spikes.text(1.01, 0.55, str(amp_name) + '$\%$', va='center', transform=ax_spikes.transAxes, color=colors[str(cell_type)], alpha=alpha) labels_for_psds(axd2, axi, axo2, extra_input, xpos_xlabel=xpos_xlabel, chi_pos=-0.1, right=right) set_clim_same(ims, mats=mats, lim_type='up') # do the scatter of these cells add = ['', '_burst_corr_individual'] if len(stack) > 0: load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell if ax3 != []: try: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack) except: print('stim problem') embed() try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() ################################ # isi if len(isi) > 0: if aa == len(amps_defined) - 1: grid_p = gridspec.GridSpecFromSubplotSpec(1, 2, grid_upper[:, 0], width_ratios=[1.4, 2], wspace=0.35, hspace=0.55) ax_isi = plt.subplot(grid_p[0]) ax_p = plt.subplot(grid_p[1]) # embed() ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame, isi, right, spikes_base, stack, xlim, base_extra=base_extra, texts_left=[90, 0], titles=titles, peaks=True, xlim_i=[0, 4]) # ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, colors, frame, # isi, base_extra=base_extra, # titles=titles, xlim_i=[0, 4]) tags.insert(0, ax_isi) ############################### # stimulus if aa == len(amps_defined) - 1: axe = plt.subplot(grid_upper[0, -2::]) plt_stimulus(eod_fr, axe, stack_final1, xlim_e, file_name=files[0]) tags.insert(1, axe) tags_cell.append(tags) try: tags_susept_pictures(tags_cell, xoffs=np.array([-4.7, -3.2, -4.7, -4.3, -6.3]), yoffs=np.array([1.1, 1.1, 2, 2, 2])) # , xoffs=np.array([-5.2, -4.2, -5.2, -5.7, -4.7,-5]) except: print('tags here') embed() def plt_cellbody_singlecell(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[], permuted=False, RAM=True, isi_delta=None, titles=['Low CV P-unit', 'High CV P-unit', 'Ampullary cell'], peaks_extra=[False, False, False], base_extra=False, color_same=True, fr_name='$f_{Base}$', eod_metrice=True, tags_individual=False, xlim_p=[0, 1.1], add_texts=[0.25, 0], scale_val=False): colors = colors_overview() tags_cell = [] for c, cell in enumerate(cells_plot): print(cell) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame) ims = [] tags = [] add_here = '_cell' + cell # str(c) mats = [] zorders = [100, 50] for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell if cell == '2012-07-03-ak-invivo-1': snippets = 4 else: snippets = 20 stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot) if len(stack) > 0: files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True) file_name = files[0] stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps trues = [] for amp in amps_defined: if amp in amps: trues.append(True) amps_defined = [np.min(amps), np.max(amps)] cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1'] if cell == cells_amp: print('cell thing') embed() wr_l = wr_l_cells_susept() wr_u = [1.4, 0.1, 1, 1] grid_cell, grid_upper = grids_upper_susept_pics(c, grid1, wr_u=wr_u) ims = [] axds = [] axos = [] extra_input = False several = False axd2, axi, axo2, grid_lower, grid_s1, grid_s2 = grids_for_psds(amps_defined, extra_input, grid_cell, several, wr=wr_l) power_type = False ax_psds = [] add_nonlin_title = None xpos_xlabel = -0.23 diag_vals = [] for aa, amp in enumerate(amps_defined): alpha = find_alpha_val(aa, amps_defined) add_save = '_cell' + str(cell) + '_amp_' + str(amp) right = 'middle' # , normval = 1 if amp in np.array(stack_file['amp']): print(zorders[aa]) diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title = plot_lin_nonlin( aa, add_save, amp, amps_defined, axds, axos, c, cells_plot, file_name, grid_lower, ims, load_name, stack_file, xlim, power_type=power_type, permuted=permuted, peaks_extra=peaks_extra[c], zorder=zorders[aa], alpha=alpha, extra_input=extra_input, fr=fr, xpos_xlabel=xpos_xlabel, add_nonlin_title=add_nonlin_title, color=colors[cell_type], axo2=axo2, axd2=axd2, axi=axi, eod_metrice=eod_metrice, base_extra=base_extra, color_same=color_same, iterate_var=amps_defined, normval=normval, snippets = snippets) diag_vals.append(np.median(diagonals_prj_l)) mats.append(mat) if aa == 0: if len(axi) < 1: tags.append(axo2) else: tags.append(axi) tags.append(ax_square) if aa == 1: tags.append(axd2) spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) ################################ # spikes ax_spikes = plt.subplot(grid_upper[1 + aa, -2::]) eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv', load_type='spikes', trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_name) if (aa == 1) | (scale_val == True): scale = True else: scale = False plt_spikes(c, cells_plot, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes, alpha=alpha, scale=scale) # 1 - alpha_min * aa amp_name = round_for_nice_float_strs(amp) ax_spikes.text(1.01, 0.55, str(int(amp_name)) + '$\%$', va='center', transform=ax_spikes.transAxes, color=colors[str(cell_type)], alpha=alpha) ax_psds.extend([axd2]) ax_psds.extend([axo2]) axd2.annotate('', ha='center', xy=(1, diag_vals[1]), xytext=(1, diag_vals[0]), arrowprops={"arrowstyle": "<->", "linestyle": "-", "linewidth": 0.7, "color": 'black'}, zorder=1, annotation_clip=False) labels_for_psds(axd2, axi, axo2, extra_input, right=right, xpos_xlabel=xpos_xlabel, normval=normval) set_same_ylimscale(ax_psds) # todo: hier eventuell noch einen percent machen damit das nicht so vebrlendet set_clim_same(ims, mats=mats, lim_type='up', percnr=95) # do the scatter of these cells if len(stack) > 0: load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell if ax3 != []: try: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack) except: print('stim problem') embed() try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() ################################ # isi if len(isi) > 0: if aa == len(amps_defined) - 1: grid_p = gridspec.GridSpecFromSubplotSpec(1, 2, grid_upper[:, 0], width_ratios=[1.4, 2], wspace=0.38, hspace=0.55) ax_isi = plt.subplot(grid_p[0]) ax_p = plt.subplot(grid_p[1]) ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame, isi, right, spikes_base, stack, xlim_p, base_extra=base_extra, add_texts=add_texts, titles=titles, peaks=True, fr_name=fr_name) if isi_delta: ax_isi.set_xticks_delta(isi_delta) if tags_individual: tags.insert(0, ax_p) tags.insert(0, ax_isi) ############################### # stimulus xlim_e = [0, 100] if aa == len(amps_defined) - 1: axe = plt.subplot(grid_upper[0, -2::]) plt_stimulus(eod_fr, axe, stack_final1, xlim_e, RAM=RAM, file_name=files[0]) if tags_individual: tags.insert(2, axe) else: tags.insert(1, axe) tags_cell.append(tags) try: if len(cells_plot) == 1: if tags_individual: tags_susept_pictures(tags_cell[0], xoffs=np.array([-4.7, -3.2, -3.2, -4.7, -6.3, -2.7, -3.2]), yoffs=np.array([3, 3, 3, 5.5, 5.5, 5.5, 5.5])) else: tags_susept_pictures(tags_cell, yoffs=np.array([3, 3, 5.5, 5.5, 5.5, 5.5])) else: tags_susept_pictures(tags_cell) except: print('tags here') embed() def base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame, isi, right, spikes_base, stack, xlim, texts_left=(0.25, 0), clip_on=True, add_texts=(0.25, 0), base_extra=False, titles=('', '', '', '', ''), pos=-0.25, peaks=False, fr_name='$f_{Base}$', xlim_i=(0, 16)): # ax_isi.text(-0.2, 0.5, 'Baseline', rotation=90, ha='center', va='center', transform=ax_isi.transAxes) plt_susept_isi_base('grey', ax_isi, isi, xlim=xlim_i, clip_on=clip_on) # colors[str(cell_type)]c, cell_type, cells_plot, normval = 1 if normval != 1: ax_p.text(1.1, -0.4, f_eod_label_core(), ha='center', va='center', transform=ax_p.transAxes) # transform=ax_isi.transAxes, else: ax_p.text(1.1, -0.4, f_eod_label_core_hz(), ha='center', va='center', transform=ax_p.transAxes) # transform=ax_isi.transAxes, ax_p.arrow_spines('b') # embed() ax_p = plt_susept_psd_base('grey', eod_fr, ax_p, spikes_base, xlim, right=right, add_texts=add_texts, normval=normval, texts_left=texts_left, sampling_rate=stack.sampling.iloc[0], peaks=peaks, fr_name=fr_name) # colors[str(cell_type)] # cvs = True # embed() if cvs: cv = frame[frame.cell == cell].cv.iloc[ 0] # str(np.round(frame[frame.cell == cell].cv.iloc[0], 2))# color=colors[str(cell_type)], fr = frame[frame.cell == cell].fr.iloc[ 0] # str(np.round(frame[frame.cell == cell].cv.iloc[0], 2))# color=colors[str(cell_type)], if base_extra: if titles[c] == '': add_nrs = r'$\mathrm{f'+basename_small()+'}=%.0f$\,Hz,' % fr + r' $\mathrm{CV'+basename_small()+'}=%.2f$' % cv else: add_nrs = r'$\mathrm{f}'+basename_small()+'}=%.0f$\,Hz,' % fr + r' $\mathrm{CV'+basename_small()+'}=%.2f$' % cv ax_isi.text(pos, 1.25, titles[c] + add_nrs, transform=ax_isi.transAxes) # str(np.std(isi) / np.mean(isi)) else: ax_isi.text(pos, 1.25, titles[c] + r' $\rm{CV}=%.2f$' % cv, transform=ax_isi.transAxes) # str(np.std(isi) / np.mean(isi)) else: ax_isi.text(pos, 1.2, titles[c], color=colors[str(cell_type)], transform=ax_isi.transAxes) return ax_isi def f_eod_label_core(): return '$f/'+f_eod_name_core_rm()+'$' def f_eod_label_core_hz(): return '$f$ [Hz]' def wr_l_cells_susept(): wr_l = [0.5, 0, 1, 1, 0.2, 0.5] return wr_l def set_same_ylimscale(ax_psds): ranges = [] for ax in ax_psds: lim = ax.get_ylim() lim_range = lim[1] - lim[0] ranges.append(lim_range) new_lim = np.max(ranges) for ax in ax_psds: lim = ax.get_ylim() lim_range = lim[1] - lim[0] add_lim = (new_lim - lim_range) / 2 ax.set_ylim(lim[0] - add_lim, lim[1] + add_lim) def peaks_extra_fillbetween(axd2, eod_fr, fr, mats, normval=1): diags = [] if normval != 1: normval = eod_fr for mat in mats: diag, diagonals_prj_l = get_mat_diagonals(np.array(mat)) diags.extend(diagonals_prj_l) diagonals_prj_l = 10 * np.log10(diags) # / maxd axd2.fill_between([(fr - 5) / normval, (fr + 5) / normval], [np.min(diagonals_prj_l), np.min(diagonals_prj_l)], [np.max(diagonals_prj_l), np.max(diagonals_prj_l)], color='grey', alpha=0.5, zorder=0) def plt_susept_psd_base(colors, eod_fr, ax_p, spikes_base, xlim, add_texts=[0, 0], normval=1, texts_left=[0.22, 0], right='middle', fr_name='$f_{Base}$', sampling_rate=40000, peaks=False): spikes_mat, f_array, p_array = calc_psd_from_spikes(int(sampling_rate / 2), sampling_rate, spikes_base) pp = 10 * np.log10(np.mean(p_array, axis=0)) # [0] if normval != 1: normval = eod_fr if len(xlim) > 0: if normval == 1: xlim = [xlim[0], xlim[1] * eod_fr] ax_p.set_xlim(xlim) ax_p.plot(f_array / normval, pp, color=colors) # , alpha = float(alpha-0.05*s) ax_p.show_spines('b') if right == 'right': ax_p.yscalebar(1, 0.35, 20, 'dB', va='center', ha='right') ax_p.text(1.15, 0, 'Baseline', rotation=90, transform=ax_p.transAxes) elif right == 'left': ax_p.yscalebar(-0.03, 0.5, 20, 'dB', va='center', ha='left') ax_p.text(-0.23, 0.5, 'Baseline', va='center', rotation=90, transform=ax_p.transAxes) else: ax_p.yscalebar(1.05, 0.35, 20, 'dB', va='center', ha='right') if peaks: fr = 1 / np.mean(np.diff(np.array(spikes_base[0]) / 1000)) plt_peaks_several([fr / normval, eod_fr / normval], [pp], ax_p, pp, f_array / normval, [fr_name, f_eod_name_rm()], 0, ['grey', 'grey'], add_texts=add_texts, texts_left=texts_left, add_log=2.5, exact=False, text_extra=True, perc_peaksize=0.08, ms=14, clip_on=False, log='log') # True return ax_p def round_for_nice_float_strs(amp): if amp % 1 > 0: amp_name = np.round(amp, 1) else: amp_name = int(amp) return amp_name def ws_for_susept_pic(): wss = [0.4, 0.2] return wss def grids_upper_susept_pics(c, grid1, row=3, hr=[1, 0.4, 0.4], hs=0.65, ws=0.08, col=4, wr_u=[1, 2, 2]): try: grid_cell = gridspec.GridSpecFromSubplotSpec(2, 1, grid1[c], height_ratios=[3, 5], hspace=hs) # 0.15 except: print('cell thing3') grid_cell = [] embed() grid_upper = gridspec.GridSpecFromSubplotSpec(row, col, grid_cell[0], width_ratios=wr_u, hspace=0.0, wspace=ws, height_ratios=hr) # hspace=0.1,wspace=0.39, return grid_cell, grid_upper def tags_susept_pictures(tags_cell, xoffs=np.array([-4.7, -3.2, -4.7, -6.3, -2.7, -3.2]), yoffs=np.array([1.1, 1.1, 2, 2, 2, 2])): fig = plt.gcf() # ok das finde ich jetzt gut dass ich da eine Liste eingeben kann tag2(fig, tags_cell, xoffs=xoffs, yoffs=yoffs) # -1.5diese Offsets sind nicht intuitiv def grids_for_psds2(amps_defined, extra_input, grid_cell, several, ws=0.3, wss=[], add=0, widht_ratios=[], wr=[1, 0.2, 1, 1]): axd2 = [] axi = [] axo2 = [] grid_s1 = [] grid_s2 = [] if several: grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(amps_defined) + 1, grid_cell[1], hspace=0.1, wspace=0.2, width_ratios=widht_ratios) grid_s1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[0], hspace=0.1, wspace=wss[0], width_ratios=[0.8, 1]) # height_ratios=[1.5, 1.5, 5], # plot the same also to the next plot grid_s2 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[1], hspace=0.1, wspace=wss[1], width_ratios=[0.8, 1]) # height_ratios=[1.5, 1.5, 5], else: if extra_input: row_nrs = 3 else: row_nrs = 2 try: grid_lower = gridspec.GridSpecFromSubplotSpec(row_nrs, len(amps_defined) + 3 + add, grid_cell[1], hspace=0.1, wspace=ws, width_ratios=wr) # , width_ratios=widht_ratios except: print('grid bursts0') embed() if extra_input: axi = plt.subplot(grid_lower[0, 0]) axd2 = plt.subplot(grid_lower[2, 0]) # plt.subplot(grid_s[0]) axo2 = plt.subplot(grid_lower[1, 0]) else: axd2 = plt.subplot(grid_lower[:, -1]) # plt.subplot(grid_s[0]) axo2 = plt.subplot(grid_lower[:, 0]) axi = [] return axd2, axi, axo2, grid_lower, grid_s1, grid_s2 def grids_for_psds(amps_defined, extra_input, grid_cell, several, ws=0.25, wss=[], add=0, widht_ratios=[], wr=[1, 0.2, 1, 1]): axd2 = [] axi = [] axo2 = [] grid_s1 = [] grid_s2 = [] if several: grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(amps_defined) + 1, grid_cell[1], hspace=0.1, wspace=0.2, width_ratios=widht_ratios) grid_s1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[0], hspace=0.1, wspace=wss[0], width_ratios=[0.8, 1]) # height_ratios=[1.5, 1.5, 5], # plot the same also to the next plot grid_s2 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[1], hspace=0.1, wspace=wss[1], width_ratios=[0.8, 1]) # height_ratios=[1.5, 1.5, 5], else: if extra_input: row_nrs = 3 else: row_nrs = 2 try: grid_lower = gridspec.GridSpecFromSubplotSpec(row_nrs, len(amps_defined) + 4 + add, grid_cell[1], hspace=0.1, wspace=ws, width_ratios=wr) # , width_ratios=widht_ratios except: print('grid bursts') embed() if extra_input: axi = plt.subplot(grid_lower[0, 0]) axd2 = plt.subplot(grid_lower[2, 0]) # plt.subplot(grid_s[0]) axo2 = plt.subplot(grid_lower[1, 0]) else: axd2 = plt.subplot(grid_lower[:, -1]) # plt.subplot(grid_s[0]) axo2 = plt.subplot(grid_lower[:, 0]) axi = [] return axd2, axi, axo2, grid_lower, grid_s1, grid_s2 def labels_for_psds(axd2, axi, axo2, extra_input, right='middle', chi_pos=-0.3, normval=1, xpos_xlabel=-0.23, power_label='$|\chi_1|$', log_transfer=False): test_limits = False if test_limits: axo2.set_ylabel(power_label) axd2.set_ylabel('Projection') else: # if aa == 0: if right == 'right': remove_yticks(axo2) remove_yticks(axd2) axo2.text(1.15, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes) axd2.text(1.15, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes) axd2.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right') axo2.yscalebar(1, 0.5, 10, trasnfer_ylabel(), va='center', ha='right') axd2.show_spines('b') # /mV axo2.show_spines('b') if extra_input: axi.text(1.15, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes) axi.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right') axi.show_spines('b') elif right == 'left': remove_yticks(axo2) remove_yticks(axd2) axo2.text(-0.23, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes) axd2.text(-0.23, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes) axd2.text(-0.23, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes) axd2.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left') axo2.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left') axd2.show_spines('b') axo2.show_spines('b') if extra_input: axi.text(-0.23, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes) axi.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left') axi.show_spines('b') else: axd_labels(axd2, chi_pos=chi_pos, normval=normval, xpos_xlabel=xpos_xlabel) if log_transfer == True: axo2.text(-0.13, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes) axo2.yscalebar(1, 0.5, 10, trasnfer_ylabel(), va='center', ha='right') axo2.show_spines('b') remove_yticks(axo2) else: axo2.set_ylabel(trasnfer_ylabel()) axo2.show_spines('lb') if normval != 1: axo2.text(1.05, xpos_xlabel, tranfer_xlabel(), ha='center', va='center', transform=axo2.transAxes) else: axo2.text(1.05, xpos_xlabel, tranfer_xlabel_hz(), ha='center', va='center', transform=axo2.transAxes) axo2.arrow_spines('b') if extra_input: axi.text(1.15, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes) axi.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right') axi.show_spines('b') def axd_labels(axd2, chi_pos=-0.3, normval=1, xpos_xlabel=-0.23): # chi_pos = -0.3,normval = 1, xpos_xlabel = -0.23, axd2.text(chi_pos, 0.5, ylabel_projected(), rotation=90, va='center', transform=axd2.transAxes) axd2.yscalebar(1.1, 0.5, 10, 'dB', va='center', ha='right') axd2.show_spines('b') # -0.23 if normval != 1: axd2.text(1.05, xpos_xlabel, diagonal_xlabel_nothz(), ha='center', va='center', transform=axd2.transAxes) else: axd2.text(1.05, xpos_xlabel, diagonal_xlabel(), ha='center', va='center', transform=axd2.transAxes) axd2.arrow_spines('b') remove_yticks(axd2) def ylabel_projected(): return r'$|\bar{\chi_2}|$' def trasnfer_ylabel(): return '$|\mathcal{\chi}_{1}|\,$[Hz]' # '$|\chi_{1}|$'#r'$|\mathcal{X}_{1}|$\,[Hz]'#' ' def get_base_params(cell, cell_type_type, frame): try: frame_cell = frame[(frame['cell'] == cell)] except: print('frame thing') embed() frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) try: cell_type = frame_cell[cell_type_type].iloc[0] except: print('cell type prob') embed() spikes = frame_cell.spikes.iloc[0] spikes_all = [] isi = [] frs_calc = [] fr = frame_cell.fr.iloc[0] eod_fr = frame_cell.EODf.iloc[0] return cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all def plt_susept_isi_base(color, ax_isi, isi, delta=None, ypos=-0.4, xlim=[0, 13], clip_on=False): ax_isi.show_spines('b') # xmin=0, xmax=xmax[c_here], alpha=0.5, step=0 kernel_histogram(ax_isi, color, isi[0], extend=False, clip_on=clip_on, step=0.01) if len(xlim) > 0: ax_isi.set_xlim(xlim) ax_isi.text(1.1, ypos, isi_label_core(), transform=ax_isi.transAxes, ha='center', va='center', ) # ($I_{spikes}/I_{EODf}$) else: ax_isi.text(1.1, ypos, isi_label_core(), transform=ax_isi.transAxes, ha='center', va='center', ) # ($I_{spikes}/I_{EODf}$) if delta: ax_isi.set_xticks_delta(delta) ax_isi.arrow_spines('b') remove_yticks(ax_isi) def isi_label_core(): return '$1/'+f_eod_name_core_rm()+'$' def kernel_histogram(ax_isi, color, isi, norm='no', clip_on=False, step=0.1, label='', orientation='horizontal', alpha=1, xmin='no', xmin_perc=False, perc_min=0.01, xmax='no', height_val=1, extend=True): if len(isi) > 1: isi = np.array(list(map(float, np.array(isi)))) try: isi = isi[~np.isnan(isi)] isi = isi[~np.isinf(isi)] except: print('any problem') embed() try: if step == 0: kernel = gaussian_kde(isi) else: kernel = gaussian_kde(isi, step / np.std(isi, ddof=1)) except: print('kernel thing') embed() isi_sorted = np.sort(isi) if xmin == 'no': if xmin_perc: # das mit dem percentile ist keine gute idee weil die verteilung kann ja im mittel durchaus hohe werte haben xmin = np.min(isi_sorted) - np.percentile(isi_sorted, perc_min) else: xmin = np.min(isi_sorted) * 0.8 if xmax == 'no': if xmin_perc: xmax = np.max(isi_sorted) + np.percentile(isi_sorted, perc_min) else: xmax = np.max(isi_sorted) * 1.1 # create points between the min and max try: if extend: x = np.linspace(xmin, xmax, 1000) else: x = isi_sorted except: print('extend thing') embed() kde = kernel(x) if norm == 'density': kde = kde / np.sum(kde) elif norm == 'maximum': kde = height_val * kde / np.max(kde) if orientation == 'horizontal': # isi_sorted ax_isi.plot(x, kde, color=color, label=label, alpha=alpha, clip_on=clip_on) # filllbetween ax_isi.fill_between(x, kde, color=color, alpha=alpha, clip_on=clip_on) # filllbetween else: ax_isi.plot(kde, x, color=color, label=label, alpha=alpha, clip_on=clip_on) # ,clip_on = Falsefilllbetween ax_isi.fill_betweenx(np.sort(x), kde[np.argsort(x)], color=color, alpha=alpha, clip_on=clip_on) # filllbetween,clip_on = False test = False if test: from utils_test import test_isi test_isi() def plt_square_with_psds(aa, amp, amps_defined, axes, axis, c, cells_plot, files, grid_s, ims, load_name, stack_file, xlim, cbar_true=True, axd=None, axo=None, square_plot=True, color='black', add='', file_name=None): eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file) if not file_name: file_name = files[0] stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='osf', trial_nr=trial_nr, stimulus_length=length, add=add, amp=amp, file_name=file_name) stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add, load_version='csv', load_type='spikes', trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_name) extra = False if extra: try: pass except: pass ############################# # load both amps for common amp db limit _, _, _ = get_max_several_amp_squares(add, amp, amps_defined, files, load_name, stack_file) if aa == 0: pass else: pass try: ax_square = plt.subplot(grid_s[:, 1]) except: print('grid problem4') embed() if square_plot: mat, test_limits, im, add_nonlin_title = plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims, stack_final1, [], cbar_true=cbar_true) ############################################ # psd part fr = stack_final1.fr.unique()[0] if not axd: try: axd = plt.subplot(grid_s[1, 0]) except: print('axd thing') embed() if not axo: axo = plt.subplot(grid_s[0, 0]) axd, axi, axo = plt_psds_all(axd, axo, mat, stack_final, stack_osf, test_limits, xlim, color=color, db='db') axes.append(axi) # np.min(mat.columns) axes.append(axo) # np.max(mat.columns) axis.append(axi) return ax_square, axi, eod_fr, fr, stack_final1, stack_spikes, im, axd, axo def stack_preprocessing(amp, stack_file, snippets=20): stack_amp2 = stack_file[stack_file['snippets'] == snippets] if len(stack_amp2) < 1: stack_amp2 = stack_file[stack_file['snippets'] == 20] if len(stack_amp2) < 1: stack_amp2 = stack_file[stack_file['snippets'] == 10] if len(stack_amp2) < 1: stack_amp2 = stack_file[stack_file['snippets'] == 9] if len(stack_amp2) < 1: stack_amp2 = stack_file[stack_file['snippets'] == 4] stack_amp = stack_amp2[stack_amp2['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] #if len(stack_final) <0: # todo: hier 20 Trials auch einbauen trial_nr_double = stack_final.trial_nr.unique() try: eod_fr = stack_final.eod_fr.iloc[0] except: print('trial thing') embed() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double1') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: trial_nr = np.max(trial_nr_double) except: print('trial something') embed() try: stack_final1 = stack_final[stack_final.trial_nr == trial_nr] except: print('stack_final1 problem') embed() return eod_fr, length, stack_final, stack_final1, trial_nr def get_max_several_amp_squares(add, amp, amps_defined, files, load_name, stack_file): maxs_i = [] maxs_o = [] maxs_d = [] for amp_here in amps_defined: stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() new_keys, stack_plot = convert_csv_str_to_float(stack_final) norm_d = False # todo: das insowas wie ein übergeordnetes Dict machen if norm_d: mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0]) else: mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # diag, diagonals_prj_l = get_mat_diagonals(np.array(mat)) maxs_d.append(np.max(diagonals_prj_l)) trial_nr = np.max(trial_nr_double) isf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='isf', trial_nr=trial_nr, stimulus_length=length, add=add, amp=amp_here, file_name=files[0]) f = find_f(stack_final) f_max = stack_final.index[-1] * 2 f_restict = f[f < f_max] maxs_i.append(np.max(isf.iloc[0:len(f_restict)]) ** 2) osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv', load_type='osf', add=add, trial_nr=trial_nr, stimulus_length=length, amp=amp_here, file_name=files[0]) f_max = stack_final.index[-1] * 2 f_restict = f[f < f_max] maxs_o.append(np.max(osf.iloc[0:len(f_restict)]) ** 2) maxo = np.max(maxs_o) maxi = np.max(maxs_i) maxd = np.max(maxs_d) return maxo, maxi, maxd def cv_stim_name_rm(): return '\mathrm{CV_{stim}}' def plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims, stack_final1, xlim, line_length=1 / 4, alpha=1, cbar_true=True, perc=True, amp_give=True, base_extra=False, eod_metrice=True, ha='right', nr=3, fr=None, title_square='', xpos_xlabel=-0.2, ypos=0.05, xpos=0.1, color='white', add_nonlin_title=None): if aa == len(amps_defined) - 1: cbar_do = False else: cbar_do = False print(add_nonlin_title) cbar, mat, im, add_nonlin_title = plot_square_core([ax_square], stack_final1, eod_metrice=eod_metrice, fr=fr, nr=nr, line_length=line_length, cbar_do=cbar_do, perc=perc, add_nonlin_title=add_nonlin_title) ims.append(im) if xlim: ax_square.set_xlim(xlim) ax_square.set_ylim(xlim) else: ax_square.set_xlim(0, np.max(mat.columns)) ax_square.set_ylim(0, np.max(mat.columns)) ax_square.set_title('') test_limits = False if test_limits: fig = plt.gcf() _, _, _, _, _ = colorbar_outside(ax_square, im, fig, add=5, width=0.01) ax_square.text(1.05, 0.5, nonlin_title(), ha='center', rotation=90) else: if cbar_true: fig = plt.gcf() cbar, left, bottom, width, height = colorbar_outside(ax_square, im, fig, add=5, width=0.007) cbar.set_label(nonlin_title(' [' + add_nonlin_title), rotation=90, labelpad=3) amp_name = round_for_nice_float_strs(amp) cvs = True if amp_give: amp_val = '$c=%s$' % (int(amp_name)) + '$\,\%$, ' else: amp_val = '' if cvs: # $\mathcal{X}_{2} if base_extra: ax_square.text(xpos, ypos, title_square + amp_val + r'$'+cv_stim_name_rm()+'=%.2f$' % (stack_final1.cv_stim.iloc[0]), ha=ha, transform=ax_square.transAxes, color=color, alpha=alpha) # (np.round(stack_final1.cv_stim.iloc[0], 2))'white' files[0] + ' l ' + str(length)chi_name()+ else: ax_square.text(xpos, ypos, title_square + amp_val + r'$\rm{CV}=%.2f$' % (stack_final1.cv_stim.iloc[0]), ha=ha, transform=ax_square.transAxes, color=color, alpha=alpha) # (np.round(stack_final1.cv_stim.iloc[0], 2))'white' files[0] + ' l ' + str(length)chi_name()+ else: ax_square.text(xpos, ypos, title_square + chi_name() + '$($' + str(int(amp_name)) + '$\%$)', ha=ha, transform=ax_square.transAxes, color=color, alpha=alpha) # 'white' files[0] + ' l ' + str(length) ax_square.set_xticks_delta(100) ax_square.set_yticks_delta(100) if c != len(cells_plot) - 1: ax_square.set_xlabel('') else: ax_square.set_xlabel(F1_xlabel()) ax_square.text(1.05, xpos_xlabel, F1_xlabel(), ha='center', va='center', transform=ax_square.transAxes) if aa != 0: remove_yticks(ax_square) ax_square.set_ylabel('') else: set_ylabel_arrow(ax_square) ax_square.arrow_spines('lb') ax_square.set_xlabel('') ax_square.set_ylabel('') return mat, test_limits, im, add_nonlin_title def set_ylabel_arrow(ax_square, xpos=-0.15, ypos=0.97, color='black', arrow = False): ax_square.text(xpos, ypos, F2_xlabel(), ha='center', va='center', transform=ax_square.transAxes, rotation=90, color=color) if arrow: ax_square.arrow_spines('l') def F2_xlabel(): return '$f_{2}$\,[Hz]' def nonlin_title(add_nonlin_title=''): return chi_name() + add_nonlin_title + 'Hz]' # \frac{Hz}{mV^2} def chi_name(): return r'$|\chi_{2}|$' # r'$|\mathcal{X}_{2}|$'# [' def plt_psds_all(axd, axo, mat, stack_final, stack_osf, test_limits, xlim, alpha=1, color='black', power_type=False, db='', fr=None, peaks_extra=False, zorder=1, eod_fr=1): ############################### # projection diagonal xmax, xmin, diagonals_prj_l = plt_diagonal(axd, color, db, fr, mat, alpha, eod_fr, peaks_extra, xlim, zorder) if power_type: plt_power_trace(alpha, axo, color, db, stack_final, stack_osf, test_limits, xmax, eod_fr=eod_fr, zorder=zorder) else: plt_transferfunction(alpha, axo, color, stack_final, zorder=zorder) axo.set_xlim(xmin, xmax) return axd, axo, axo def plt_diagonal(axd, color, db, fr, mat, alpha=1, eod_fr=750, peaks_extra=True, label='', xlim=[], zorder=1, normval=1, color_same=True): diag, diagonals_prj_l = get_mat_diagonals(np.array(mat)) axis_d = axis_projection(mat, axis='') if normval != 1: normval = eod_fr if db == 'db': diagonals_prj_l = 10 * np.log10(diagonals_prj_l) # / maxd axd.plot(axis_d / normval, diagonals_prj_l, color=color, alpha=alpha - 0.05, zorder=zorder, label=label) if peaks_extra: if not color_same: color = 'black' color_dot = 'black' alpha_dot = [alpha] else: color_dot = 'grey' alpha_dot = [1] axd.axhline(np.median(diagonals_prj_l), linewidth=0.9, linestyle='--', color=color, alpha=alpha, zorder=zorder + 1) # 0.45#0.75 plt_peaks_several([fr / normval], [diagonals_prj_l], axd, diagonals_prj_l, axis_d / normval, [''], 0, [color_dot], zorder=zorder + 1, alphas=alpha_dot, ms=5) xmax, xmin = get_xlim_psd(axis_d / normval, xlim) print(xmax) axd.set_xlim(xmin, xmax) axd.arrow_spines('b') return xmax, xmin, diagonals_prj_l def get_xlim_psd(axis_d, xlim): if xlim: # [0] xmin = xlim[0] xmax = xlim[1] else: xmin = 0 # axis_d[0] - 1 # mat.columns[0] xmax = axis_d[-1] # mat.columns[-1] return xmax, xmin def plt_transferfunction(alpha, axo, color, stack_final, zorder=1, label='', normval=1, log_transfer=False): f_axis, vals = get_transferfunction(stack_final) if log_transfer: means_all = 10 * np.log10(vals) else: means_all = 10 * np.log10(vals) max_lim = calc_cut_offs(stack_final.file_name.iloc[0]) / normval axis = f_axis / normval if normval != 1: pass if max_lim: axo.plot(axis[axis < max_lim], means_all[axis < max_lim], color=color, zorder=zorder, alpha=alpha, label=label) else: axo.plot(axis, means_all, color=color, alpha=alpha, zorder=zorder, label=label) def get_transferfunction(stack_final): osf = stack_final['osf'] isf = stack_final['isf'] f = find_f(stack_final) f_axis = f[0:len(isf.iloc[0][0])] # csd pds berechnung counter = 0 for t in range(len(osf)): if type(osf.iloc[t]) == list: if t == 0: vals = osf.iloc[t][0] * np.conj(isf.iloc[t][0]) powers = np.abs(isf.iloc[t][0]) ** 2 else: vals += osf.iloc[t][0] * np.conj(isf.iloc[t][0]) powers += np.abs(isf.iloc[t][0]) ** 2 counter += 1 vals = vals / counter vals = np.abs(vals) / (powers / counter) return f_axis, vals def plt_power_trace(alpha, axo, color, db, stack_final, stack_osf, test_limits, xmax, eod_fr=1, zorder=1): if len(stack_osf) == 0: isf = stack_final['osf'] isf_resaved = False else: isf = stack_osf isf_resaved = True f = find_f(stack_final) power = 1 if isf_resaved: f_axis = f[0:len(isf)] means = np.transpose(isf) means_all = np.mean(np.abs(means) ** power, axis=0) else: f_axis = f[0:len(isf.iloc[l][0])] means = get_array_from_pandas(isf) means_all = np.mean(np.abs(means) ** power, axis=0) if db == 'db': means_all = 10 * np.log10(means_all) max_lim = xmax axis = f_axis / eod_fr if max_lim: axo.plot(axis[axis < max_lim], means_all[axis < max_lim], color=color, zorder=zorder, alpha=alpha) else: axo.plot(axis, means_all, color=color, alpha=alpha, zorder=zorder) if not test_limits: remove_xticks(axo) remove_xticks(axo) def plt_spikes(c, cells_plot, color, ax_spikes, stack_final1, stack_spikes, alpha=1, xlim_e=[0, 200], sc=20, scale=True, spikes_max=5): spikes_here = create_spikes(stack_final1, stack_spikes) if len(spikes_here) > 0: if len(spikes_here) > spikes_max: spikes_here = spikes_here[0:spikes_max] ax_spikes.eventplot(spikes_here, color=color, alpha=alpha) ax_spikes.set_xlim(xlim_e) # spikes_both[gg] if c != len(cells_plot) - 1: remove_xticks(ax_spikes) ax_spikes.show_spines('') if scale: ax_spikes.xscalebar(float(1 - sc / xlim_e[-1]), -0.03, sc, 'ms', va='left', ha='bottom') # def plt_stimulus(eod_fr, axe, stack_final1, xlim_e, RAM=True, file_name=None, alpha = 0.5, add=0.07): axe.show_spines('') neuronal_delay = 5 # das hatte Jan G. angemerkt, dass wir den Stimulus um den neuronal Delay kompensieren sollten max_here = (xlim_e[1] + neuronal_delay) / 1000 eod_interp, sampling_interp, time_eod_interp = get_stimulus_here(file_name, stack_final1, sampling=40000, max=max_here) fake_fish = fakefish.wavefish_eods('Alepto', frequency=eod_fr, samplerate=sampling_interp, duration=len(time_eod_interp) / sampling_interp, phase0=0.0, noise_std=0.00) size_fake_am = 0.5 if RAM: try: axe.plot(time_eod_interp * 1000 - neuronal_delay, fake_fish * (1 + eod_interp * size_fake_am), color='lightgrey', alpha=alpha, clip_on=True) except: print('axe thing') embed() axe.plot(time_eod_interp * 1000 - neuronal_delay, eod_interp * size_fake_am + 1 + add, color='red', linewidth=1) else: try: axe.plot(time_eod_interp * 1000 - neuronal_delay, fake_fish + eod_interp * size_fake_am, color='grey', alpha=0.5, clip_on=True) except: print('axe thing') embed() axe.plot(time_eod_interp * 1000 - neuronal_delay, eod_interp * size_fake_am + 1 + add, color='red', linewidth=1) axe.set_xlim(xlim_e) ylim_e = axe.get_ylim() ylim_e = np.array(ylim_e) * 1.05 axe.set_ylim(ylim_e) remove_xticks(axe) def get_stimulus_here(file_name, stack_final1, max=0.4, sampling=None): if not sampling: sampling = stack_final1.sampling.iloc[0] time_eod = np.arange(0, max, 1 / sampling) if not file_name: try: eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name.iloc[0]) except: try: eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name2.iloc[0]) except: eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name.iloc[0] + 's') print('open problem thing2') else: try: eod_interp, time_wn_cut, _ = load_noise(file_name) except: eod_interp, time_wn_cut, _ = load_noise(file_name + 's') eod_interp = interpolate(time_wn_cut, eod_interp, time_eod, kind='cubic') return eod_interp, sampling, time_eod def same_lims_susept(axds, axis, axos, ims): set_clim_same(ims, clims='all', same='same', lim_type='up') set_same_ylim(axos) set_same_ylim(axis) set_same_ylim(axds) def create_spikes(stack_final1, stack_spikes=[]): spikes_here = [] if len(stack_spikes) > 0: # type(stack_final1.spikes.iloc[0]) == str if type(stack_spikes) == list: for sp in range(len(stack_spikes)): try: spi = stack_spikes[sp].dropna() except: spi = stack_spikes[sp] spikes_here.append(np.array(spi) * 1000) else: for sp in range(np.shape(stack_spikes)[1]): try: spi = stack_spikes[sp].dropna() except: spi = stack_spikes[sp] spikes_here.append(np.array(spi) * 1000) else: for sp in stack_final1.spikes.iloc[0][0]: try: spikes_here.append(np.array(sp) * 1000) except: print('spike thing') embed() return spikes_here def load_mt_data(axi, axs, c, cell, cells_plot, colors_hist, gg, grid_upper, stack_final1, xlim): data_dir = 'cells/' data_name = cell name_core = load_folder_name('data') + data_dir + data_name nix_name = name_core + '/' + data_name + '.nix' # '/' f = nix.File.open(nix_name, nix.FileMode.ReadOnly) b = f.blocks[0] names_mt_gwn = stack_final1['names_mt_gwn'].unique()[0] try: mt = b.multi_tags[names_mt_gwn] except: names_mt_gwns = find_names_gwn(b) mt = b.multi_tags[names_mt_gwns[0]] print('mt thing') embed() features, id, data_between_2017_2018, mt_ids = find_feature_gwn_id(mt) dataset, rlx_problem = load_rlxnix(nix_name) # wir machen das hier für diese rlx only weil ich nur so an den Kontrast komme spikes_loaded = [] if rlx_problem: file_name, file_name_save, cut_off, file, sd = find_file_names(nix_name, mt, names_mt_gwn) file_extra, idx_c, base_properties, id_names = get_contrasts_over_rlx_calc_RAM(dataset) dataset.close() try: base_properties = base_properties.sort_values(by='c', ascending=False) except: print('contrast problem sorting') embed() if data_between_2017_2018 != 'all': file_name_sorted = base_properties[base_properties.file_name == file_name] else: file_name_sorted = base_properties if len(file_name_sorted) < 1: print('file_name problem') embed() file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1] # ich sollte auf dem level schon nach dem richtigen filename filtern! file_name_sorted = file_name_sorted[file_name_sorted['c_orig'] == stack_final1['c_orig'].unique()[0]] grouped = file_name_sorted.groupby('c') # ok es gibt wohl eine Zelle die erste, Zelle '2010-06-15-af' wo eben das nicht input arr heißt sondern gwn 300, was da passiert ist kann ich # euch jetzt so auch nicht sagen, aber alle anderen Zellen sehen gut aus! Scheint die einzige zu sein° data_array_names = get_data_array_names(b) # ,find_indices_to_match_contrats,get_data_array_names if 'eod' in ''.join(data_array_names).lower(): for g, group in enumerate(grouped): # hier erstmal alles nach dem Kontrast sortieren sd, start, end, rep, cut_off, c_len, c_unit, c_orig, c_len, files_load, cc, id_g, amplsel = open_group_gwn( group, file_name, cut_off, sd, data_between_2017_2018) indices, ends_mt = find_indices_to_match_contrats(grouped, group, mt, id_g, mt_ids, data_between_2017_2018) indices = list(map(int, indices)) max_f = cut_off if max_f == 0: print('max f = 0') embed() for mm, m in enumerate(indices): first, minus, second, stimulus_length = find_first_second(b, names_mt_gwn, m, mt, False, mm=mm, ends_mt=ends_mt) spikes_mt = link_arrays_spikes(b, first, second, minus) # spikes_loaded.append(spikes_mt * 1000) eod_mt, sampling = link_arrays_eod(b, first, second, array_name='LocalEOD-1') # hier noch das stimpresaved laden # todo: das eventuell noch anpassen axi.set_xlim(0, 13) xlim_e = [0, 200] axs = plt.subplot(grid_upper[1, 1::]) axs.eventplot(spikes_loaded, color=colors_hist[gg], ) axs.set_xlim(xlim) # spikes_both[gg] if c != len(cells_plot) - 1: remove_xticks(axi) remove_xticks(axs) else: axi.set_xlabel('ISI') # ($I_{spikes}/I_{EODf}$) axs.set_xlabel('Time [ms]') axi.set_ylabel('Nr') axs.set_ylabel('Nr') else: print('rlx thing') return axs, eod_mt, sampling, xlim_e def exclude_cut_filenames(cell_type, stack, fexclude=False): file_names_exclude = create_file_names_exclude(cell_type) files = stack['file_name'].unique() if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() print('file names excluded') print(files) return files, stack def plt_cellbody_punitsingle(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, plus=1, ax3=[], xlim=[], burst_corr='_burst_corr_individual'): stack = [] for c, cell in enumerate(cells_plot): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) try: cell_type = frame_cell[cell_type_type].iloc[0] except: print('cell type prob') embed() spikes = frame_cell.spikes.iloc[0] eod, sampling_rate, ds, time_eod = find_eod(frame_cell) eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) # cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht # also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben! if cont_spikes: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: lim_here = find_lim_here(cell, burst_corr) print(lim_here) if np.min(np.concatenate(hists)) < lim_here: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all), lim=lim_here, burst_corr=burst_corr) spikes_both = [spikes_all, spikes_ex] hists_both = [hists, hists2] else: spikes_both = [spikes_all] hists_both = [hists] # das ist der title fals der square nicht plottet for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell axes = [] if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') file_names_exclude = create_file_names_exclude(cell_type) files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() for aa, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, aa + plus], height_ratios=[1.5, 1.5, 5], hspace=0.1) axs = plt.subplot(grid_s[2]) except: print('grid problem2') embed() cbar, mat, im = plot_square_core([axs], stack_final1) if xlim: axs.set_xlim(xlim) axs.set_ylim(xlim) axs.set_title('') if aa == len(amps) - 1: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) fr = stack_final1.fr.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] if xlim: # [0] pass else: pass try: axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, xlim[0], xlim[-1], eod_fr, fr, fr_stim, stack_final, ) except: print('psd problem') embed() if aa == 1: axo.text(0.5, 1, cell, ha='center', transform=axo.transAxes) axes.append(axi) # np.min(mat.columns) axes.append(axo) # np.max(mat.columns) if aa == 0: axo.set_ylabel('Otp.') axi.set_ylabel('Inp.') if c == 0: axo.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length) if aa != 0: axi.set_ylabel('') if c != len(cells_plot) - 1: axs.set_xlabel('') remove_xticks(axs) else: axs.set_xlabel(F1_xlabel()) if aa != 0: remove_yticks(axs) axs.set_ylabel('') # do the scatter of these cells add = ['', '_burst_corr_individual'] ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if len(stack) > 0: load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell if ax3 != []: try: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack) except: print('stim problem') embed() try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() alpha = 1 grid_s = gridspec.GridSpecFromSubplotSpec(4, 1, grid1[c, 0], height_ratios=[2.5, 1.5, 1.2, 2.5], hspace=0.25) axi = plt.subplot(grid_s[-1]) axs.set_title(' ' + cell) if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] for gg in range(len(hists_both)): hists_here = hists_both[gg] for hh, h in enumerate(hists_here): try: axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh)) except: print('alpha problem2') embed() axi.set_xlim(0, 13) axe = plt.subplot(grid_s[0]) axe.plot(time_eod * 1000, eod, color='grey', linewidth=0.5) axe.set_xlim(0, 40) axs = plt.subplot(grid_s[1]) axs.eventplot(spikes_both[gg], color=colors_hist[gg], ) axs.set_xlim(0, 40) if c != len(cells_plot) - 1: remove_xticks(axi) remove_xticks(axs) else: axi.set_xlabel('isi') axs.set_xlabel('Time [ms]') remove_xticks(axe) axi.set_ylabel('Nr') axs.set_ylabel('Nr') axe.set_ylabel('mV') def create_file_names_exclude(cell_type): if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = punit_file_exclude() else: file_names_exclude = ampullary_file_exclude() # return file_names_exclude def punit_file_exclude(): file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'InputArr_50hz_30', 'gwn50Hz50s0.3', 'gwn50Hz10s0.3', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn25Hz10s0.3', 'FileStimulus-file-gaussian50.0', 'FileStimulus-file-gaussian25.0', ] # return file_names_exclude def plt_cell_body2(grid1, frame, colors, cells_plot, cell_type_type, ax3=[], xlim=[]): for c, cell in enumerate(cells_plot): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) try: cell_type = frame_cell[cell_type_type].iloc[0] except: embed() spikes = frame_cell.spikes.iloc[0] fr = frame_cell.fr.iloc[0] cv = frame_cell.cv.iloc[0] eod_fr = frame_cell.EODf.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) # cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht # also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben! if cont_spikes: # die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild if len(hists) > 0: if len(np.concatenate(hists)) > 0: if np.min(np.concatenate(hists)) < 1.5: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all)) hists_both = [hists, hists2] else: hists_both = [hists] # das ist der title fals der square nicht plottet plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz', fontsize=11, ) # cell[0:13] + color=color+ cell_type save_names = [ 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0_s_burst_corr', 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr', 'noise_data8_nfft1sec_original__LocalEOD_mean2__CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr', ] for a, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell if os.path.exists(load_name + '.pkl'): stack = pd.read_pickle(load_name + '.pkl') if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = punit_file_exclude() # else: file_names_exclude = ampullary_file_exclude() files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() _, _ = find_row_col(np.arange(len(amps) * len(files))) predefined_amp = True if predefined_amp: pass else: pass stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() amps_defined = [np.min(amps)] for aa, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') embed() # ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() try: grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a + 1], height_ratios=[1.5, 1.5, 5], hspace=0) axs = plt.subplot(grid_s[2]) except: print('grid problem0') embed() cbar, mat, im = plot_square_core([axs], stack_final1) if xlim: axs.set_xlim(xlim) axs.set_ylim(xlim) if a == len(amps) - 1: cbar.set_label(nonlin_title(), rotation=90, labelpad=10) fr = stack_final1.fr.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns), np.max(mat.columns), eod_fr, fr, fr_stim, stack_final, ) if c == 0: axi.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length) if a != 0: axi.set_ylabel('') if c != 2: axs.set_xlabel('') remove_xticks(axi) ################################ # do the scatter of these cells if ax3 != []: frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type) try: ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') except: print('scatter problem') embed() ################################ # do the hist do_hist(grid1, c, colors, cell_type, hists_both, cell, cells_plot) def ampullary_file_exclude(): file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30', 'InputArr_50to100hz_30', 'InputArr_50hz_30', 'blwn125Hz10s0.3', 'gwn50Hz10s0.3', 'FileStimulus-file-gaussian50.0', 'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3', 'gwn50Hz10s0.3short', 'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] # return file_names_exclude def do_hist(grid1, c, colors, cell_type, hists_both, cell, cells_plot): alpha = 1 axi = plt.subplot(grid1[c, 0]) if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] for gg in range(len(hists_both)): hists_here = hists_both[gg] for hh, h in enumerate(hists_here): try: axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh)) except: print('alpha problem3') embed() axi.set_title('CV ' + str(np.round(np.std(h) / np.mean(h), 3)) + ' ' + cell) # +' VS '+str(vs) axi.set_xlim(0, 13) if c != len(cells_plot) - 1: remove_xticks(axi) else: axi.set_xlabel('isi') def cells_eigen(base_extra=False, amp_desired=[0.5, 1, 5], xlim=[0, 1.1], cells_plot2=[], titles=['Baseline \n Susceptibility', 'Half EODf \n Susceptibility'], peaks_extra=[False, False, False]): plot_style() # 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_ # ] # save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_'] save_names = [version_final()] amps_desired = amp_desired # amps_desired, cell_type_type, cells_plot, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class) cell_type_type = 'cell_type_reclassified' frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_sp='/calc_base_data-base_frame_overview.pkl', frame_general=False) cells_plot = cells_plot2 default_figsize(column=2, width=12, length=3.25 * len(cells_plot)) # ts=10, fs=10, ls=10, grid1 = big_grid_susept_pics(cells_plot, top=0.94, bottom=0.065) plt_cellbody_eigen(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, xlim=xlim, base_extra=base_extra, titles=titles, peaks_extra=peaks_extra) save_visualization(pdf=True) def fr_name_rm(): rm_var = rem_variable() if rm_var['rm']: val = r'$f\rm{'+basename_small()+'}$' else: val = r'$f'+basename_small()+'$' return val def ampullary_punit(permuted=False, eod_metrice=True, base_extra=False, color_same=True, fr_name='$f_{Base}$', amp_desired=[5, 20], isi_delta=None, xlim_p=[0, 1.1], tags_individual=False, xlim=[], add_texts=[0.25, 0], cells_plot2=[], RAM=True, scale_val=False, titles=['Low-CV P-unit,', 'High-CV P-unit', 'Ampullary cell,'], peaks_extra=[True, True, True]): # [0, 1.1] plot_style() # save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_'] save_names = [version_final()] amps_desired = amp_desired # amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class) cell_type_type = 'cell_type_reclassified' frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_sp='/calc_base_data-base_frame_overview.pkl', frame_general=False) default_settings_cells_susept(cells_plot2) if len(cells_plot2) == 1: grid1 = big_grid_susept_pics(cells_plot2, top=0.9, bottom=0.12) else: grid1 = big_grid_susept_pics(cells_plot2, bottom=0.065) plt_cellbody_singlecell(grid1, frame, amps_desired, save_names, cells_plot2, cell_type_type, xlim=xlim, permuted=permuted, base_extra=base_extra, color_same=color_same, fr_name=fr_name, eod_metrice=eod_metrice, isi_delta=isi_delta, tags_individual=tags_individual, RAM=RAM, add_texts=add_texts, titles=titles, xlim_p=xlim_p, peaks_extra=peaks_extra, scale_val=scale_val, ) save_visualization(pdf=True, individual_tag=cells_plot2[0]) def default_settings_cells_susept(cells_plot, l=3.7): default_figsize(column=2, width=12, length=l * len(cells_plot)) # ts=10, fs=10, ls=10, def big_grid_susept_pics(cells_plot, top=0.96, bottom=0.065): grid = gridspec.GridSpec(1, 1, wspace=0.1, hspace=0.5, top=top, left=0.08, bottom=bottom, right=0.95) grid1 = gridspec.GridSpecFromSubplotSpec(len(cells_plot), 1, grid[0], hspace=0.35, wspace=0.35) # , return grid1 def show_func(show=True): if show: if os.path.exists('..\code\calc_model'): plt.show() else: plt.close() else: plt.close() def plt_RAM_overview_all_filename_selected(): save_name = 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s' frame_load = pd.read_csv(load_folder_name('calc_RAM') + '/' + save_name + '.csv') scores = ['perc95_perc5_fr', 'perc80_perc5_fr', 'entropy_mat_fr', 'entropy_diagonal_fr', ] col = 4 row = 2 cell_types = [' Ampullary', ' P-unit', ] fig, ax = plt.subplots(row, col, sharex=True, figsize=(14, 7.5)) # constrained_layout=True, for c, cell_type_here in enumerate(cell_types): cell_type = frame_load.cell_type p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit' frame = frame_load.loc[p_pos] plt.suptitle(cell_type_here + ' \n ' + save_name) for s, score in enumerate(scores): file_names = ['gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] cmap = rainbow_cmap(file_names, nrs=len(file_names)) for ff, file in enumerate(file_names): frame_file = frame[frame.file_name == file] try: ax[c, s].scatter(frame_file['cv_wo_burstcorr'], np.array(frame_file[score]), color=cmap[ff], label=file) # colors[' P-unit'] except: print('axs thing1') embed() ax[c, s].set_title(score, fontsize=11) if s < row * col - col: pass else: ax[c, s].set_xlabel('CV') # fr_names[f] ax[c, s].set_xlim(0, 2) if s == len(scores) - 1: ax[c, s].legend(ncol=1, loc=(1.3, 0)) ax[c, 0].set_ylabel(cell_type_here) for a in range(4): set_same_ylim(ax[:, a]) plt.subplots_adjust(left=0.06, right=0.8, top=0.83, wspace=0.45, hspace=0.3) save_visualization(individual_tag='_score_' + str(score) + '_celltype_' + str(cell_type_here)) # def plt_data_noise(time_wn, stimulus_wn, nfft, sampling, mt, b, m): plt.subplot(2, 3, 1) plt.plot(time_wn, stimulus_wn) plt.xlim(0, 0.3) plt.subplot(2, 3, 4) p, f = ml.psd(stimulus_wn - np.mean(stimulus_wn), Fs=sampling, NFFT=nfft, noverlap=nfft // 2, sides='twosided') plt.plot(f, p) plt.xlim(0, 0.3) eod_mt_test, spikes_mt, sampling = link_arrays(b, mt.positions[:][m], mt.extents[:][m]) time_here = np.arange(0, len(eod_mt_test) / sampling, 1 / sampling) plt.subplot(1, 3, 2) plt.plot(time_here - np.min(time_here), eod_mt_test) plt.xlim(0, 0.3) plt.subplot(1, 3, 3) plt.plot(time_here - np.min(time_here), eod_mt_test) plt.plot(time_wn, amp * stimulus_wn + 0.4, color='red') plt.xlim(0, 0.3) plt.show() def plt_data_overview_amps(ax): save_names = [ 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', ] x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_stim_wo_burstcorr", "cv_stim_w_burstcorr"] save_names_title = ['No burst correction (cv_base)', 'Burst correction (cv_base)', 'No burst correction (cv_stim)', 'Burst correction (cv_stim)'] counter = 0 for cv_n, cv_name in enumerate(x_axis): frame_load = load_overview_susept(save_names[cv_n]) scores = [ 'perc95_perc5_fr', ] cell_types = [' P-unit', ' Ampullary', ] for c, cell_type_here in enumerate(cell_types): cell_type = frame_load['celltype'] # 'cell_type_reclassified' p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit' frame = frame_load.loc[p_pos] for s, score in enumerate(scores): file_names = ['InputArr_50to100hz_30', 'InputArr_150to200hz_30', 'InputArr_250to300hz_30', 'InputArr_350to400hz_30', 'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3', 'gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] file_names_there = ['gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] frame_file_ex = frame[frame.file_name.isin(file_names_there)] frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9] print(len(file_names)) frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file] amps = np.array(frame_file.amp.unique()) cmap = rainbow_cmap(amps, nrs=len(amps)) for a, amp in enumerate(amps): frame_amp = frame_file[frame_file.amp == amp] cvs = frame_amp[cv_name] # x_axis = cvs[frame_amp[score] > 0] y_axis = np.array(frame_amp[score])[frame_amp[score] > 0] ax[counter].set_title(save_names_title[cv_n]) max_val = 1.5 if 'P-unit' in cell_type_here: marker = '.' else: marker = '*' try: ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val], color=cmap[a], alpha=0.45, marker=marker, s=10) # colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, except: print('axs thing2') embed() ax[counter].set_xlabel(cv_name) if counter == 0: ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes ax[counter].set_xlim(0, max_val) counter += 1 return cell_type_here, score def plt_data_overview2(ax, scores=['perc95_perc5_fr']): ########################## # Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln save_names = [ 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', ] x_axis = ["cv_base", "cv_base", "cv_base_w_burstcorr", "cv_base_w_burstcorr"] save_names_title = ['No burst correction', 'Burst correction', 'No burst correction', 'Burst correction'] counter = 0 for cv_n, cv_name in enumerate(x_axis): frame_load = load_overview_susept(save_names[cv_n]) cell_types = [' P-unit', ' Ampullary', ] for c, cell_type_here in enumerate(cell_types): cell_type = frame_load['cell_type_reclassified'] # 'celltype' 'cell_type_reclassified' p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit' frame = frame_load.loc[p_pos] for s, score in enumerate(scores): # todo: hier den Übergang womöglich soft machen mod_limits = mod_lims_modulation(cell_type_here, frame_load, score) if cell_type_here == ' P-unit': cm = 'Blues' else: cm = 'Greens' cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[ ::-1] # len(amps) cmap = cmap[0:len(mod_limits)][::-1] for ff, amp in enumerate(range(len(mod_limits) - 1)): file_names = ['InputArr_50to100hz_30', 'InputArr_150to200hz_30', 'InputArr_250to300hz_30', 'InputArr_350to400hz_30', 'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3', 'gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] ############## # Auschlusskriterium 1, nur RAMs die bei Null anfangen file_names_there = ['gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] frame_amp = frame[(frame['response_modulation'] > mod_limits[ff]) & ( frame['response_modulation'] <= mod_limits[ff + 1])] try: frame_file_ex = frame_amp[frame_amp.file_name.isin(file_names_there)] except: print('file thing') embed() ############## # Auschlusskriterium 2, mindestens 9 Sekunden frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9] print(len(file_names)) frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file] ############## # Auschlusskriterium 3, kleiner als 10 % Kontrast # oder nicht ausschließen und stattdessen Modulation Farben! # frame_amp = frame_file[frame_file.amp < 9] cvs = frame_file[cv_name] # x_axis = cvs[frame_file[score] > 0] y_axis = np.array(frame_file[score])[frame_file[score] > 0] ax[counter].set_title(save_names_title[cv_n]) max_val = 1.5 try: ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val], alpha=1, s=2.5, color=cmap[ ff], ) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, except: print('axs thing3') embed() ax[counter].set_xlabel(cv_name) test = False if test: for ff in range(len(cmap)): plt.plot([1, 2], [1, 2 * ff], color=cmap[ff]) plt.show() # embed() if counter == 0: ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes ax[counter].set_xlim(0, max_val) counter += 1 return cell_type_here, score def plt_data_overview(ax, scores=['perc95_perc5_fr']): save_names = [ 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s', 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', ] x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_stim_wo_burstcorr", "cv_stim_w_burstcorr"] save_names_title = ['No burst correction (cv_base)', 'Burst correction (cv_base)', 'No burst correction (cv_stim)', 'Burst correction (cv_stim)'] counter = 0 for cv_n, cv_name in enumerate(x_axis): frame_load = load_overview_susept(save_names[cv_n]) colors = colors_overview() cell_types = [' P-unit', ' Ampullary', ] for c, cell_type_here in enumerate(cell_types): cell_type = frame_load['celltype'] # 'cell_type_reclassified' p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit' frame = frame_load.loc[p_pos] for s, score in enumerate(scores): file_names = ['InputArr_50to100hz_30', 'InputArr_150to200hz_30', 'InputArr_250to300hz_30', 'InputArr_350to400hz_30', 'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3', 'gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] ############## # Auschlusskriterium 1, nur RAMs die bei Null anfangen file_names_there = ['gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] frame_file_ex = frame[frame.file_name.isin(file_names_there)] ############## # Auschlusskriterium 2, mindestens 9 Sekunden frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9] print(len(file_names)) frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file] ############## # Auschlusskriterium 3, kleiner als 10 % Kontrast frame_amp = frame_file[frame_file.amp < 9] cvs = frame_amp[cv_name] # x_axis = cvs[frame_amp[score] > 0] y_axis = np.array(frame_amp[score])[frame_amp[score] > 0] ax[counter].set_title(save_names_title[cv_n]) max_val = 1.5 try: ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val], color=colors[cell_type_here], alpha=0.45, s=5) # colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, except: print('axs thing4') embed() ax[counter].set_xlabel(cv_name) if counter == 0: ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes ax[counter].set_xlim(0, max_val) counter += 1 return cell_type_here, score def plt_power2(spikes_all_here, axp, color='blue'): spikes_mat = [[]] * len(spikes_all_here) sampling_calc = 40000 nfft = 2 ** 14 p_array = [[]] * len(spikes_all_here) alpha = 1 for s, sp in enumerate(spikes_all_here): spikes_mat[s] = cr_spikes_mat(np.array(sp) / 1000, sampling_rate=sampling_calc, length=int(sampling_calc * np.array(sp[-1]) / 1000)) p_array[s], f_array = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=sampling_calc, NFFT=nfft, noverlap=nfft // 2) axp.plot(f_array, p_array[s], alpha=float(alpha - 0.05 * s), color=color) # color=colors[str(cell_type)], axp.set_xlim(0, 1000) axp.set_xlabel('Hz') axp.set_ylabel('Hz') return p_array, f_array def find_names_gwn(b): names_mt_gwns = [] for mts in b.multi_tags: if find_gwn(mts): names_mt_gwns.append(mts.name) return names_mt_gwns def find_gwn(trials): return ('file' in trials.name) or ('noise' in trials.name) or ('gwn' in trials.name) or ( 'InputArr' in trials.name) or ( 'FileStimulus-file-gaussian' in trials.name) def model_and_data_isi_power(nr_clim=10, many=False, width=0.005, row='no', HZ50=True, fs=8, nffts=['whole'], cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'], contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], clims='all', restrict='restrict'): stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 repeats = [9] # 30 powers = [3] # ,1] aa = 0 good_data, remaining = overlap_cells() cells_all = good_data for _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, in it.product( cells, D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 if row == 'no': col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange( else: pass if row == 2: default_settings(column=2, length=7.5) # 2+2.25+2.25 elif row == 1: default_settings(column=2, length=4) grid = gridspec.GridSpec(1, 4, wspace=0.6, bottom=0.075, hspace=0.13, left=0.08, right=0.93, top=0.88, width_ratios=[0.7, 1, 1, 1]) a = 0 maxs = [] mins = [] ims = [] perc05 = [] perc95 = [] iternames = [D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' print(cells_all) for all in it.product(*iternames): var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all fig = plt.figure() hs = 0.45 ################################# # model cells adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width = plt_model_part(HZ50, a, a_fe, a_fr, adapt_type, c_noise, c_sig, cell_recording_save_name, cells_all, cut_off1, cut_off2, dendrid, extract, fft_i, fft_o, fig, fs, grid, hs, ims, mV_unit, many, maxs, mimick, mins, nfft, noise_added, nr, perc05, perc95, power, ref_type, repeats, spikes_unit, stim_type_afe, stim_type_noise, stimulus_length, trans, trial_nrs, var_items, var_type, variant, width) ################################# # data cells grid_data = gridspec.GridSpecFromSubplotSpec(len(cells_all), 1, grid[1], hspace=hs) ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width, cbar_label=False) for ax in ax_data: ax.set_ylabel(F2_xlabel()) ################################# # plt isi of data grid_isi = gridspec.GridSpecFromSubplotSpec(len(cells_all), 1, grid[0], hspace=hs) ax_isi = plt_isi(cells_all, grid_isi, stack_spikes=stack_spikes_all, eod_frs=eod_frs) ax_isi[0].get_shared_x_axes().join(*ax_isi) end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title) # +' file ' set_clim_same(ims, perc05, perc95, mins, maxs, nr_clim, clims) axes = np.array([np.array(ax_data), np.array(ax_model[0:int(len(ax_model) / 2)]), np.array(ax_model[int(len(ax_model) / 2)::]), np.array(ax_isi)]) fig.tag(np.transpose(axes), xoffs=-3, yoffs=2.9, minor_index=2) save_visualization(pdf=True) def model_and_data(width=0.005, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], contrasts=[0], noises_added=[''], D_extraction_method=['additiv_cv_adapt_factor_scaled'], internal_noise=['RAM'], external_noise=['RAM'], level_extraction=[''], receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300]): # ['eRAM'] stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 good_data, remaining = overlap_cells() cells_all = [good_data[0]] plot_style() default_settings(column=2, length=4.9) # 0.75 grid = gridspec.GridSpec(3, 4, wspace=0.95, bottom=0.07, hspace=0.23, left=0.09, right=0.9, top=0.92) a = 0 maxs = [] mins = [] mats = [] ims = [] iternames = [D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] lp = 2 for all in it.product(*iternames): var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all fig = plt.figure() hs = 0.45 ################################# # data cells grid_data = gridspec.GridSpecFromSubplotSpec(1, 1, grid[0, 1], hspace=hs) ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width, cbar_label=True, lp=lp, title=True) for ax in ax_data: ax.set_xticks_delta(100) set_ylabel_arrow(ax, xpos=xpos_y_modelanddata(), ypos=0.87) set_xlabel_arrow(ax) ax.arrow_spines('lb') ################################## # model part cell = '2012-07-03-ak-invivo-1' cells_given = [cell] save_names = [ 'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', 'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', 'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', 'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV', ] nrs_s = [2, 3, 6, 7, 10, 11] titles = ['Trials=11 c=0.01', 'Trials=500000 c=0.01', 'Trials=11 \n Intrinsic split', 'Trials=500000\n Intrinsic split', 'Trials=11 c=0.01\n Intrinsic split', 'Trials=500000 c=0.01\n Intrinsic split'] ax_model = [] for s, sav_name in enumerate(save_names): try: ax = plt.subplot(grid[nrs_s[s]]) except: print('vers something') embed() ax_model.append(ax) save_name = load_folder_name('calc_model') + '/' + sav_name cell_add, cells_save = find_cell_add(cells_given) perc = 'perc' path = save_name + '.pkl' # '../'+ stack = load_model_susept(path, cells_save, save_name.split(r'/')[-1] + cell_add) add_nonlin_title, cbar, fig, stack_plot, im = plt_single_square_modl(ax, cell, stack, perc, titles[s], width, titles_plot=True, resize=True) ims.append(im) mats.append(stack_plot) maxs.append(np.max(np.array(stack_plot))) mins.append(np.min(np.array(stack_plot))) col = 2 row = 3 ax.set_xticks_delta(100) ax.set_yticks_delta(100) cbar.set_label(nonlin_title(' [' + add_nonlin_title), labelpad=lp) # rotation=270, if (s in np.arange(col - 1, 100, col)) | (s == 0): remove_yticks(ax) else: set_ylabel_arrow(ax, xpos=xpos_y_modelanddata(), ypos=0.87) if s >= row * col - col: set_xlabel_arrow(ax) else: remove_xticks(ax) if len(cells) > 1: a += 1 set_clim_same(ims, mats=mats, lim_type='up', nr_clim='perc', clims='', percnr=95) ################################################# # Flowcharts var_types = ['', 'additiv_cv_adapt_factor_scaled', 'additiv_cv_adapt_factor_scaled'] a_fes = [0.009, 0, 0.009] eod_fe = [750, 750, 750] ylim = [-0.5, 0.5] c_sigs = [0, 0.9, 0.9] grid_left = [[], grid[1, 0], grid[2, 0]] ax_ams = [] for g, grid_here in enumerate([grid[0, 0], grid[1, 1], grid[2, 1]]): grid_lowpass = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid_here, hspace=0.2, height_ratios=[1, 1, 0.1]) models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0] cell = model_params.pop('cell') # .iloc[0]# Werte für das Paper nachschauen eod_fr = model_params['EODf'] # .iloc[0] deltat = model_params.pop("deltat") # .iloc[0] v_offset = model_params.pop("v_offset") # .iloc[0] print(var_types[g] + ' a_fe ' + str(a_fes[g])) noise_final_c, spike_times, stimulus, stimulus_here, time, v_dent_output, v_mem_output, frame = get_flowchart_params( a_fes, a_fr, g, c_sigs[g], cell, deltat, eod_fr, model_params, stimulus_length, v_offset, var_types, eod_fe=eod_fe) if (len(np.unique(frame.RAM_afe)) > 1) & (len(np.unique(frame.RAM_noise)) > 1): grid_lowpass2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=grid_left[g], hspace=0.2) axt_p2 = plt_time_arrays('purple', grid_lowpass2, 1, frame.RAM_noise, time=time, nr=0) axt_p2.text(-0.6, 0.5, '$\%$', rotation=90, va='center', transform=axt_p2.transAxes) color_timeseries = 'black' axt_p2.set_xlabel('Time [ms]') axt_p2.text(-0.6, 0.5, '$\%$', rotation=90, va='center', transform=axt_p2.transAxes) ax_ams.append(axt_p2) elif len(np.unique(frame.RAM_afe)) > 1: color_timeseries = 'red' elif len(np.unique(frame.RAM_noise)) > 1: color_timeseries = 'purple' print(str(g) + ' afevar ' + str(np.var(frame.RAM_afe)) + ' afenoise ' + str(np.var(frame.RAM_noise))) try: ax, ff, pp, ff_am, pp_am = plot_lowpass2([grid_lowpass[0]], time, frame.RAM_afe + frame.RAM_noise, deltat, eod_fr, color1=color_timeseries, lw=1, extract=False) except: print('add up thing') embed() ax.text(-0.6, 0.5, '$\%$', va='center', rotation=90, transform=ax.transAxes) ax_ams.append(ax) remove_xticks(ax) ax_n, ff, pp, ff_am, pp_am = plot_lowpass2([grid_lowpass[1]], time, noise_final_c, deltat, eod_fr, extract=False, color1='grey', lw=1) remove_yticks(ax_n) if g == 1: remove_xticks(ax_n) else: ax_n.set_xlabel('Time [ms]') ax_n.set_ylim(ylim) set_same_ylim(ax_ams, up='up') axes = np.concatenate([ax_data, ax_model]) axes = [ax_ams[0], axes[0], axes[1], axes[2], ax_ams[1], axes[3], axes[4], ax_ams[2], ax_ams[3], axes[5], axes[6], ] fig.tag(axes, xoffs=-3, yoffs=2) save_visualization(pdf=True) def xpos_y_modelanddata(): return -0.52 def F1_xlabel(): return '$f_{1}$\,[Hz]' def plt_model_part(HZ50, a, a_fe, a_fr, adapt_type, c_noise, c_sig, cell_recording_save_name, cells_all, cut_off1, cut_off2, dendrid, extract, fft_i, fft_o, fig, fs, grid, hs, ims, mV_unit, many, maxs, mimick, mins, nfft, noise_added, nr, perc05, perc95, power, ref_type, repeats, spikes_unit, stim_type_afe, stim_type_noise, stimulus_length, trans, trial_nrs, var_items, var_type, variant, width, xlabels=True, perc='', label=nonlin_title(), rows=2, title=True): ax_model = [] for t, trials_stim in enumerate(repeats): grid_model = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[2 + t], hspace=hs) save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') path = save_name + '.pkl' print(t) print(path) model = load_model_susept(path, cells_all, save_name + 'all') # cells adapt_type_name, ref_type_name, dendrid_name, stim_type_noise_name = define_names(var_type, stim_type_noise, dendrid, ref_type, adapt_type) if len(model) > 0: model = model[model.cell.isin(cells_all)] # ('cv_stim') try: cells_all = model.groupby('cv_stim').first().sort_values(by='cv_stim').cell # ('cv_stim') except: print('model thing') for c, cell in enumerate(cells_all): print(c) try: ax = plt.subplot(grid_model[c]) except: print('something') embed() titles = '' suptitles = '' stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) suptitles, titles = find_titles_RAM(a_fe, cell, extract, noise_added, stim_type_afe_name, stim_type_noise_name2, suptitles, titles, trials_stim, var_items, var_type) model_show = model[ (model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] new_keys = model_show.index.unique() # [0:490] try: stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) ax.set_xlim(0, 300) ax.set_ylim(0, 300) ax.set_aspect('equal') ax.set_xticks_delta(100) ax.set_yticks_delta(100) ax.arrow_spines('lb') model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength # (noise_strength ** 2) / 2 D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based stack_plot = RAM_norm(stack_plot, trials_stim, D_derived, model_show=model_show) if many == True: titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0])) color = title_color(cell) if title: if t == 0: ax.set_title( titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str( int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str( np.round(model_show.cv.iloc[0], 2)) + ' $cv_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) + ' s=' + str( np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color) # + ' $D_{sig}$=' + str(np.round(D_derived, 5)) im = plt_RAM_perc(ax, perc, stack_plot) ims.append(im) maxs.append(np.max(np.array(stack_plot))) mins.append(np.min(np.array(stack_plot))) perc05.append(np.percentile(stack_plot, 5)) perc95.append(np.percentile(stack_plot, 95)) plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300, model_show.eod_fr.iloc[0]) if HZ50: plt_50_Hz_noise(ax, 300) ax.set_aspect('equal') cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, width=width) if t == 1: ax.text(1.5, 0.5, label, rotation=90, ha='center', va='center', transform=ax.transAxes) remove_yticks(ax) if (c == len(cells_all) - 1) & xlabels: ax.text(1.05, -0.35, F1_xlabel(), ha='center', va='center', transform=ax.transAxes) ax.arrow_spines('lb') else: remove_xticks(ax) print(c) ax_model.append(ax) a += 1 print('model done') return adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width def find_titles_RAM(a_fe, cell, extract, noise_added, stim_type_afe_name, stim_type_noise_name2, suptitles, titles, trials_stim, var_items, var_type): if 'cells' in var_items: titles += cell[2:13] else: suptitles += cell[2:13] if 'internal_noise' in var_items: titles += ' intrinsic noise=' + stim_type_noise_name2 else: suptitles += ' intrinsic noise=' + stim_type_noise_name2 if 'external_noise' in var_items: titles += ' additive RAM=' + stim_type_afe_name else: suptitles += ' additive RAM=' + stim_type_afe_name if 'repeats' in var_items: titles += ' $N_{repeat}=$' + str(trials_stim) else: suptitles += ' $N_{repeat}=$' + str(trials_stim) if 'contrasts' in var_items: titles += ' contrast=' + str(a_fe) else: suptitles += ' contrast=' + str(a_fe) if 'level_extraction' in var_items: titles += ' Extract Level=' + str(extract) else: suptitles += ' Extract Level=' + str(extract) if 'D_extraction_method' in var_items: titles += str(var_type) else: suptitles += str(var_type) if 'noises_added' in var_items: titles += ' high freq noise=' + str(noise_added) else: suptitles += ' high freq noise=' + str(noise_added) return suptitles, titles def set_clim_same(ims, perc05=[], val_chosen=None, percnr=None, perc95=[], mins=[], maxs=[], mats=[], nr_clim='perc', lim_type='', clims='all', same='', mean_type=False, clim_min=[], clim_max=[]): if clims == 'all': if same == 'same': if len(clim_min) < 1: clim_min = [] clim_max = [] for a, im in enumerate(ims): clim_min.append(im.get_clim()[0]) clim_max.append(im.get_clim()[1]) lim = np.max([np.abs(np.min(clim_min)), np.abs(np.max(clim_max))]) for a, im in enumerate(ims): if lim_type == 'same': ims[a].set_clim(-lim, lim) elif lim_type == 'up': ims[a].set_clim(0, lim) else: ims[a].set_clim(np.min(clim_min), np.max(clim_max)) else: if len(mats) < 1: if nr_clim == 'perc': for im in ims: if lim_type == 'up': im.set_clim(0, np.max(perc95)) else: im.set_clim(np.min(perc05), np.max(perc95)) else: for im in ims: im.set_clim(np.min(np.min(mins)) * nr_clim, np.max(np.max(maxs) / nr_clim)) else: maxs, mins, perc05, perc95 = get_perc_vals(mats, percnr) for i, im in enumerate(ims): if nr_clim == 'perc': if lim_type == 'up': if mean_type: im.set_clim(0, np.mean(perc95)) else: im.set_clim(0, np.max(perc95)) else: im.set_clim(np.min(perc05), np.max(perc95)) else: im.set_clim(np.min(mins) * nr_clim, np.max(maxs) / nr_clim) # todo: noch alle clim funkcitonen fusioenier else: if len(mats) < 1: for i, im in enumerate(ims): if nr_clim == 'perc': if lim_type == 'up': im.set_clim(0, perc95[i]) else: im.set_clim(perc05[i], perc95[i]) elif nr_clim == 'None': values = im.get_clim() if lim_type == 'up': if val_chosen: im.set_clim(0, val_chosen) else: im.set_clim(0, values[-1]) else: if lim_type == 'up': im.set_clim(0, maxs[i] / nr_clim) else: im.set_clim(mins[i] * nr_clim, maxs[i] / nr_clim) else: maxs, mins, perc05, perc95 = get_perc_vals(mats, percnr) for i, im in enumerate(ims): if nr_clim == 'perc': if lim_type == 'up': im.set_clim(0, perc95[i]) else: im.set_clim(perc05[i], perc95[i]) else: if lim_type == 'up': im.set_clim(0, maxs[i] / nr_clim) else: im.set_clim(mins[i] * nr_clim, maxs[i] / nr_clim) def get_perc_vals(mats, percnr): perc05 = [] perc95 = [] mins = [] maxs = [] for m in range(len(mats)): mins.append(np.min(mats[m])) if not percnr: perc05.append(np.percentile(mats[m], 5)) perc95.append(np.percentile(mats[m], 95)) else: perc05.append(np.percentile(mats[m], 100 - percnr)) perc95.append(np.percentile(mats[m], percnr)) maxs.append(np.min(mats[m])) return maxs, mins, perc05, perc95 def plt_isi(cells_all, grid_isi, stack_spikes=[], eod_frs=[]): frame = load_cv_base_frame(cells_all) ax_isi = [] for f, cell in enumerate(cells_all): axi = plt.subplot(grid_isi[f]) frame_cell = frame[(frame['cell'] == cell)] # todo: hier mit dem EODfr nochmal schauen if len(stack_spikes) > 0: spikes = [] hists = [] for sp in range(len(stack_spikes[f].keys())): spikes.append(np.array(stack_spikes[f][sp])) hists.append(np.diff(spikes[-1]) / (1 / eod_frs[f])) else: spikes = frame_cell.spikes.iloc[0] spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_frs[f]) remove_yticks(axi) axi.spines['left'].set_visible(False) alpha = 1 for hh, h in enumerate(hists): try: axi.hist(h, bins=100, color='blue', alpha=float(alpha - 0.05 * hh)) except: print('hist i') embed() ax_isi.append(axi) axi.spines['left'].set_visible(False) remove_yticks(axi) if f == len(cells_all) - 1: axi.set_xlabel('EODf multiple') return ax_isi def group_the_certain_group(grouped, DF2_desired, DF1_desired): mult1 = np.array([a_tuple[2][0] for a_tuple in grouped.groups.keys()]) mult2 = np.array([a_tuple[2][1] for a_tuple in grouped.groups.keys()]) mult_array = np.round(np.abs(mult1 - DF1_desired) + np.abs((mult2 - DF2_desired)), 2) restrict = np.argmin(mult_array) return restrict def extract_waves(variant, cell, stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, e, eod_fj, a_fj, phase_r=0, nfft_for_morph=4068 * 4, phase_e=0): if 'receiver' in variant: time, time_fish_r, eod_fish_r, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves( nfft_for_morph, cell, a_fr=a_fr, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fr) else: time = np.arange(0, stimulus_length, deltat) time_fish_r = time * 2 * np.pi * eod_fr eod_fish_r = a_fr * np.sin(time_fish_r + phase_r) if 'emitter' in variant: time, time_fish_e, eod_fish_e, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves( nfft_for_morph, cell, a_fr=a_fe, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fe[e]) else: time = np.arange(0, stimulus_length, deltat) time_fish_e = time * 2 * np.pi * eod_fe[e] eod_fish_e = a_fe * np.sin(time_fish_e + phase_e) if 'jammer' in variant: time, time_fish_e, eod_fish_e, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves( nfft_for_morph, cell, a_fr=a_fj, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fj) else: time = np.arange(0, stimulus_length, deltat) time_fish_j = time * 2 * np.pi * eod_fj eod_fish_j = a_fj * np.sin(time_fish_j + phase_e) time_fish_sam = time * 2 * np.pi * (eod_fe[e] - eod_fr) eod_fish_sam = a_fe * np.sin(time_fish_sam) stimulus_am = eod_fish_e + eod_fish_r + eod_fish_j stimulus_sam = eod_fish_r * (1 + eod_fish_sam) return eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam def plot_shemes3(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=0, waves_present=['receiver', 'emitter', 'jammer', 'all'], sheme_shift=0, title=[]): stimulus = np.zeros(len(eod_fish_r)) ax = [] xlim = [0, 0.05] for ww, w in enumerate(waves_present): ax = plt.subplot(grid0[ww + sheme_shift, g]) if w == 'receiver': if title: plt.title(title) plt.plot(time, eod_fish_r, color='grey') stimulus += eod_fish_r plt.ylim(-1.1, 1.1) plt.xlim(xlim) ax.spines['bottom'].set_visible(False) if g == 0: plt.ylabel('f0') elif w == 'emitter': ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20) plt.plot(time, eod_fish_e, color='orange') stimulus += eod_fish_e plt.ylim(-1.1, 1.1) plt.xlim(xlim) ax.spines['bottom'].set_visible(False) if g == 0: plt.ylabel('f1') elif w == 'jammer': ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20) plt.plot(time, eod_fish_j, color='purple') stimulus += eod_fish_j plt.ylim(-1.1, 1.1) plt.xlim(xlim) if g == 0: plt.ylabel('f2') elif w == 'all': ax.text(0.5, 1.25, '$=$', va='center', ha='center', transform=ax.transAxes, fontsize=20) plt.plot(time, stimulus, color='grey') plt.ylim(-1.2, 1.2) plt.xlim(xlim) if g == 0: plt.ylabel('Stimulus') if g == 0: if ww == 0: plt.ylabel('f0') elif ww == 1: plt.ylabel('f1') elif ww == 2: plt.ylabel('f2') elif ww == 3: plt.ylabel('Stimulus') ax.show_spines('') ax.set_xticks([]) ax.set_yticks([]) return ax def plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, ylim=[-1.1, 1.1], g=0, waves_present=['receiver', 'emitter', 'jammer', 'all'], eod_fr=700, xlim=[0, 0.05], color_am2='purple', extracted=False, extracted2=False, color_am='black', title_top=False, title=[]): stimulus = np.zeros(len(eod_fish_r)) ax = plt.subplot(grid0[0, g]) for ww, w in enumerate(waves_present): if title_top: if ww == 0: # ok das ist alles anders zentriert, wie ich denke # ha bedeutet dass das alignment horizontal ist nicht der anker # und auch rechts und bottom ist genau anders herum # EGOZENTRISCHE und nicht ALOZENTRISCHE Ausrichtung!! # todo: vielleicht eine funktion die das in allozentrisch ändert, das kann sich doch keiner merken ax.text(1, 1, title, va='bottom', ha='right', transform=ax.transAxes) if w == 'receiver': stimulus += eod_fish_r elif w == 'emitter': stimulus += eod_fish_e elif w == 'jammer': stimulus += eod_fish_j elif w == 'all': eod_interp, eod_norm = extract_am(stimulus, time, norm=False, sampling=1 / time[1], eodf=eod_fr, emb=False, extract='') plt.plot(time, stimulus, color='grey', linewidth=0.5) if extracted: # plt.plot(time, eod_interp, color=color_am, linewidth=1) if extracted2: # eod_interp2, eod_norm = extract_am(eod_interp, time, norm=False, sampling=1 / time[1], eodf=eod_fr, emb=False, extract='') test = False if test: nfft = 2 ** 16 _, _ = ml.psd(eod_interp2 - np.mean(eod_interp2), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # _, _ = ml.psd(eod_interp - np.mean(eod_interp), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # plt.plot(time, eod_interp2, color=color_am2, linewidth=1) plt.ylim(-1.2, 1.2) if len(xlim) > 0: plt.xlim(xlim) plt.ylim(ylim) if g == 0: plt.ylabel('stimulus') if g == 0: if ww == 3: plt.ylabel('stimulus') ax.show_spines('') ax.set_xticks([]) ax.set_yticks([]) return ax def motivation_small_roc(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'], figsize=None, end='0', cut_matrix='malefemale', chose_score='mean_nrs', detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'): plot_style() default_settings(column=2, length=3.7) # 3.3ts=12, ls=12, fs=12 show = True # mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_' datasets, data_dir = find_all_dir_cells() # '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ] autodefine = '_dfchosen_closest_first_' cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ] # c1s = [10] # 1, 10, # c2s = [10] c2 = 10 # detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort' # detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch'] # phase_sorting = ''#'PhaseSort' eodftype = '_psdEOD_' indices = ['_allindices_'] chirps = [ ''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_' extract = '' # '_globalmax_' if len(cells) < 1: data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets) _, _, _ = restrict_cell_type(cells, 'p-units') start = 'min' # cells = ['2022-01-28-ah-invivo-1'] DF2_desired = [-175] DF1_desired = [-99] for c, cell in enumerate(cells): if not c2: contrasts = [10, 5, 3, 1] else: contrasts = [c2] if not c2: contrasts = [10, 5, 3, 1] else: contrasts1 = [c1] for c, contrast in enumerate(contrasts): contrast_small = 'c2' contrast_big = 'c1' for contrast1 in contrasts1: for devname_orig in devs: datapoints = [1000] for _ in datapoints: ################################ # prepare DF1 desired # chose_score = 'auci02_012-auci_base_01' # hier muss das halt stimmen mit der auswahl # hier wollen wir eigntlich kein autodefine # sondern wir wollen so ein diagonal ding haben divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value( DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small, contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine, cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01' DF1_desired = DF1_desired # [::-1] DF2_desired = DF2_desired # [::-1] # ROC part b = load_b_public(c, cell, data_dir) mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell) mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)] for gg in range(len(DF1_desired)): DF1_desired_ROC = [DF1_desired[gg]] DF2_desired_ROC = [DF2_desired[gg]] t3 = time.time() # all trials in one grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2'], as_index=False) grouped_mean = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True, emb=False, concat=True) # groups sorted by repro tag # todo: evnetuell die tuples gleich hier umspeichern vom csv '' grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2', 'repro_tag_id'], as_index=False) grouped_orig = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True) group_mean = [grouped_orig[0][0], grouped_mean] for d, detection in enumerate(detections): mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat if figsize: fig = plt.figure(figsize=figsize) else: fig = plt.figure() grid = gridspec.GridSpec(1, 3, wspace=0.35, hspace=0.5, left=0.05, top=top, bottom=0.14, right=0.95, width_ratios=[4.2, 1, 1]) # height_ratios = [1,6]bottom=0.25, top=0.8, grid0 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.15, hspace=0.06, subplot_spec=grid[0], height_ratios=[0.4, 3, 3]) # height_ratios=hr, grid_sheme = gridspec.GridSpecFromSubplotSpec(1, 4, wspace=0.15, hspace=0.35, subplot_spec=grid0[0]) xlim = [0, 100] fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine) ########################################### stimulus_length = 0.3 deltat = 1 / 40000 eodf = np.mean(group_mean[1].eodf) eod_fr = eodf a_fr = 1 eod_fe = eodf + np.mean( group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fe = group_mean[0][1] / 100 eod_fj = eodf + np.mean( group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = group_mean[0][0] / 100 variant_cell = 'no' # 'receiver_emitter_jammer' eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj) jammer_name = 'female' titles = ['receiver ', '+' + 'intruder ', '+' + jammer_name, '+' + jammer_name + '+intruder', []] ##'receiver + ' + 'receiver + receiver gs = [0, 1, 2, 3, 4] waves_presents = [['receiver', '', '', 'all'], ['receiver', 'emitter', '', 'all'], ['receiver', '', 'jammer', 'all'], ['receiver', 'emitter', 'jammer', 'all'], ] # ['', '', '', ''],['receiver', '', '', 'all'], symbols = ['', '', '', '', ''] time_array = time_array * 1000 color0_burst = 'darkgreen' color01 = 'blue' color02 = 'red' color012 = 'orange' colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012] extracted = [False, True, True, True] ax_w = [] for i in range(len(waves_presents)): ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid_sheme, time_array, g=gs[i], title_top=True, eod_fr=eod_fr, waves_present=waves_presents[i], ylim=ylim, xlim=xlim, color_am=colors_am[i], extracted=extracted[i], title=titles[i]) # 'intruder','receiver'#jammer_name ax_w.append(ax) if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) bar = False if bar: if i == 0: ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.16, '20 ms', va='center', fontsize=10, transform=ax.transAxes) printing = True if printing: print('time of arrays plotting: ' + str(time.time() - t3)) ########################################## # spike response means_here = ['_MeanTrialsIndexPhaseSort', 'AllTrialsIndex'] array_chosen = 1 for m, mean_type in enumerate(means_here): hr = [0.35, 1.2, 0, 3] grid_psd = gridspec.GridSpecFromSubplotSpec(4, 4, wspace=0.15, hspace=0.35, subplot_spec=grid0[m + 1], height_ratios=hr, ) if d == 0: # ############################################################## # load plotting arrays arrays, arrays_original, spikes_pure = save_arrays_susept( data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0, rocextra=False, sorted_on=sorted_on) fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3( [arrays[0], arrays[2], arrays[1], arrays[3]], [arrays_original[0], arrays_original[2], arrays_original[1], arrays_original[3]], spikes_pure, cell, grid_psd, mean_type, group_mean, xlim=xlim, row=d * 3, array_chosen=array_chosen, ylim_log=(-50.5, 3), color0_burst=color0_burst, xlim_psd=[0, 550], color01=color01, color02=color02, color012=color012, add_burst_corr=True, log='') ################################################################### nrs = [1, 2] for n, nr in enumerate(nrs): grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.4, subplot_spec=grid[nr], height_ratios=[1, 1]) subdevision_nr = 3 dev = '05' datapoints_way = ['absolut'] color = ['red', 'green', 'lightblue', 'pink', ] fig = plt.gcf() plot_group = 0 ranges = [plot_group] _, _, _, _, _ = roc_part( titles, devs, group_mean, ranges, fig, subdevision_nr, datapoints, datapoints_way, color, c, chose_score, cell, DF1_desired_ROC, DF2_desired_ROC, contrast_small, contrast_big, contrast1, dfs, start, dev, contrast, grid2[0], plot_group, autodefine2='_dfchosen_', sorted_on=sorted_on, cut_matrix=cut_matrix, mean_type=means_here[n], extract=extract, mult_type=mult_type, eodftype=eodftype, indices=indices, c1=c1, c2=c2, autodefine=autodefine) ax = plt.gca() ax.set_title(means_here[n]) plot_group = 1 ranges = [plot_group] _, _, _, _, _ = roc_part( titles, devs, group_mean, ranges, fig, subdevision_nr, datapoints, datapoints_way, color, c, chose_score, cell, DF1_desired_ROC, DF2_desired_ROC, contrast_small, contrast_big, contrast1, dfs, start, dev, contrast, grid2[1], plot_group, sorted_on=sorted_on, autodefine2='_dfchosen_', cut_matrix=cut_matrix, mean_type=means_here[n], extract=extract, mult_type=mult_type, eodftype=eodftype, indices=indices, c1=c1, c2=c2, autodefine=autodefine) suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str( group_mean[0][2][0]) + ' DF1: ' + str( group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str( group_mean[0][1]) + '$\%$ m2: ' + str( group_mean[0][2][1]) + ' DF2: ' + str( group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str( len(group_mean[1])) + fr_end individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str( DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type axes = [] axes.append(ax_w) axes.extend(np.transpose(ax_as)) axes.append(np.transpose(ax_ps)) fig.tag(ax_w, xoffs=-1.5, yoffs=1.4) save_visualization(individual_tag=individual_tag, show=show, pdf=True) return suptitle def load_b_public(c, cell, data_dir): version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() if version_comp != 'public': full_path = find_nix_full_path(c, cell, data_dir) if os.path.exists(full_path): # todo: this maybe also has to be fixed print('do ' + cell) file = nix.File.open(full_path, nix.FileMode.ReadOnly) b = file.blocks[0] else: b = [] else: b = [] return b def motivation_all(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'], figsize=None, save=True, end='0', chose_score='mean_nrs', detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'): plot_style() default_settings(column=2, length=6.7) # 3.3ts=12, ls=12, fs=12 show = True # mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_' datasets, data_dir = find_all_dir_cells() DF2_desired = [-33] DF1_desired = [133] autodefine = '_dfchosen_closest_first_' cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ] # c1s = [10] # 1, 10, # c2s = [10] c2 = 10 # detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort' # detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch'] # phase_sorting = ''#'PhaseSort' eodftype = '_psdEOD_' indices = ['_allindices_'] chirps = [ ''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_' extract = '' # '_globalmax_' if len(cells) < 1: data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets) _, _, _ = restrict_cell_type(cells, 'p-units') start = 'min' # cells = ['2021-08-03-ac-invivo-1'] for c, cell in enumerate(cells): contrasts = [c2] for c, contrast in enumerate(contrasts): contrast_small = 'c2' contrast_big = 'c1' contrasts1 = [c1] for contrast1 in contrasts1: for devname_orig in devs: datapoints = [1000] for _ in datapoints: ################################ # prepare DF1 desired # chose_score = 'auci02_012-auci_base_01' # hier muss das halt stimmen mit der auswahl # hier wollen wir eigntlich kein autodefine # sondern wir wollen so ein diagonal ding haben divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value( DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small, contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine, cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01' DF1_desired = DF1_desired # [::-1] DF2_desired = DF2_desired # [::-1] # embed() ####################################### # ROC part _, _, _, _, _ = find_code_vs_not() b = load_b_public(c, cell, data_dir) mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell) mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)] for gg in range(len(DF1_desired)): t3 = time.time() ax_w = [] ################### # all trials in one grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2'], as_index=False) grouped_mean = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True, emb=False, concat=True) # groups sorted by repro tag # todo: evnetuell die tuples gleich hier umspeichern vom csv '' grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2', 'repro_tag_id'], as_index=False) grouped_orig = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True) ################### group_mean = [grouped_orig[0][0], grouped_mean] for d, detection in enumerate(detections): mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat arrays, arrays_original, spikes_pure = save_arrays_susept( data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0, rocextra=False, sorted_on=sorted_on) # hier checken wir ob für diesen einen Punkt das funkioniert mit der standardabweichung try: check_var_substract_method(spikes_pure) except: print('var checking not possible') if figsize: fig = plt.figure(figsize=figsize) else: fig = plt.figure() grid = gridspec.GridSpec(2, 3, wspace=0.7, hspace=0.35, left=0.075, top=top, bottom=0.1, height_ratios=[1, 2], right=0.935) # height_ratios = [1,6]bottom=0.25, top=0.8, hr = [1, 0.35, 1.2, 0, 3, ] # 1 # several coherence plot ax_w, d, data_dir, devs = plt_coherences(ax_w, d, devs, grid) # part with the power spectra grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35, subplot_spec=grid[1, :], height_ratios=hr) xlim = [0, 100] fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine) ########################################### stimulus_length = 0.3 deltat = 1 / 40000 eodf = np.mean(group_mean[1].eodf) eod_fr = eodf a_fr = 1 eod_fe = eodf + np.mean( group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fe = group_mean[0][1] / 100 eod_fj = eodf + np.mean( group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = group_mean[0][0] / 100 variant_cell = 'no' # 'receiver_emitter_jammer' eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj) jammer_name = 'female' cocktail_names = False if cocktail_names: titles = ['receiver ', '+' + 'intruder ', '+' + jammer_name, '+' + jammer_name + '+intruder', []] ##'receiver + ' + 'receiver + receiver else: titles = title_motivation() ##'receiver + ' + 'receiver + receiver gs = [0, 1, 2, 3, 4] waves_presents = [['receiver', '', '', 'all'], ['receiver', 'emitter', '', 'all'], ['receiver', '', 'jammer', 'all'], ['receiver', 'emitter', 'jammer', 'all'], ] # ['', '', '', ''],['receiver', '', '', 'all'], symbols = ['', '', '', '', ''] time_array = time_array * 1000 color0_burst = 'darkgreen' color01 = 'green' color02 = 'red' color012 = 'orange' colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012] extracted = [False, True, True, True] for i in range(len(waves_presents)): ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array, g=gs[i], title_top=True, eod_fr=eod_fr, waves_present=waves_presents[i], ylim=ylim, xlim=xlim, color_am=colors_am[i], extracted=extracted[i], title=titles[i]) # 'intruder','receiver'#jammer_name ax_w.append(ax) if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) bar = False if bar: if i == 0: ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.16, '20 ms', va='center', fontsize=10, transform=ax.transAxes) printing = True if printing: print(time.time() - t3) # spike response array_chosen = 1 if d == 0: # _, _, _ = plot_arrays_ROC_psd_single3( [arrays[0], arrays[2], arrays[1], arrays[3]], [arrays_original[0], arrays_original[2], arrays_original[1], arrays_original[3]], spikes_pure, cell, grid0, mean_type, group_mean, xlim=xlim, row=1 + d * 3, array_chosen=array_chosen, color0_burst=color0_burst, color01=color01, color02=color02, color012=color012) suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str( group_mean[0][2][0]) + ' DF1: ' + str( group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str( group_mean[0][1]) + '$\%$ m2: ' + str( group_mean[0][2][1]) + ' DF2: ' + str( group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str( len(group_mean[1])) + fr_end individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str( DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type axes = [] axes.append(ax_w) fig.tag(ax_w[0:3], xoffs=-2.3, yoffs=1.7) fig.tag(ax_w[3::], xoffs=-1.9, yoffs=1.4) if save: save_visualization(individual_tag=individual_tag, show=show, pdf=True) return suptitle def check_var_substract_method(spikes_pure): vars = {} for k, key in enumerate(spikes_pure.keys()): for j in range(len(spikes_pure[key])): spikes_mat = cr_spikes_mat(spikes_pure[key][j] / 1000, 40000, int(spikes_pure[key][j][-1] / 1000 * 40000)) # len(arrays[k][j]) smoothed = gaussian_filter(spikes_mat, sigma=0.0005 * 40000) if key not in vars: vars[key] = [np.var(smoothed)] else: vars[key].append(np.var(smoothed)) var_vals = [] for j in range(len(spikes_pure[key])): var_vals.append(vars['012'][j] - vars['control_01'][j] - vars['control_02'][j] + vars['base_0'][j]) # ja wenn das stabil wäre wäre das in Ordnung aber so weiß nciht print('single var vals:' + str(var_vals)) print('mean of single var vals:' + str(np.mean(var_vals))) def plt_coherences(ax_w, d, devs, grid): data_names, data_dir = find_all_dir_cells() cell_here = ['2021-08-03-ab-invivo-1'] # cell cell_here.extend(data_names) data_names = ['2021-08-03-ab-invivo-1'] for data_name in data_names: frame = load_coherence_file(data_name, '05') if len(frame) > 0: amps = np.sort(frame.amp.unique())[::-1] file_names = frame.file_name.unique() devs = ['05'] # original for a, amp in enumerate(amps): for file_name in file_names: ax = plt.subplot(grid[0, a]) ax.set_ylim(0, 1) for d, dev in enumerate(devs): frame = load_coherence_file(data_name, dev) if len(frame) > 0: frame_cell = frame[ (frame.file_name == file_name) & (frame.amp == amp)] names = ['coherence_s', 'coherence_r', 'coherence_r_exp' ] # 'coherence_r_direct_restrict', labels = ['SR', '$\sqrt{RR}$', 'RR$_{exp}$', ] colors = ['black', 'grey', 'brown'] # 'coherence_r_firstsnippet', linestyles = ['-', '-', '--', '-', '--', '-', '--'] # 'purple','-', for n, name in enumerate(names): if 'coherence_s' in name: ax.plot(frame_cell['f'], frame_cell[name] ** 2, label=labels[n], color=colors[n], linestyle=linestyles[ n]) # , 'MI_r_direct', 'coherence_r_direct_restrict', else: ax.plot(frame_cell['f'], frame_cell[name], label=labels[n], color=colors[n], linestyle=linestyles[ n]) # , 'MI_r_direct', 'coherence_r_direct_restrict', if amp < 1: amp_name = amp else: amp_name = int(amp) ax.set_title('Contrast=' + str(amp_name)) if a == 0: ax.legend(loc=(0.75, 0.75)) ax.set_ylabel('Coherence') ax.set_xlabel('Frequency [Hz]') ax.set_ylabel('Coherence') xlim = ax.get_xlim() ax.set_xlim(0, xlim[-1]) ax_w.append(ax) return ax_w, d, data_dir, devs def load_coherence_file(data_name, dev): save_name = load_folder_name( 'calc_RAM') + '/calc_coherence-coherence__cell_' + data_name + '_dev_' + dev + '.csv' load_function = find_load_function() name1 = load_function + save_name.split('/')[-1] if not os.path.exists(name1): frame = pd.read_csv(save_name, index_col=0) frame.to_csv(name1) frame = pd.read_csv(save_name, index_col=0) else: frame = pd.read_csv(name1, index_col=0) return frame def motivation_small(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'], figsize=None, save=True, end='0', chose_score='mean_nrs', detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'): plot_style() default_settings(column=2, length=3.5) # 3.3ts=12, ls=12, fs=12 show = True datasets, data_dir = find_all_dir_cells() # '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ] DF2_desired = [-33] DF1_desired = [133] autodefine = '_dfchosen_closest_first_' cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ] c2 = 10 # phase_sorting = ''#'PhaseSort' eodftype = '_psdEOD_' indices = ['_allindices_'] chirps = [ ''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_' extract = '' # '_globalmax_' if len(cells) < 1: data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets) _, _, _ = restrict_cell_type(cells, 'p-units') start = 'min' # cells = ['2021-08-03-ac-invivo-1'] for c, cell in enumerate(cells): contrasts = [c2] for c, contrast in enumerate(contrasts): contrast_small = 'c2' contrast_big = 'c1' contrasts1 = [c1] for contrast1 in contrasts1: for devname_orig in devs: datapoints = [1000] for _ in datapoints: ################################ # prepare DF1 desired # chose_score = 'auci02_012-auci_base_01' # hier muss das halt stimmen mit der auswahl # hier wollen wir eigntlich kein autodefine # sondern wir wollen so ein diagonal ding haben divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value( DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small, contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine, cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01' DF1_desired = DF1_desired # [::-1] DF2_desired = DF2_desired # [::-1] # ROC part _, _, _, _, _ = find_code_vs_not() b = load_b_public(c, cell, data_dir) mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell) mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)] for gg in range(len(DF1_desired)): t3 = time.time() # all trials in one grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2'], as_index=False) grouped_mean = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True, emb=False, concat=True) ################### # groups sorted by repro tag # todo: evnetuell die tuples gleich hier umspeichern vom csv '' grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2', 'repro_tag_id'], as_index=False) grouped_orig = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True) ################### group_mean = [grouped_orig[0][0], grouped_mean] for d, detection in enumerate(detections): mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat ############################################################## # load plotting arrays arrays, arrays_original, spikes_pure = save_arrays_susept( data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0, rocextra=False, sorted_on=sorted_on) #################################################### if figsize: fig = plt.figure(figsize=figsize) else: fig = plt.figure() grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=top, bottom=0.14, right=0.95) # height_ratios = [1,6]bottom=0.25, top=0.8, hr = [1, 0.35, 1.2, 0, 3, ] # 1 grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35, subplot_spec=grid[0], height_ratios=hr, ) xlim = [0, 100] fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine) ########################################### stimulus_length = 0.3 deltat = 1 / 40000 eodf = np.mean(group_mean[1].eodf) eod_fr = eodf a_fr = 1 eod_fe = eodf + np.mean( group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fe = group_mean[0][1] / 100 eod_fj = eodf + np.mean( group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = group_mean[0][0] / 100 variant_cell = 'no' # 'receiver_emitter_jammer' eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj) jammer_name = 'female' titles = ['receiver ', '+' + 'intruder ', '+' + jammer_name, '+' + jammer_name + '+intruder', []] ##'receiver + ' + 'receiver + receiver gs = [0, 1, 2, 3, 4] waves_presents = [['receiver', '', '', 'all'], ['receiver', 'emitter', '', 'all'], ['receiver', '', 'jammer', 'all'], ['receiver', 'emitter', 'jammer', 'all'], ] # ['', '', '', ''],['receiver', '', '', 'all'], symbols = ['', '', '', '', ''] time_array = time_array * 1000 color0_burst = 'darkgreen' color01 = 'green' color02 = 'red' color012 = 'orange' colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012] extracted = [False, True, True, True] ax_w = [] for i in range(len(waves_presents)): ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array, g=gs[i], title_top=True, eod_fr=eod_fr, waves_present=waves_presents[i], ylim=ylim, xlim=xlim, color_am=colors_am[i], extracted=extracted[i], title=titles[i]) # 'intruder','receiver'#jammer_name ax_w.append(ax) if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) bar = False if bar: if i == 0: ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.16, '20 ms', va='center', fontsize=10, transform=ax.transAxes) printing = True if printing: print(time.time() - t3) # spike response array_chosen = 1 if d == 0: # fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3( [arrays[0], arrays[2], arrays[1], arrays[3]], [arrays_original[0], arrays_original[2], arrays_original[1], arrays_original[3]], spikes_pure, cell, grid0, mean_type, group_mean, xlim=xlim, row=1 + d * 3, array_chosen=array_chosen, color0_burst=color0_burst, color01=color01, color02=color02, color012=color012) suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str( group_mean[0][2][0]) + ' DF1: ' + str( group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str( group_mean[0][1]) + '$\%$ m2: ' + str( group_mean[0][2][1]) + ' DF2: ' + str( group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str( len(group_mean[1])) + fr_end individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str( DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type axes = [] axes.append(ax_w) axes.extend(np.transpose(ax_as)) axes.append(np.transpose(ax_ps)) fig.tag(ax_w, xoffs=-1.5, yoffs=1.4) if save: save_visualization(individual_tag=individual_tag, show=show, pdf=True) return suptitle def csvReader(filename): context = open(filename).read(2048) dialect = csv.Sniffer().sniff(context) return csv.reader(open(filename), dialect) def plot_arrays_ROC_psd_single(arrays, arrays_original, spikes_pure, cell, grid0, mean_type, group_mean, rocextra=False, xlim=[0, 100], row=4, way='absolut', color0='green', color0_burst='darkgreen', color01='blue', ylim_log=(-13.5, 3), add_burst_corr=False, color02='red', array_chosen=1, color012='orange'): arrs = [] for a, arr in enumerate(arrays): time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40) if len(xlim) > 0: arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])]) else: arrs.append(np.array(arr[0])) ylim = [-2, np.max(arrs) + 30] ps = {} p_means = {} p_means_all = {} ax_ps = [] key_names = ['base_0', 'control_02', 'control_01', '012'] names = ['0', '02', '01', '012'] # color012color0, color02, color01, colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0] colors_p = [color0, color02, color01, color012, color02, color01, color0_burst, color0_burst, color0, color0] xlim_psd = [0, 1000] ylim_psd = [] # [-40, 10] color_psd = 'black' ax_as = [] for j in range(len(arrays)): ax0 = plt.subplot(grid0[row, j]) ax_a = [] ax_a.append(ax0) for i in range(len(spikes_pure[key_names[j]])): ax0.eventplot(spikes_pure[key_names[j]], color=colors[j]) ax0.show_spines('') ax0.set_xticks([]) ax0.set_yticks([]) if len(xlim) > 0: ax0.set_xlim(xlim) ax00 = plt.subplot(grid0[row + 1, j]) ax_a.append(ax00) # hier wird nur der erste Array geplottet time_array = np.arange(0, len(arrays[j][0]) / 40, 1 / 40) # embed() if rocextra: if '_AllTrialsIndex' in mean_type: pass else: pass else: pass try: if '_AllTrialsIndex' in mean_type: ax00.plot(time_array, arrays[j][array_chosen], color=colors[j]) else: ax00.plot(time_array, arrays[j][0], color=colors[j]) except: print('array thing') embed() if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2' pass if len(xlim) > 0: ax00.set_xlim(xlim) ax00.set_ylim(ylim) ax00.show_spines('') ax00.set_xticks([]) ax00.set_yticks([]) if j == 0: length = 20 plus_here = 5 try: ax00.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0] ax00.yscalebar(-0.02, 0.35, 500, 'Hz', va='center', ha='left') except: ax00.plot([0, length], [ylim[0] + plus_here, ylim[0] + plus_here], color='black') ax00.text(0, -0.2, str(length) + ' ms', va='center', fontsize=10, transform=ax00.transAxes) if len(xlim) > 0: ax00.plot([xlim[0] + 0.01, xlim[0] + 0.01], [ylim[0], 500], color='black') else: ax00.plot([time_array[0] + 0.01, time_array[0] + 0.01], [ylim[0], 500], color='black') ax00.text(-0.1, 0.4, ' 500 Hz', rotation=90, va='center', fontsize=10, transform=ax00.transAxes) # plot the corresponding psds # hier kann man aussuchen welches power spektrum machen haben will nfft = 2 ** 13 # 2 ** 18 # 17#16 p_mean_all_here = [] if '_AllTrialsIndex' in mean_type: range_here = [array_chosen] else: range_here = range(len(arrays[j])) for i in range_here: p_type = '05' if 'original' in p_type: p_mean_all, f = ml.psd(arrays_original[j][i] - np.mean(arrays_original[j][i]), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # else: p_mean_all, f = ml.psd(arrays[j][i] - np.mean(arrays[j][i]), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # p_mean_all_here.append(p_mean_all) p_means_all[names[j]] = p_mean_all_here ax_as.append(ax_a) # das machen wir nochmal für einen gemeinsamen Ref Wert for j in range(len(arrays)): log = 'log' # '' # 'log'#''#'log'#''# ref, ax00 = plt_single_pds(nfft, f, p_means, p_means_all[names[j]], ylim_psd, xlim_psd, color_psd, names, ps, arrays, ax_ps, grid0, row + 1, j, p_means_all, psd_type='mean_freq', log=log) if j == 0: if log == 'log': ax00.set_ylabel('dB') else: ax00.set_ylabel('Hz/Hz$^2$') ax00.set_xlim(xlim_psd) DF1 = group_mean[1].DF1.iloc[-1] DF2 = group_mean[1].DF2.iloc[-1] fr_isis = [] if add_burst_corr: frs_burst_corr = [] for i in range(len(spikes_pure['base_0'])): fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc, lim_here = find_lim_here(cell, 'individual') print(lim_here) eod_fr = group_mean[1].EODf.iloc[i] spikes_all = spikes_pure['base_0'][i] isi = calc_isi(spikes_all, eod_fr) if np.min(isi) < lim_here: hists2, spikes_ex, fr_burst_corr = correct_burstiness(isi, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all), lim=lim_here, burst_corr='individual') frs_burst_corr.append(fr_burst_corr) else: frs_burst_corr.append(fr_isis[-1]) else: for i in range(len(spikes_pure['base_0'])): fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc, fr_isi = np.nanmean(fr_isis) freqs = [fr_isi, np.abs(DF2), np.abs(DF1), np.abs(DF1) + np.abs(DF2), 2 * np.abs(DF2), 2 * np.abs(DF1), ] try: labels = ['Baseline=' + str(int(np.round(fr_isi))) + 'Hz', 'DF1=' + str(DF2) + 'Hz', 'DF2=' + str(DF1) + 'Hz', '$|$DF1+DF2$|$=' + str(np.abs(DF1) + np.abs(DF2)) + 'Hz', 'DF1$_{H}$=' + str(DF2 * 2) + 'Hz', 'DF2$_{H}$=' + str(DF1 * 2) + 'Hz', 'fr_burst_corr_individual', 'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given', 'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz', except: print('label thing') embed() if add_burst_corr: frs_burst_corr_mean = np.nanmean(frs_burst_corr) freqs.extend([ frs_burst_corr_mean]) # np.abs(np.abs(DF1) - np.abs(DF2)),,np.array(np.nanmax(frame_spikes['highest_fr'])),np.array(np.nanmax(frame_spikes['highest_fr_burst_corr_individual'])) labels.extend(['Baseline_Burstcorr']) colors_p.extend(['pink']) choice = [[0], [1, 4], [2], [0, 1, 2, 3, 6]] else: choice = [[0], [1, 4], [2], [0, 1, 2, 3]] if log == 'log': pp = 10 * np.log10(p_means_all[names[j]] / ref) pp_mean = 10 * np.log10(np.mean(p_means_all[names[j]], axis=0) / ref) else: pp = p_means_all[names[j]] pp_mean = np.mean(p_means_all[names[j]], axis=0) try: # todo: if log müsste hier was anderes rein, das log veränderte nämlich! plt_peaks_several(np.array(freqs)[choice[j]], pp, ax00, pp_mean, f, np.array(labels)[choice[j]], j, np.array(colors_p)[choice[j]], add_log=2.5, exact=False, text_extra=True, perc_peaksize=0.08, ms=14, clip_on=True, log=log) # True except: print('peaks thing0') embed() if log == 'log': ax00.set_ylim(ylim_log) ax00.show_spines('b') if j == 0: ax00.yscalebar(-0.02, 0.5, 10, 'dB', va='center', ha='left') ax00.get_shared_y_axes().join(*ax_ps) return fr_isi, ax_ps, ax_as def plot_arrays_ROC_psd_single4(arrays, arrays_original, spikes_pure, cell, grid0, mean_type, group_mean, names=['0', '02', '01', '012'], xlim=[0, 100], row=4, way='absolut', datapoints=1000, xlim_psd=[0, 235], color0='blue', color0_burst='darkgreen', color01='green', ylim_log=(-15, 3), add_burst_corr=False, color02='red', array_chosen=1, text_extra=True, color012_minus='purple', color012='orange', log='log'): arrs = [] for a, arr in enumerate(arrays): time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40) if len(xlim) > 0: arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])]) else: arrs.append(np.array(arr[0])) ylim = [-2, np.max(arrs) + 30] ax_ps = [] key_names = ['base_0', 'control_02', 'control_01', '012'] colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0] colors_p = [color0, color02, color01, color012, color02, color01, color012_minus, color0_burst, color0_burst, color0, color0] ylim_psd = [] # [-40, 10] color_psd = 'black' ax_as = [] for j in range(len(arrays)): ################################### # plt spikes try: ax0 = plt.subplot(grid0[row, j]) plt_spikes_ROC(ax0, colors[j], spikes_pure[key_names[j]], xlim) ax_a = [] ax_a.append(ax0) except: print('ax something') embed() ######################################### ax00 = plt.subplot(grid0[row + 1, j]) ax_a.append(ax00) time_array = plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type, way, xlim, ylim) var_val = np.var(arrays[3]) - np.var(arrays[2]) - np.var(arrays[1]) + np.var(arrays[0]) print('mean var val:' + str(var_val)) p_means_all = {} for j in range(len(arrays)): ######################################## # get the corresponding psds # hier kann man aussuchen welches power spektrum machen haben will f, nfft = get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all) ax_as.append(ax_a) # plot the psds ps = {} p_means = {} ax00, fr_isi = plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all, ps, row, spikes_pure, time_array, names=names, color_psd=color_psd, add_burst_corr=add_burst_corr, xlim_psd=xlim_psd, ylim_log=ylim_log, ylim_psd=ylim_psd, log=log, text_extra=text_extra) ax00.get_shared_y_axes().join(*ax_ps) return fr_isi, ax_ps, ax_as def plot_arrays_ROC_psd_single3(arrays, arrays_original, spikes_pure, cell, grid0, mean_type, group_mean, names=['0', '02', '01', '012'], xlim=[0, 100], row=4, way='absolut', datapoints=1000, xlim_psd=[0, 235], color0='blue', color0_burst='darkgreen', color01='green', ylim_log=(-15, 3), add_burst_corr=False, color02='red', array_chosen=1, text_extra=True, color012_minus='purple', color012='orange', log='log'): arrs = [] for a, arr in enumerate(arrays): time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40) if len(xlim) > 0: arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])]) else: arrs.append(np.array(arr[0])) ylim = [-2, np.max(arrs) + 30] ax_ps = [] key_names = ['base_0', 'control_02', 'control_01', '012'] colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0] colors_p = [color0, color02, color01, color012, color02, color01, color012_minus, color0_burst, color0_burst, color0, color0] ylim_psd = [] # [-40, 10] color_psd = 'black' ax_as = [] for j in range(len(arrays)): ################################### # plt spikes ax0 = plt.subplot(grid0[row, j]) plt_spikes_ROC(ax0, colors[j], spikes_pure[key_names[j]], xlim) ax_a = [] ax_a.append(ax0) ######################################### ax00 = plt.subplot(grid0[row + 1, j]) ax_a.append(ax00) time_array = plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type, way, xlim, ylim) var_val = np.var(arrays[3]) - np.var(arrays[2]) - np.var(arrays[1]) + np.var(arrays[0]) print('mean var val:' + str(var_val)) p_means_all = {} for j in range(len(arrays)): ######################################## # get the corresponding psds # hier kann man aussuchen welches power spektrum machen haben will f, nfft = get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all) ax_as.append(ax_a) # plot the psds ps = {} p_means = {} ax00, fr_isi = plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all, ps, row, spikes_pure, time_array, names=names, color_psd=color_psd, add_burst_corr=add_burst_corr, xlim_psd=xlim_psd, ylim_log=ylim_log, ylim_psd=ylim_psd, log=log, text_extra=text_extra) ax00.get_shared_y_axes().join(*ax_ps) return fr_isi, ax_ps, ax_as def motivation_all_small_stim(dev_desired = '1',ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'], figsize=None, redo=False, save=True, end='0', cut_matrix='malefemale', chose_score='mean_nrs', a_fr=1, restrict='modulation', adapt='adaptoffsetallall2', step=str(30), detections=['AllTrialsIndex'], variant='no', sorted_on='LocalReconst0.2Norm'): autodefines = [ 'triangle_diagonal_fr'] # ['triangle_fr', 'triangle_diagonal_fr', 'triangle_df2_fr','triangle_df2_eodf''triangle_df1_eodf', ] # ,'triangle_df2_fr''triangle_df1_fr','_triangle_diagonal__fr',] cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ] c1s = [10] # 1, 10, c2s = [10] plot_style() default_figsize(column=2, length=3.3) #6.7 ts=12, ls=12, fs=12 show = True DF2_desired = [0.8] DF1_desired = [0.87] DF2_desired = [-0.23] DF1_desired = [0.94] # mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_' extract = '' datasets, data_dir = find_all_dir_cells() cells = ['2022-01-28-ah-invivo-1'] # , '2022-01-28-af-invivo-1', '2022-01-28-ab-invivo-1', # '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ] append_others = 'apend_others' # '#'apend_others'#'apend_others'#'apend_others'##'apend_others' autodefine = '_DFdesired_' autodefine = 'triangle_diagonal_fr' # ['triangle_fr', 'triangle_diagonal_fr', 'triangle_df2_fr','triangle_df2_eodf''triangle_df1_eodf', ] # ,'triangle_df2_fr''triangle_df1_fr','_triangle_diagonal__fr',] DF2_desired = [-33] DF1_desired = [133] autodefine = '_dfchosen_closest_' autodefine = '_dfchosen_closest_first_' cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ] # c1s = [10] # 1, 10, # c2s = [10] minsetting = 'Min0.25sExcluded' c2 = 10 # detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort' # detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch'] extend_trials = '' # 'extended'#''#'extended'#''#'extended'#''#'extended'#''#'extended'#''#'extended'# ok kein Plan was das hier ist # phase_sorting = ''#'PhaseSort' eodftype = '_psdEOD_' concat = '' # 'TrialsConcat' indices = ['_allindices_'] chirps = [ ''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_' extract = '' # '_globalmax_' devs_savename = ['original', '05'] # ['05']##################### # control = pd.read_pickle( # load_folder_name( # 'calc_model') + '/modell_all_cell_no_sinz3_afe0.1__afr1__afj0.1__length1.5_adaptoffsetallall2___stepefish' + step + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096.pkl') if len(cells) < 1: data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets) cells, p_units_cells, pyramidals = restrict_cell_type(cells, 'p-units') # default_settings(fs=8) start = 'min' # cells = ['2021-08-03-ac-invivo-1'] tag_cells = [] for c, cell in enumerate(cells): counter_pic = 0 contrasts = [c2] tag_cell = [] for c, contrast in enumerate(contrasts): contrast_small = 'c2' contrast_big = 'c1' contrasts1 = [c1] for contrast1 in contrasts1: for devname_orig in devs: datapoints = [1000] for d in datapoints: ################################ # prepare DF1 desired # chose_score = 'auci02_012-auci_base_01' # hier muss das halt stimmen mit der auswahl # hier wollen wir eigntlich kein autodefine # sondern wir wollen so ein diagonal ding haben extra_f_calculatoin = False if extra_f_calculatoin: divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value( DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small, contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine, cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01' DF1_desired = [1.2]#DF1_desired # [::-1] DF2_desired = [0.95]#DF2_desired # [::-1] #embed() ####################################### # ROC part # fr, celltype = get_fr_from_info(cell, data_dir[c]) version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() b = load_b_public(c, cell, data_dir) mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell) counter_waves = 0 mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)] for gg in range(len(DF1_desired)): # try: t3 = time.time() # except: # print('time thing') # embed() ax_w = [] ################### # all trials in one grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2'], as_index=False) # try: grouped_mean = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True, emb=False, concat=True) # except: # print('grouped thing') # embed() ################### # groups sorted by repro tag # todo: evnetuell die tuples gleich hier umspeichern vom csv '' grouped = mt_sorted.groupby( ['c1', 'c2', 'm1, m2', 'repro_tag_id'], as_index=False) grouped_orig = chose_certain_group(DF1_desired[gg], DF2_desired[gg], grouped, several=True) gr_trials = len(grouped_orig) ################### groups_variants = [grouped_mean] group_mean = [grouped_orig[0][0], grouped_mean] for d, detection in enumerate(detections): mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat ############################################################## # load plotting arrays arrays, arrays_original, spikes_pure = save_arrays_susept( data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0, rocextra=False, sorted_on=sorted_on, dev_desired = dev_desired) #################################################### #################################################### # hier checken wir ob für diesen einen Punkt das funkioniert mit der standardabweichung try: check_var_substract_method(spikes_pure) except: print('var checking not possible') # fig = plt.figure() # grid = gridspec.GridSpec(2, 1, wspace=0.7, left=0.05, top=0.95, bottom=0.15, # right=0.98) if figsize: fig = plt.figure(figsize=figsize) else: fig = plt.figure() grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.35, left=0.055, top=top, bottom=0.15, right=0.935) # height_ratios=[1, 2], height_ratios = [1,6]bottom=0.25, top=0.8, hr = [1, 0.35, 1.2, 0, 3, ] # 1 ########################################################################## # several coherence plot # frame_psd = pd.read_pickle(load_folder_name('calc_RAM')+'/noise_data11_nfft1sec_original__StimPreSaved4__first__CutatBeginning_0.05_s_NeurDelay_0.005_s_2021-08-03-ab-invivo-1.pkl') # frame_psd = pd.read_pickle(load_folder_name('calc_RAM') + '/noise_data11_nfft1sec_original__StimPreSaved4__first__CutatBeginning_0.05_s_NeurDelay_0.005_s_2021-08-03-ab-invivo-1.pkl') coh = False if coh: ax_w, d, data_dir, devs = plt_coherences(ax_w, d, devs, grid) # ax_cohs = plt.subplot(grid[0,1]) ########################################################################## # part with the power spectra grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35, subplot_spec=grid[:, :], height_ratios=hr) xlim = [0, 100] ########################################### stimulus_length = 0.3 deltat = 1 / 40000 eodf = np.mean(group_mean[1].eodf) eod_fr = eodf a_fr = 1 eod_fe = eodf + np.mean( group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fe = group_mean[0][1] / 100 eod_fj = eodf + np.mean( group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = group_mean[0][0] / 100 variant_cell = 'no' # 'receiver_emitter_jammer' print('f0' + str(eod_fr)) print('f1'+str(eod_fe)) print('f2' + str(eod_fj)) eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj) jammer_name = 'female' cocktail_names = False if cocktail_names: titles = ['receiver ', '+' + 'intruder ', '+' + jammer_name, '+' + jammer_name + '+intruder', []] ##'receiver + ' + 'receiver + receiver else: titles = title_motivation() gs = [0, 1, 2, 3, 4] waves_presents = [['receiver', '', '', 'all'], ['receiver', 'emitter', '', 'all'], ['receiver', '', 'jammer', 'all'], ['receiver', 'emitter', 'jammer', 'all'], ] # ['', '', '', ''],['receiver', '', '', 'all'], # ['receiver', '', 'jammer', 'all'], # ['receiver', 'emitter', '', 'all'],'receiver', 'emitter', 'jammer', symbols = [''] # '$+$', '$-$', '$-$', '$=$', symbols = ['', '', '', '', ''] time_array = time_array * 1000 color01, color012, color01_2, color02, color0_burst, color0 = colors_suscept_paper_dots() colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012] extracted = [False, True, True, True] extracted2 = [False, False, False, False] for i in range(len(waves_presents)): ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array, g=gs[i], title_top=True, eod_fr=eod_fr, waves_present=waves_presents[i], ylim=ylim, xlim=xlim, color_am=colors_am[i], color_am2 = color01_2, extracted=extracted[i], extracted2=extracted2[i], title=titles[i]) # 'intruder','receiver'#jammer_name ax_w.append(ax) if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) bar = False if bar: if i == 0: ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.16, '20 ms', va='center', fontsize=10, transform=ax.transAxes) printing = True if printing: print(time.time() - t3) ########################################## # spike response array_chosen = 1 if d == 0: # #embed() frs = [] for i in range(len(spikes_pure['base_0'])): #duration = spikes_pure['base_0'][i][-1] / 1000 duration = 0.5 fr = len(spikes_pure['base_0'][i])/duration frs.append(fr) fr = np.mean(frs) #embed() base_several = False if base_several: spikes_new = [] for i in range(len(spikes_pure['base_0'])): duration = 100 duration_full = 101#501 dur = np.arange(0, duration_full, duration) for d_nr in range(len(dur) - 1): #embed() spikes_new.append(np.array(spikes_pure['base_0'][i][ (spikes_pure['base_0'][i] > dur[d_nr]) & ( spikes_pure['base_0'][i] < dur[ d_nr + 1])])/1000-dur[d_nr]/1000) # spikes_pure['base_0'] = spikes_new sampling_rate = 1/np.diff(time_array) sampling_rate = int(sampling_rate[0]*1000) spikes_mats = [] smoothed05 = [] for i in range(len(spikes_new)): spikes_mat = cr_spikes_mat(spikes_new[i], sampling_rate, int(sampling_rate*duration/1000)) spikes_mats.append(spikes_mat) smoothed05.append(gaussian_filter(spikes_mat, sigma=(float(dev_desired)/1000) * sampling_rate)) smoothed_base = np.mean(smoothed05, axis=0) mat_base = np.mean(spikes_mats, axis=0) else: smoothed_base = arrays[0][0] mat_base = arrays_original[0][0] #embed()#arrays[0]v fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3( [[smoothed_base], arrays[2], arrays[1], arrays[3]], [[mat_base], arrays_original[2], arrays_original[1], arrays_original[3]], spikes_pure, cell, grid0, mean_type, group_mean, xlim=xlim, row=1 + d * 3, array_chosen=array_chosen, color0_burst=color0_burst, color01=color01, color02=color02,ylim_log=(-15, 3), color012=color012,color012_minus = color01_2,color0=color0) ########################################################################## individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str( DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type # save_all(individual_tag, show, counter_contrast=0, savename='') # print('individual_tag') axes = [] axes.append(ax_w) # axes.extend(np.transpose(ax_as)) # axes.append(np.transpose(ax_ps)) # np.transpose(axes) #fig.tag(ax_w[0:3], xoffs=-2.3, yoffs=1.7) #fig.tag(ax_w[3::], xoffs=-1.9, yoffs=1.4) fig.tag(ax_w, xoffs=-1.9, yoffs=1.4) # ax_w, np.transpose(ax_as), ax_ps if save: save_visualization(individual_tag=individual_tag, show=show, pdf=True) # fig = plt.gcf() # fig.savefig # plt.show() def plt_spikes_ROC(ax0, colors, spikes_pure, xlim, lw=None): if lw: ax0.eventplot(spikes_pure, color=colors, linewidth=lw) else: ax0.eventplot(spikes_pure, color=colors) ax0.show_spines('') ax0.set_xticks([]) ax0.set_yticks([]) if len(xlim) > 0: ax0.set_xlim(xlim) def plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all, ps, row, spikes_pure, time_array, names=['0', '02', '01', '012'], color_psd='black', add_burst_corr=False, xlim_psd=[0, 235], clip_on=True, ms=14, labels=[], ax00s=[], choice=[], marker='o', text_extra=True, alphas=[], range_plot=[], ylim_log=(-15, 3), ylim_psd=[], log='log', ax01=None): psd_type = 'mean_freq' if not range_plot: range_plot = range(len(arrays)) for j in range_plot: try: ref, ax00 = plt_single_pds(nfft, f, p_means, p_means_all[names[j]], ylim_psd, xlim_psd, color_psd, names, ps, arrays, ax_ps, grid0, row + 1, j, p_means_all, ax00s=ax00s, ax00=ax01, psd_type='mean_freq', log=log) except: print('ref not working') embed() if j == 0: if log == 'log': ax00.set_ylabel('dB') else: ax00.set_ylabel('Hz/Hz$^2$') ax00.set_xlim(xlim_psd) DF1 = group_mean[1].DF1.iloc[-1] DF2 = group_mean[1].DF2.iloc[-1] fr_isis = [] if add_burst_corr: frs_burst_corr = get_burst_corr_peak(cell, fr_isis, group_mean, spikes_pure) else: for i in range(len(spikes_pure['base_0'])): fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc, fr_isi = np.nanmean(fr_isis) freqs = [fr_isi, np.abs(DF2), np.abs(DF1), np.abs(DF1) + np.abs(DF2), 2 * np.abs(DF2), 2 * np.abs(DF1), np.abs(np.abs(DF1) - np.abs(DF2)), ] try: if not labels: if j == 3: labels_inside = ['', '', '', fsum_core(DF1, DF2), '', '', fdiff_core(DF1, DF2), 'fr_bc', 'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given', 'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz', elif j == 2: labels_inside = ['', df1_core(DF2), df2_core(DF1), fsum_core(DF1, DF2), f1_core(DF2), f2_core(DF1), fdiff_core(DF1, DF2), 'fr_bc', 'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given', 'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz', else: labels_inside = labels_all_motivation(DF1, DF2, fr_isi) else: labels_inside = labels except: print('label thing2') embed() #embed() if add_burst_corr: if not choice: choice = update_burst_corr_peaks(colors_p, freqs, frs_burst_corr, labels_inside) rots = [[45], [45, 45], [45], [45, 45, 45, 45, 45, 45]] extra = [] left = 40 else: if not choice: choice = [[0], [1, 4], [0, 2], [0, 1, 2, 3, 4, 6]] rots = [[0], [0, 0], [0, 0], [55, 55, 57, 45, 45, 45]] # 45 lefts = [[10], [25, 3], [0, 105], [10, 10, 10, 13, 12, 40]] # 40 extras = [[1], [1, 1], [1, 1], [1, 1, 2.5, 1.7, 4, 4]] # 4,1 extra = extras[j] try: left = np.array(lefts)[j] except: print('left something') embed() pp, pp_mean = decide_log_ROCs(j, log, names, p_means_all, ref) try: # todo: if log müsste hier was anderes rein, das log veränderte nämlich!#2.5 plt_peaks_several(np.array(freqs)[choice[j]], pp, ax00, pp_mean, f, np.array(labels_inside)[choice[j]], j, np.array(colors_p)[choice[j]], marker=marker, add_texts=extra, texts_left=left, add_log=1.5, rots=np.array(rots)[j], exact=False, text_extra=text_extra, perc_peaksize=0.08, alphas=alphas, ms=ms, clip_on=clip_on, log=log) # True except: # freqs, p_arrays, axs_p, p0_means, fs, labels=None, j=1, colors=None, print('peaks thing2') embed() if log == 'log': ax00.set_ylim(ylim_log) ax00.show_spines('b') if log == 'log': if j == 0: ax00.yscalebar(-0.02, 0.5, 10, 'dB', va='center', ha='left') return ax00, fr_isi def labels_all_motivation(DF1, DF2, fr_isi): labels = [r'$f'+basename_small()+'=%s$' % (int(np.round(fr_isi))) + '\,Hz', df1_core(DF2), df2_core(DF1), fsum_core(DF1, DF2), f1_core(DF2), f2_core(DF1), fdiff_core(DF1, DF2), 'fr_bc', 'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given', 'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz', return labels def df2_core(DF1): return '$|\Delta f_{2}|=|f_{2}-$' + f_eod_name_rm() + '$|=%s$' % (np.abs(DF1)) + '\,Hz' def df1_core(DF2): return '$|\Delta f_{1}|=|f_{1}-$' + f_eod_name_rm() + '$|=%s$' % (np.abs(DF2)) + '\,Hz' def f2_core(DF1): return '$2 |\Delta f_{2}|=%s$' % (DF1 * 2) + '\,Hz' def f1_core(DF2): return '$2 |\Delta f_{1}|=%s$' % (np.abs(DF2) * 2) + '\,Hz' def fdiff_core(DF1, DF2): return '$||\Delta f_{1}|-|\Delta f_{2}||=%s$' % (np.abs(np.abs(DF1) - np.abs(DF2))) + '\,Hz' def fsum_core(DF1, DF2): return '$||\Delta f_{1}| + |\Delta f_{2}||=%s$' % (np.abs(DF1) + np.abs(DF2)) + '\,Hz' # ) def decide_log_ROCs(j, log, names, p_means_all, ref): if log == 'log': pp = 10 * np.log10(p_means_all[names[j]] / ref) pp_mean = 10 * np.log10(np.mean(p_means_all[names[j]], axis=0) / ref) else: pp = p_means_all[names[j]] pp_mean = np.mean(p_means_all[names[j]], axis=0) return pp, pp_mean def update_burst_corr_peaks(colors_p, freqs, frs_burst_corr, labels): frs_burst_corr_mean = np.nanmean(frs_burst_corr) freqs.extend([ frs_burst_corr_mean]) # np.abs(np.abs(DF1) - np.abs(DF2)),,np.array(np.nanmax(frame_spikes['highest_fr'])),np.array(np.nanmax(frame_spikes['highest_fr_burst_corr_individual'])) labels.extend(['Baseline_Burstcorr']) colors_p.extend(['pink']) choice = [[0, 7], [1, 4], [2], [0, 1, 2, 3, 7]] return choice def get_burst_corr_peak(cell, fr_isis, group_mean, spikes_pure): frs_burst_corr = [] for i in range(len(spikes_pure['base_0'])): fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc, lim_here = find_lim_here(cell, 'individual') eod_fr = group_mean[1].EODf.iloc[i] spikes_all = spikes_pure['base_0'][i] isi = calc_isi(np.array(spikes_all) / 1000, eod_fr) if np.min(isi) < lim_here: hists2, spikes_ex, fr_burst_corr = correct_burstiness([isi], [spikes_all], [eod_fr], [eod_fr], lim=lim_here, burst_corr='individual') frs_burst_corr.append(fr_burst_corr) else: frs_burst_corr.append(fr_isis[-1]) return frs_burst_corr def get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all, nfft=2 ** 13): p_mean_all_here = [] if 'AllTrialsIndex' in mean_type: # AllTrialsIndex range_here = [array_chosen] print('alltrials choice') else: range_here = range(len(arrays[j])) for i in range_here: p_type = '05' if 'original' in p_type: p_mean_all, f = ml.psd(arrays_original[j][i] - np.mean(arrays_original[j][i]), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # else: p_mean_all, f = ml.psd(arrays[j][i] - np.mean(arrays[j][i]), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # p_mean_all_here.append(p_mean_all) try: p_means_all[names[j]] = p_mean_all_here except: print('assign p problem') embed() return f, nfft def plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type, way, xlim, ylim): # hier wird nur der erste Array geplottet time_array = np.arange(0, len(arrays[j][0]) / 40, 1 / 40) try: if '_AllTrialsIndex' in mean_type: ax00.plot(time_array, arrays[j][array_chosen], color=colors[j]) else: ax00.plot(time_array, arrays[j][0], color=colors[j]) except: print('array thing') embed() if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2' pass if len(xlim) > 0: ax00.set_xlim(xlim) ax00.set_ylim(ylim) ax00.show_spines('') ax00.set_xticks([]) ax00.set_yticks([]) if j == 0: length = 20 plus_here = 5 try: ax00.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0] ax00.yscalebar(-0.02, 0.35, 500, 'Hz', va='center', ha='left') except: ax00.plot([0, length], [ylim[0] + plus_here, ylim[0] + plus_here], color='black') ax00.text(0, -0.2, str(length) + ' ms', va='center', fontsize=10, transform=ax00.transAxes) if len(xlim) > 0: ax00.plot([xlim[0] + 0.01, xlim[0] + 0.01], [ylim[0], 500], color='black') else: ax00.plot([time_array[0] + 0.01, time_array[0] + 0.01], [ylim[0], 500], color='black') ax00.text(-0.1, 0.4, ' 500 Hz', rotation=90, va='center', fontsize=10, transform=ax00.transAxes) return time_array def save_arrays_susept(data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group, rocextra, sorted_on='LocalReconst0.2Norm', dev_desired='1', mean_type0=''): # '_MeanTrialsIndex' version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() load_name = find_load_function() if (version_comp == 'develop') | (version_comp == 'code'): full_path = find_nix_full_path(c, cell, data_dir) if os.path.exists(full_path): # todo: this maybe also has to be fixed print('do ' + cell) file = nix.File.open(full_path, nix.FileMode.ReadOnly) b = file.blocks[0] all_mt_names, mt_names, t_names = get_all_nix_names(b, what='Three') if mt_names: nix_there = check_nix_fish(b) if nix_there: ## printing = True t3 = time.time() spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three( group_mean, b, extract, chirps=chirps, emb=False, mean_type=mean_type, sorted_on=sorted_on, devname_orig=[dev_desired, '05', 'original']) # times_sort=times_sort, if printing: # todo: also das dauert lange das könnte man optimizieren print('arrays1 ' + str(time.time() - t3)) try: pass except: print('dev nr thing') embed() delays_length = define_delays_trials(mean_type, frame, sorted_on=sorted_on) printing = True t3 = time.time() if ('Phase' not in mean_type0) & (mean_type0 != ''): for i in range(len(delays_length['base_0'])): delays_length['base_0'][i] = np.arange(0, delays_length['base_0'][i][-1], 1) try: array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials( frame, dev_desired, delays_length, mean_type) except: print('sorting thing') embed() array0_original, array01_original, array02_original, array012_original, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials( frame, 'original', delays_length, mean_type) if printing: print('arrays2 ' + str(time.time() - t3)) if 'TrialsConcat' in mean_type: array0 = array0[0] array01 = array01[0] array02 = array02[0] array012 = array012[0] test = False if test: from utils_test import test_arrays test_arrays() # array0, array01, array02, array012 printing = True t3 = time.time() if 'Mean' not in mean_type: delays_length_m = define_delays_trials('_MeanTrialsIndexPhaseSort_Min0.25sExcluded_', frame, sorted_on=sorted_on) if printing: print('arrays3 ' + str(time.time() - t3)) if rocextra: arrays = [[array0, array01], [array02, array012]] arrays = arrays[plot_group] else: arrays = [array0, array01, array02, array012] arrays_original = [array0_original, array01_original, array02_original, array012_original] names = ['base_0', 'control_01', 'control_02', '012'] spikes_all_out = {} for n, name in enumerate(names): spikes_all = [] for s, sp in enumerate(np.array(spikes_pure[name])): spikes = spikes_pure[name].iloc[s] * 1000 if name != 'base_0': if s != 0: if len(delays_length) > 0: cut = delays_length[name][s - 1][0] / 40 spikes = spikes[spikes > cut] - cut else: try: cut = delays_length_m[name][s - 1][0] / 40 except: print('delay thing') embed() spikes = spikes[spikes > cut] - cut spikes_all.append(np.array(spikes)) spikes_all_out[name] = spikes_all if version_comp == 'develop': for n, name in enumerate(names): save_here = save_csv_to_pandas(arrays[n]) save_here.to_csv(load_name + '_05_' + name + '.csv') save_here = save_csv_to_pandas(arrays_original[n]) save_here.to_csv(load_name + '_original_' + name + '.csv') save_here = save_csv_to_pandas(spikes_all_out[name]) save_here.to_csv(load_name + '_spikes_' + name + '.csv') # todo: das noch mit den normalen spike resave funktionen machen test = False if test: from utils_test import test_arrays test_arrays() # todo: hier arrays, arrays_original und spikes_pure speichern elif version_comp == 'public': spikes_all_out = {} arrays_original = [] arrays = [] names = ['base_0', 'control_01', 'control_02', '012'] for n, name in enumerate(names): spikes = pd.read_csv(load_name + '_spikes_' + name + '.csv', index_col=0) array_o = pd.read_csv(load_name + '_original_' + name + '.csv', index_col=0) array_05 = pd.read_csv(load_name + '_05_' + name + '.csv', index_col=0) spikes_all_out[name] = np.array(np.transpose(spikes)) arrays_original.append(np.array(np.transpose(array_o))) arrays.append(np.array(np.transpose(array_05))) return arrays, arrays_original, spikes_all_out def find_nix_full_path(c, cell, data_dir): base = cell.split(os.path.sep)[-1] + ".nix" if data_dir == '': path = '../data/ThreeFish/' + cell else: path = load_folder_name('data') + data_dir[c] + cell full_path = path + '/' + base return full_path def plt_single_pds(nfft, f, p_means, p_mean_all_here, ylim_psd, xlim_psd, color_psd, names, ps, arrays, ax_ps, grid0, row, j, p_means_all, psd_type='mean_freq', ax00s=[], log='log', ax00=None): if psd_type == 'single': ref = np.max([p_means_all['012'][0], p_means_all['01'][0], p_means_all['02'][0], p_means_all['0'][0]]) if not ax00: if not ax00s: ax00 = ax00s[j] else: ax00 = plt.subplot(grid0[row + 2, j]) ax_ps.append(ax00) nfft = 2 ** 16 p, f = ml.psd(arrays[j][0] - np.mean(arrays[j][0]), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # ps[names[j]] = p p = log_calc_psd(log, p, ref) ax00.plot(f, p, color=color_psd) ax00.set_xlim(xlim_psd) if len(ylim_psd) > 0: ax00.set_ylim(ylim_psd) remove_xticks(ax00) if j != 0: remove_yticks(ax00) elif psd_type == 'mean_temporal': if not ax00: ax00 = plt.subplot(grid0[row + 2, j]) ax_ps.append(ax00) # hier mache ich noch den temporal mean p_mean, f_mean = ml.psd(np.mean(arrays[j], axis=0) - np.mean(np.mean(arrays[j], axis=0)), Fs=40000, NFFT=nfft, noverlap=nfft // 2) # p = log_calc_psd(log, p_mean, ref) ax00.plot(f, p, color=color_psd) ax00.set_xlim(xlim_psd) if len(ylim_psd) > 0: ax00.set_ylim(ylim_psd) p_means[names[j]] = p_mean remove_xticks(ax00) if j != 0: remove_yticks(ax00) # hier mache ich einen mean über die differenz elif psd_type == 'all': if not ax00: ax00 = plt.subplot(grid0[row + 2, j]) ax_ps.append(ax00) for p in p_mean_all_here: p = log_calc_psd(log, p, ref) ax00.plot(f, p, color='grey') ax00.set_xlim(xlim_psd) if len(ylim_psd) > 0: ax00.set_ylim(ylim_psd) remove_xticks(ax00) if j != 0: remove_yticks(ax00) elif psd_type == 'mean_freq': array_here = [] for name in names: array_here.append(np.mean(p_means_all[name], axis=0)) ref = np.max(array_here) if not ax00: ax00 = plt.subplot(grid0[row + 2, j]) ax_ps.append(ax00) if log == 'log': ax00.plot(f, 10 * np.log10(np.mean(p_mean_all_here, axis=0) / ref), color=color_psd) if j == 0: ax00.set_ylabel('dB') else: ax00.plot(f, np.mean(p_mean_all_here, axis=0), color=color_psd) ax00.set_xlim(xlim_psd) if len(ylim_psd) > 0: ax00.set_ylim(ylim_psd) try: ax00.set_xlabel('Frequency [Hz]') except: print('freq') embed() if j != 0: remove_yticks(ax00) return ref, ax00 def log_calc_psd(log, p, ref): if log == 'log': p = 10 * np.log10(p / ref) return p def chose_certain_group(DF1_desired, DF2_desired, grouped, concat=False, several=False, emb=False): if DF1_desired == 'all' or DF2_desired == 'all': return if several: if (type(DF1_desired) == float) & (type(DF2_desired) == float): restrict = group_the_certain_group_several(grouped, DF2_desired, DF1_desired, emb=False) if concat: key_list = list(map(tuple, grouped.groups.keys())) keys_r = np.array(key_list, dtype=object)[np.array(restrict)] groups = [] for r in keys_r: if len(groups) < 1: groups = grouped.get_group(tuple(r)) else: groups = pd.concat([groups, grouped.get_group(tuple(r))]) print(len(grouped.get_group(tuple(r)))) final_grouped = groups else: try: final_grouped = np.array(list(grouped), dtype=object)[restrict] except: keys_r = np.array(list(map(tuple, grouped.groups.keys())))[restrict] groups = [] for r in keys_r: if len(groups) < 1: groups = [tuple(r), grouped.get_group(tuple(r))] else: groups = pd.concat([groups, grouped.get_group(tuple(r))]) print(len(grouped.get_group(tuple(r)))) final_grouped = groups else: restricts = [] final_grouped = [] for d in range(len(DF1_desired)): restrict = group_the_certain_group_several(grouped, DF2_desired[d], DF1_desired[d], emb=False) print(restrict) if concat: keys_r = np.array(list(map(tuple, grouped.groups.keys())))[restrict] groups = [] for r in keys_r: if len(groups) < 1: groups = grouped.get_group(tuple(r)) else: groups = pd.concat([groups, grouped.get_group(tuple(r))]) final_grouped.append(groups) else: final_grouped.append(np.array(list(grouped))[restrict]) restricts.append(restrict) else: if (type(DF1_desired) == float) & (type(DF2_desired) == float): restrict = group_the_certain_group(grouped, DF2_desired, DF1_desired) grouped = list(grouped) # [::-1] final_grouped = grouped[restrict] else: restricts = [] final_grouped = [] for d in range(len(DF1_desired)): restrict = group_the_certain_group(grouped, DF2_desired[d], DF1_desired[d]) print(restrict) final_grouped.append(list(grouped)[restrict]) restricts.append(restrict) if emb: embed() return final_grouped def phase_sort_arrays(f, delays_length, frame_dev): if f != 0: if delays_length['012'][f - 1] != []: frame_dev.iloc[f]['012'] = frame_dev['012'].iloc[f][ np.arange(delays_length['012'][f - 1][0], len(frame_dev['012'].iloc[f]), 1)] # np.array(frame_dev['012'].iloc[f])[np.array(list(map(int, delays_length['012'][f - 1])))] if delays_length['control_01'][f - 1] != []: frame_dev.iloc[f]['control_01'] = frame_dev['control_01'].iloc[f][ np.arange(delays_length['control_01'][f - 1][0], len(frame_dev['control_01'].iloc[f]), 1)] # frame_dev['control_01'].iloc[f][delays_length['control_01'][f - 1]] if delays_length['control_02'][f - 1] != []: frame_dev.iloc[f]['control_02'] = frame_dev['control_02'].iloc[f][ np.arange(delays_length['control_02'][f - 1][0], len(frame_dev['control_02'].iloc[f]), 1)] if 'base_0' in delays_length.keys(): if delays_length['base_0'][f - 1] != []: frame_dev.iloc[f]['base_0'] = frame_dev['base_0'].iloc[f][ np.arange(delays_length['base_0'][f - 1][0], len(frame_dev['base_0'].iloc[f]), 1)] return frame_dev def cut_uneven_trials(frame, devname, mean_type, delays_length, sampling=40000): frame_dev = frame[frame['dev'] == devname] # wenn das alles phase sorted sein soll, werden die davor ausgerichtet # für das Threewave nicht notwenig length = [] for f in range(len(frame_dev['012'])): if 'Phase' in mean_type: frame_dev = phase_sort_arrays(f, delays_length, frame_dev) length.append([len(frame_dev['012'].iloc[f]), len(frame_dev['control_01'].iloc[f]), len(frame_dev['control_02'].iloc[f]), len(frame_dev['base_0'].iloc[f])]) ####################### # hier werden alle trials auf die gleiche Länge geschnitten array0_all, array01_all, array02_all, array012_all = cut_even_arrays(length, sampling, mean_type, frame_dev) return array012_all, array01_all, array02_all, array0_all def cut_even_arrays(length, sampling, mean_type, frame_dev): # hier sagen wir mindestens z.B. 0.25 S! if 'Min' in mean_type: # 0.25sExcluded # DEFAULT ms_exclude = float(mean_type.split('Min')[1].split('sExcluded')[0]) exclude_array = np.array(length) > ms_exclude * sampling length_min = np.min(np.array(length)[np.array(length) > 0.25 * sampling]) else: exclude_array = np.ones_like(length) length_min = np.min(length) array012_all = [] # [[]] * len(frame_dev['012']) array01_all = [] # [[]] * len(frame_dev['012']) array02_all = [] # [[]] * len(frame_dev['012']) array0_all = [] # [[]] * len(frame_dev['012']) for f in range(len(frame_dev['012'])): if exclude_array[f][0]: array012_all.append(frame_dev['012'].iloc[f][0:length_min]) if exclude_array[f][1]: array01_all.append(frame_dev['control_01'].iloc[f][0:length_min]) if exclude_array[f][2]: array02_all.append(frame_dev['control_02'].iloc[f][0:length_min]) if exclude_array[f][3]: array0_all.append(frame_dev['base_0'].iloc[f][0:length_min]) return array0_all, array01_all, array02_all, array012_all def plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02, array012_all, array012, ): fig, ax = plt.subplots(4, 2, sharey=True, sharex=True) lmax = np.nanmax([np.nanmax(np.transpose(array02_all)), np.nanmax(np.transpose(array01_all)), np.nanmax(np.transpose(array0_all)), np.nanmax(np.transpose(array012_all))]) lmin = np.nanmin([np.nanmin(np.transpose(array02_all)), np.nanmin(np.transpose(array01_all)), np.nanmin(np.transpose(array0_all)), np.nanmin(np.transpose(array012_all))]) names = ['base_0', 'control_01', 'control_02', '012'] xlim = 4000 x = 1000 for nn, n in enumerate(names): frame_dev = frame[frame.dev == devname] for i in range(len(frame_dev[n])): if nn == 0: ax[nn, 1].set_title('not sorted') ax[nn, 1].plot(frame_dev[n].iloc[i]) ax[nn, 0].set_xlim(0, xlim) ax[nn, 0].axvline(x=x) _, _ = ml.psd(frame_dev[n].iloc[i] - np.mean(frame_dev[n].iloc[i]), Fs=40000, NFFT=8000, noverlap=8000 / 2) plt.suptitle(devname) ax[3, 0].set_ylabel('012') ax[3, 0].plot(np.transpose(array012_all)) ax[3, 0].plot(array012[0], color='red') ax[3, 0].set_ylim(lmin, lmax) ax[3, 0].set_xlim(0, xlim) ax[3, 0].axvline(x=x) ax[1, 0].set_ylabel('01') ax[1, 0].plot(np.transpose(array01_all)) ax[1, 0].plot(array01[0], color='red') ax[1, 0].set_ylim(lmin, lmax) ax[1, 0].set_xlim(0, xlim) ax[1, 0].axvline(x=x) ax[2, 0].set_ylabel('02') ax[2, 0].plot(np.transpose(array02_all)) ax[2, 0].plot(array02[0], color='red') ax[2, 0].set_ylim(lmin, lmax) ax[2, 0].set_xlim(0, xlim) ax[2, 0].axvline(x=x) ax[0, 0].set_ylabel('0') ax[0, 0].set_title('sorted') ax[0, 0].plot(np.transpose(array0_all)) ax[0, 0].plot(array0[0], color='red') ax[0, 0].set_ylim(lmin, lmax) ax[0, 0].set_xlim(0, xlim) ax[0, 0].axvline(x=x) plt.subplots_adjust(hspace=0.4, wspace=0.35) save_visualization(show=False) def assign_trials(frame, devname, delays_length, mean_type, sampling=40000): # get all trails together array012_all, array01_all, array02_all, array0_all = cut_uneven_trials(frame, devname, mean_type, delays_length, sampling=sampling) # calculate mean or also not array0, array012, array01, array02, mean_nrs = assign_trials_mean(devname, frame, array012_all, mean_type, array01_all, array02_all, array0_all) return array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all def assign_trials_mean(devname, frame, array012_all, mean_type, array01_all, array02_all, array0_all): if '_MeanTrials' in mean_type: if 'TrialsConcat' in mean_type: trial_concats = int(mean_type.split('TrialsConcat_')[1][0]) length = len(array012_all) array012 = [] array01 = [] array02 = [] array0 = [] for trial_concat in range(trial_concats): array012.append([np.mean(np.array(array012_all[int(length * (trial_concat / trial_concats)):int( length * ((trial_concat + 1) / trial_concats))]), axis=0)]) array01.append([np.mean(np.array(array01_all[int(length * (trial_concat / trial_concats)):int( length * ((trial_concat + 1) / trial_concats))]), axis=0)]) array02.append([np.mean(np.array(array02_all[int(length * (trial_concat / trial_concats)):int( length * ((trial_concat + 1) / trial_concats))]), axis=0)]) array0.append([np.mean(np.array(array0_all[int(length * (trial_concat / trial_concats)):int( length * ((trial_concat + 1) / trial_concats))]), axis=0)]) else: if 'snippets' in mean_type: snippets = int(mean_type.split('snippets')[0][-1]) array012 = [np.mean(np.array(array012_all)[0:int(len(array012_all) / snippets)], axis=0), np.mean(np.array(array012_all)[int(len(array012_all) / snippets):-1], axis=0)] array01 = [np.mean(np.array(array01_all)[0:int(len(array012_all) / snippets)], axis=0), np.mean(np.array(array01_all)[int(len(array01_all) / snippets):-1], axis=0)] array02 = [np.mean(np.array(array02_all)[0:int(len(array012_all) / snippets)], axis=0), np.mean(np.array(array02_all)[int(len(array02_all) / snippets):-1], axis=0)] array0 = [np.mean(np.array(array0_all)[0:int(len(array012_all) / snippets)], axis=0), np.mean(np.array(array0_all)[int(len(array0_all) / snippets):-1], axis=0)] else: array012 = [np.mean(np.array(array012_all), axis=0)] array01 = [np.mean(np.array(array01_all), axis=0)] array02 = [np.mean(np.array(array02_all), axis=0)] array0 = [np.mean(np.array(array0_all), axis=0)] mean_nrs = len(array012_all) test = False if test: if devname == 'eod': plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02, array012_all, array012, ) plt.show() test = False if test == True: plot_traces_frame_three_roc(frame, [], show=True) else: mean_nrs = 1 array012 = array012_all # np.array(frame_dev['012']) array01 = array01_all # np.array(frame_dev['control_01']) array02 = array02_all # np.array(frame_dev['control_02']) array0 = array0_all # np.array(frame_dev['base_0']) test = False if test == True: from utils_test import test_assign_trials test_assign_trials(devname, array01, array02, array0, array012) return array0, array012, array01, array02, mean_nrs def plot_traces_frame_three_roc(frame, id_group, show=True, names=['base_0', 'control_01', 'control_02', '012']): ######################################### # plot thre arrays of the data counter = 0 fig, axis = plt.subplots(len(frame.dev.unique()), len(names), sharex=True) if len(id_group) > 0: plt.suptitle('DF1 ' + str(np.mean(id_group[1]['DF1'])) + ' DF2' + str( np.mean(id_group[1]['DF2']))) for ff, f in enumerate(frame.dev.unique()): for nn, n in enumerate(names): frame_dev = frame[frame.dev == f] for i in range(len(frame_dev[n])): axis[ff, nn].plot(frame_dev[n].iloc[i]) p, freq = ml.psd(frame_dev[n].iloc[i] - np.mean(frame_dev[n].iloc[i]), Fs=40000, NFFT=8000, noverlap=8000 / 2) if len(id_group) > 0: max_f = freq[np.argmax(p[freq < np.mean(id_group[1]['eodf'])])] axis[ff, nn].set_title(f + ' ' + n + ' ' + str(max_f)) counter += 1 plt.subplots_adjust(hspace=0.6, wspace=0.6) save_visualization(show=False) if show: plt.show() def divergence_title_add_on(group_mean, fr, autodefine): if 'triangle' in autodefine: if 'df1' in autodefine: try: divergence = np.abs(np.abs(group_mean[1]['DF1, DF2'].iloc[0][0]) - fr) except: print('df1 divergence problems') embed() elif 'df2' in autodefine: divergence = np.abs(np.abs(group_mean[1]['DF1, DF2'].iloc[0][1]) - fr) elif 'diagonal' in autodefine: try: divergence = np.abs(np.abs( np.abs(group_mean[1]['DF1, DF2'].iloc[0][0]) + np.abs(group_mean[1]['DF1, DF2'].iloc[0][1])) - fr) except: print('diagonal divergence problems') embed() else: divergence = '' fr_end = '\n fr ' + str(fr) + ' Hz ' + ' fr_m ' + str( np.round(np.mean(fr / group_mean[1].EODf + 1), 2)) + ' Hz ' + 'diverge from Fr by ' + str( divergence) + ' Hz ' + autodefine else: fr_end = '' return fr_end def find_env(way, results_diff, position_diff, sampling, f0='f0'): beat1 = np.abs(results_diff.loc[position_diff, 'f1'] - results_diff.loc[position_diff, f0]) beat2 = np.abs(results_diff.loc[position_diff, 'f2'] - results_diff.loc[position_diff, f0]) if 'mult_minimum' in way: env_f = np.min([np.min(beat1), np.min(beat2)]) elif 'mult_env' in way: env_f = np.abs(beat1 - beat2) if env_f == 0: env_f = np.min([np.min(beat1), np.min(beat2)]) elif 'mult_f1' in way: env_f = beat1 elif 'mult_f2' in way: env_f = beat2 else: if 'f1' in results_diff.keys(): env_f = np.abs(results_diff.loc[position_diff, 'f1'] - results_diff.loc[position_diff, 'f2']) 'mult_minimum', 'mult_env', 'mult_f1', 'mult_f2' datapoints = int((1 / env_f) * int(way[-1]) * sampling) return datapoints def check_nix_fish(b, with_fish2='with_fish2'): nix_there = False names_mt = [] names_f = [] for t_nr, mt in enumerate( b.multi_tags): # todo: hier kann man immer noch die Daten anschaeun die ohne Nix sind aber die waren glaube ich nciht so gut names_mt.append(mt.name) if ('Three' in mt.name) and not nix_there: # ok man braucht das hier wenn man nicht erst über alle mts gehen will! if with_fish2: for ff, f in enumerate(mt.features): if 'id' not in f.data.name and not nix_there: names_f.append(f.data.name) if 'fish2' in f.data.name: nix_there = True else: nix_there = False else: nix_there = True return nix_there def load_data_arrays(extract, mt_group, sampling_rate, sorted_on, b, mt, mt_nr, delay, printing=False): array_eod = {} ############################################ t1 = time.time() # global + Efield # die brauchen wir weil wir die als stimulus mit abspeichern wollen eod_globalEfield, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay, mt.extents[:][mt_nr] + delay, array_name='GlobalEFieldStimulus') eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay, mt.extents[:][mt_nr] + delay, array_name='EOD') array_eod['Global'] = eod_global array_eod['EField'] = eod_globalEfield if printing: print('second0 ' + str(time.time() - t1)) ####################################################### # das brauchen wir auf jeden Fall halt womöglich zum plotten # 'LocalReconst', 'Global', 'EField' t1 = time.time() time_eod = np.arange(0, len(eod_global) / 40000, 1 / 40000) - delay spikes_mt = link_arrays_spikes(b, first=mt.positions[:][mt_nr] - delay, second=mt.extents[:][mt_nr] + delay, minus_spikes=mt.positions[:][mt_nr]) array_eod['spikes_mt'] = spikes_mt array_eod['time_eod'] = time_eod if printing: print('second1 ' + str(time.time() - t1)) t1 = time.time() nrs = ['', '0.2', '0.4'] norms = ['', 'Norm'] # hier define ich die ich minimal brauche sorted_ons = ['LocalReconst', sorted_on] for sorted_on_here in sorted_ons: for norm in norms: if norm == 'Norm': norm_ex = True else: norm_ex = False for nr in nrs: # das Ganze jetzt auch nochmal groß machen für das sorting name = 'LocalReconst' + str(nr) + norm name_am = 'LocalReconst' + str(nr) + norm + 'Am' if nr != '': if name in sorted_on_here: t1 = time.time() eod_global_norm = zenter_and_normalize(eod_global, 1) eod_globalEfield_norm02 = zenter_and_normalize(eod_globalEfield, float(nr)) eod_local = cut_ends(eod_global_norm, eod_globalEfield_norm02) if printing: print('second20 ' + str(time.time() - t1)) else: eod_local = [] else: eod_local = cut_ends(eod_global, eod_globalEfield) ######################################## # ich glaube wir brauchen das jetzt nicht für alle # weil hier will ich dass nur die ams rauskommen die ich brauche ich muss das ja nicht üfr alle local global etcs. machen # todo: aber das muss man noch shcauen ob das nicht hier crashed if name_am in sorted_on_here: if len(eod_local) > 0: t1 = time.time() eod_final_am, eod_final = extract_am( eod_local, array_eod['time_eod'], sampling=sampling_rate, eodf=mt_group[1].eodf[ mt_nr], emb=False, norm=norm_ex, extract=extract) if printing: print('second21 ' + str(time.time() - t1)) if sorted_on_here == name_am: eod_final = [] else: eod_final_am = [] else: eod_final_am = [] eod_final = [] else: if len(eod_local) > 0: if norm_ex: pass else: eod_final = eod_local eod_final_am = [] else: eod_final_am = [] eod_final = [] t1 = time.time() array_eod = update_array_matrix(array_eod, eod_final, name) array_eod = update_array_matrix(array_eod, eod_final_am, name_am) if printing: print('second22 ' + str(time.time() - t1)) if printing: print('second2 ' + str(time.time() - t1)) t1 = time.time() if 'LocalEOD' in sorted_on: eod_local, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay, mt.extents[:][mt_nr] + delay, array_name='LocalEOD-1') else: eod_local = [] if sorted_on == 'LocalEOD': array_eod['LocalEOD'] = eod_local else: array_eod['LocalEOD'] = [] for norm in norms: if norm == 'Norm': norm_ex = True else: norm_ex = False if 'LocalEOD' + norm in sorted_on: eod_final_am, eod_final = extract_am( eod_local, array_eod['time_eod'], sampling=sampling_rate, eodf=mt_group[1].eodf[ mt_nr], emb=False, norm=norm_ex, extract=extract) if 'LocalEOD' + norm + 'Am' in sorted_on: array_eod['LocalEOD' + norm + 'Am'] = eod_final_am else: array_eod['LocalEOD' + norm + 'Am'] = [] if 'LocalEOD' + norm in sorted_on: array_eod['LocalEOD' + norm] = eod_final else: array_eod['LocalEOD' + norm] = [] if printing: print('second3 ' + str(time.time() - t1)) return array_eod def update_array_matrix(array_eod, eod_local_reconstruct_big_norm, name): if name not in array_eod.keys(): array_eod[name] = eod_local_reconstruct_big_norm else: if len(array_eod[name]) == 0: array_eod[name] = eod_local_reconstruct_big_norm return array_eod def chirps_delete_analysis(eod_local_interp, eod_norm, fish_cuts, time_eod, cut, fish_number, fish_number_base): eods, _ = cut_eod_sequences(eod_norm, fish_cuts, time_eod, cut=cut, rec=False, fish_number=fish_number, fillup=False, fish_number_base=fish_number_base) eods_int, _ = cut_eod_sequences(eod_local_interp, fish_cuts, time_eod, cut=cut, rec=False, fish_number=fish_number, fillup=False, fish_number_base=fish_number_base) keys = [k for k in eods] fish_number_final = fish_number * 1 for e in range(len(eods)): if len(eods[keys[e]]) > 0: try: freq, freq1, freq2, freq3, freq4 = calc_power( eods[keys[e]], nfft=2 ** 9, sampling_rate=40000, shift_by=0.001) except: print('freq problem') embed() eods_time = np.arange(0, len(eods[keys[e]]) / 40000, 1 / 40000) time_freq = np.linspace(eods_time[0] + 2 ** 9 / (40000 * 2), eods_time[-1] - 2 ** 9 / (40000 * 2), len(freq)) detection = 'No_Chirp_detected' time_detected = [] chirp_size = 35 random_data_std = np.std(eods_int[keys[e]]) random_data_mean = np.mean(eods_int[keys[e]]) anomaly_cut_off = random_data_std * 3 lower_limit = random_data_mean - anomaly_cut_off upper_limit = random_data_mean + anomaly_cut_off range_exeed = upper_limit - lower_limit if (np.ptp(freq) > chirp_size) & ( np.ptp(eods_int[keys[e]]) > range_exeed): # > 0.3 perc95 = np.median(freq) + chirp_size pos = np.diff(np.where(freq > perc95)) if 1 in pos: if 1 in np.diff(np.where(pos == 1)): lim_w = 0.04 lower_window = time_freq[ np.where(np.diff(freq))] - lim_w upper_window = time_freq[ np.where(np.diff(freq))] + lim_w time_detected = [] for nr_w in range(len(lower_window)): eods_cut = eods_int[keys[e]][ (eods_time > lower_window[nr_w]) & ( eods_time < upper_window[nr_w])] diverge = np.median(eods_int[keys[e]]) - np.min( eods_cut) if (diverge > 0.25) & ( detection != 'Chirp_detected'): # time_detected.append( lower_window[nr_w] + lim_w) if keys[e] not in chirp.keys(): chirp[keys[e]] = [mt_idx] else: chirp[keys[e]].append(mt_idx) fish_number_final[fish_number_final.index( keys[e])] = 'interspace' if len(np.unique(fish_number_final)) == 1: if np.unique(fish_number_final)[ 0] == 'interspace': print('to many chirps') detection = 'Chirp_detected' elif np.ptp(eods_cut) > range_exeed: # > 0.55 time_detected.append( lower_window[nr_w] + lim_w) test = False if test: from utils_test import check_chirp_del_directly check_chirp_del_directly(time_detected, eods_time, eods, range_exeed, eods_int, keys, e, freq, detection, time_freq, freq1, freq2, freq3, freq4) return fish_number_final def cut_spikes_and_eod_three(mt_group, b, extract, cut_nr=0, chirps='', devname_orig=['05'], emb=False, test=False, mean_type='', sorted_on='LocalReconst', sampling_rate=40000, devname=[], done=False, counter=0, printing=False, printing_all=False): # todo: das könnte man noch vereinfachen dass nur die wichtigen Sachen rauskommen t1 = time.time() mt_list = mt_group[1]['mt'] frame = [] chirp = {} spikes_pure = [] print('cut_spikes_and_eod_three is running') for mt_idx, mt_nr in enumerate(list(map(int, mt_list))): # range(start_l, len(mt.positions[:])) features, mt, name_here, l = get_mt_features3(b, mt_group, mt_idx) # somehow we have mts with negative extend, we exclude these t0 = time.time() if (mt.extents[:][mt_nr] > 0).any(): t1 = time.time() _, _, _, _, fish_number, fish_cuts, whole_duration, cont = load_durations(mt_nr, mt, mt_group[1], mt_idx, mean_type=mean_type, emb=False) delay = np.abs(fish_cuts[0]) if printing: print('first ' + str(time.time() - t1)) if cont: # embed() ######################################## # load the according EOD arrays # basics und reconstructs t1 = time.time() array_eod = load_data_arrays(extract, mt_group, sampling_rate, sorted_on, b, mt, mt_nr, delay) if printing: # todo:das dauert ewig! print('second ' + str(time.time() - t1)) if (len(array_eod['LocalReconst']) > 0) & (len(array_eod['spikes_mt']) > 0) & ( not ((len(array_eod['LocalReconst']) < 1) or (len(array_eod['spikes_mt']) < 1))): ######################################## # extract the am of the loaded arrays # das phase sorting sollte nicht anhand dieser AMs passieren, sondern anhand des gesamten Stimulus # if 'PhaseSort' in mean_type: eod_local_am = [] cut_edge = [cut_nr] # 0.02 for cut in cut_edge: if (len(array_eod['time_eod']) > 0) & (len(array_eod['spikes_mt']) > 0) & ( array_eod['time_eod'][-1] + delay > whole_duration * 0.9) & ( array_eod['spikes_mt'][-1] + delay > whole_duration * 0.6) & any_spikes( array_eod['spikes_mt'], minimal=fish_cuts[0] + cut, maximal=fish_cuts[-1] - cut): t1 = time.time() fish_number_base = remove_interspace_fish_nr(fish_number) if 'ChirpsDelete' in chirps: # die Snippets ausschließen wo der Fisch gechirpt hat fish_number_final = chirps_delete_analysis(eod_local_recondstruct_am, eods_local_reconstruct_norm, fish_cuts, array_eod['time_eod'], cut, fish_number, fish_number_base) else: fish_number_final = fish_number devname, smoothened2, smoothed05, mat, time_array, arrays_calc, effective_duration, spikes_cut = cut_spikes_sequences( delay, array_eod['spikes_mt'], sampling_rate, fish_cuts, cut=cut, fish_number_base=fish_number_base, fish_number=fish_number_final, devname_orig=devname_orig * 1, mean_type=mean_type) lengths = [] if printing: print('Forth ' + str(time.time() - t1)) for name in fish_number_final[::-1]: if 'interspace' not in name: lengths.append(len( arrays_calc[0][name])) # lengths.append(len(np.unique(arrays_calc[0][name]))) if np.min(lengths) > 2: # das sind die verschiedenen EOD versions die man zum sortieren brauchen könnte # if 'PhaseSort' in mean_type: # eod_arrays = [eod_global, eod_local_am, eods_local_norm, eod_local_recondstruct_am, eod_local_reconstruct, eods_local_reconstruct_norm, eod_local_reconstruct_big_am, eod_local_reconstruct_big_norm] # names = ['global','local_am','local_norm','local_reconst_am','local_reconst_norm','local_reconst','local_reconst_big_am','local_reconst_big_norm'] # eod_arrays = [eod_local_reconstruct_norm_huge,eod_global,eod_local_reconstruct, eod_local_reconstruct_big_norm] # names = ['local','global','local_reconst','local_reconst_big_norm'] # todo da das nehmen was wir für das sort on so brauchen # die braucen wir später fürs plotten einmal den stimulus mit machen ist immer gut # das ist einmal reconstruiert und einmal das auf die richtige contrast größe gebracht # todo das jetzt nochmal richtig machen und das so machen das man nur das macht was man braucht # das sind die basics, das sind die die wir später plotten ############################################################### t1 = time.time() names = ['LocalReconst', 'Global', 'EField'] eod_arrays = [array_eod['LocalReconst'], array_eod['Global'], array_eod['EField']] ############################################################### # und das ist fürs sorting, da nehmen wir jetzt auch noch das was wir eignetlich wollen # todo: das muss man noch systematisch machen und bequemer implementieren # eod_arrays_possible = [array_eod['Local'], # array_eod['LocalReconst0.4Norm'], # eod_local_recondstruct_am, eods_local_reconstruct_norm, # eod_local_reconstruct_big_norm, # eod_local_reconstruct_big_am] # names_possible = ['Local', # 'LocalReconst0.4Norm', # 'LocalReconstAm', 'LocalReconst', # 'LocalReconst0.2Norm', 'LocalReconst0.2Am'] # where_pos = np.where(np.array(names_possible) == sorted_on)[0] if sorted_on in array_eod.keys(): # len(where_pos) > 0: eod_arrays.append(array_eod[sorted_on]) # eod_arrays_possible[where_pos[0]]) names.append(sorted_on) # names_possible[where_pos[0]] for e, eod_array in enumerate(eod_arrays): try: eods_cut, _ = cut_eod_sequences(eod_array, fish_cuts, cut=cut, rec=False, fish_number=fish_number_final, fillup=True, fish_number_base=fish_number_base) except: print('eod problem0') embed() arrays_calc.append(eods_cut) time_array.append(array_eod['time_eod']) devname.append(names[e]) if names[e] == 'Global': idx = len(arrays_calc) - 1 + e if printing: print('Fifth ' + str(time.time() - t1)) t1 = time.time() names_synch = ['EodLocSynch', 'EodAmSynch'] if 'Synch' in sorted_on: ##################### # synthetisiere den stimulus aus dem global und dem idealen stimulus # das ist gar nicht so schlecht # das ist das gleiche wie das globale und das Efield eods_loc_synch, eods_am_synch = synthetise_eod(mt_nr, extract, sampling_rate, sampling_rate, mt_idx, idx, arrays_calc, mt_group) eod_arrays = [eods_loc_synch, eods_am_synch] # names = ['eod_loc_synch', 'eod_am_synch'] for e, eod_array in enumerate(eod_arrays): arrays_calc.append(eods_cut) time_array.append(array_eod['time_eod']) devname.append(names_synch[e]) if test: # eod_local_am, eods_local_norm,'local_am','local_norm', from utils_test import test_EOD_arrays arrays_calc, devname, eods_cut, idx, time_array = test_EOD_arrays(cut, e, eods_cut, eods_loc_synch, fish_cuts, fish_number_base, fish_number_final, idx) else: array_eod[names_synch[0]] = [] array_eod[names_synch[1]] = [] if printing: print('six ' + str(time.time() - t1)) ################################## # das in dataframe speichern t1 = time.time() frame, spikes_cut, spikes_pure, done = transform_dataframe(frame, spikes_pure, done, arrays_calc, devname, spikes_cut) if printing: print('seventh ' + str(time.time() - t1)) counter += 1 test = False if test: from utils_test import compare_chirp_nfft, compare_chirp_nfft_traces compare_chirp_nfft() # time_eod, eod_local compare_chirp_nfft_traces() # chirp, eod_local, time_eod, time_array[0],smoothed05, fish_cuts else: print('negative mt') if done == False: devname = [] frame = [] if printing_all: print('all ' + str(mt_idx) + ' ' + str(time.time() - t0)) if emb: embed() if test == True: from utils_test import test_eod_arrays2 test_eod_arrays2(frame) if printing: print('a_all ' + str(time.time() - t1)) if test == True: names = [] for f in fish_number_base: if not 'interspace' in f: names.append(f) overview_of_mt_group(frame, names=names) if len(frame) < 1: print('devname to short!') return [[]] * 22 else: return spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number_final, smoothened2, smoothed05, \ array_eod['LocalEOD'], eod_local_am, effective_duration, cut, devname, frame def get_mt_features3(b, mt_group, mt_idx=-1): name_here = mt_group[1]['mt_names'].iloc[mt_idx] mt = b.multi_tags[name_here] features, delay_name = feature_extract(mt) l = mt_group[1]['mt'].iloc[-1] return features, mt, name_here, l def transform_dataframe(frame, spikes_pure, done, arrays_calc, devname, spikes_cut): if done == False: frame = pd.DataFrame(arrays_calc) frame['dev'] = devname spikes_pure = pd.DataFrame(spikes_cut) done = True else: frame_new = pd.DataFrame(arrays_calc) frame_new['dev'] = devname frame = pd.concat([frame, frame_new]) spikes_cut = pd.DataFrame(spikes_cut) spikes_pure = pd.concat([spikes_pure, spikes_cut]) return frame, spikes_cut, spikes_pure, done def overview_of_mt_group(frame, names=['012', 'control_01', 'control_02', 'base_0']): trial_nr = len(frame) / len(frame.dev.unique()) for i in range(int(trial_nr)): fig, ax = plt.subplots(len(frame.dev.unique()), len(names), sharex=True) for nn, name in enumerate(names): for dd, dev in enumerate(frame.dev.unique()): dev10 = frame[frame.dev == dev] ax[dd, nn].plot(dev10[name].iloc[i]) ax[dd, nn].set_title(name) save_visualization() plt.show() def calc_power(arr, nfft=2 ** 17, sampling_rate=40000, time_show=False, shift_by=0.01): t1 = time.time() shifts = np.arange(0, len(arr) - nfft, shift_by * sampling_rate) np.arange(0, len(arr) - nfft, shift_by * sampling_rate) freq = np.zeros(len(shifts)) freq1 = np.zeros(len(shifts)) freq2 = np.zeros(len(shifts)) freq3 = np.zeros(len(shifts)) freq4 = np.zeros(len(shifts)) pps = [[]] * len(shifts) for s, start in enumerate(shifts): pps[s], freq[s], freq1[s], freq2[s], freq3[s], freq4[s] = get_mult_freqs(arr[int(start):int(start + nfft)], sampling_rate, nfft) if time_show: print('calc power' + str(time.time() - t1)) return freq, freq1, freq2, freq3, freq4 def get_mult_freqs(arr, sampling_rate, nfft, ): p, f = ml.psd( arr - np.mean(arr), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) pps = p freq = f[np.argmax(p)] freq1, freq2, freq3, freq4 = find_mult_freqs(p, freq, f) return pps, freq, freq1, freq2, freq3, freq4 def find_mult_freqs(p, freq, f): first_harm = (f > freq * 1.8) & (f < freq * 2.2) freq1 = f[first_harm][np.argmax(p[first_harm])] / 2 second_harm = (f > freq * 2.8) & (f < freq * 3.2) freq2 = f[second_harm][np.argmax(p[second_harm])] / 3 third_harm = (f > freq * 3.8) & (f < freq * 4.2) freq3 = f[third_harm][np.argmax(p[third_harm])] / 4 forth_harm = (f > freq * 4.8) & (f < freq * 5.2) freq4 = f[forth_harm][np.argmax(p[forth_harm])] / 5 return freq1, freq2, freq3, freq4 def find_corr_time(corr1): corr_time_negative = corr1 * 1 corr_time = np.arange(0, len(corr_time_negative), 1) corr_time_negative[corr_time > len(corr1) / 2] = 0 corr_time_neg_zentered = (corr_time - len(corr_time_negative) / 2) / 40000 return corr_time_neg_zentered, corr_time_negative, corr_time def plt_frame_traces_original(frame_dev_eod, n): fig, ax = plt.subplots(len(frame_dev_eod[n]), 1, sharex=True) plt.title('Original Traces') for i in range(len(frame_dev_eod[n])): ax[i].plot(frame_dev_eod[n].iloc[i]) save_visualization() plt.show() def plt_delays_pair(i, inputs, outputs, titles, n, autocorr1, frame_dev_eod, shifted_eod, i_nr, delay, corr_time_neg_zent, corr1): grid1 = gridspec.GridSpec(4, 1, hspace=0.8, wspace=1.2) # test = False if test == True: from utils_test import plot_crosscorrelations plot_crosscorrelations() plt.subplot(grid1[0]) plt.title(titles + ' ' + n) corr_time_neg_zent_a, corr_time_negative, corr_time = find_corr_time( autocorr1) plt.plot(corr_time_neg_zent_a, autocorr1, label='autocorr', color='blue') plt.plot(corr_time_neg_zent, corr1, label='corr initial', color='red') plt.axvline(x=0, color='grey', label='mitte') plt.xlim(-0.1, 0.1) plt.subplot(grid1[1]) plt.plot(corr_time_neg_zent_a, autocorr1, label='autocorr', color='blue') corr_test1 = scipy.signal.correlate(inputs[0:-delay], outputs[delay::]) corr_time_neg_zent_1, corr_time_negative, corr_time = find_corr_time( corr_test1) plt.plot(corr_time_neg_zent_1, corr_test1, label='corr after', color='red') plt.axvline(x=0, color='grey', label='mitte') plt.legend(loc=(0, 4), ncol=2) plt.xlim(-0.1, 0.1) plt.subplot(grid1[2]) plt.title('eod before') plt.plot( np.arange(0, len(frame_dev_eod[n].iloc[i_nr]) / 40000, 1 / 40000), frame_dev_eod[n].iloc[i_nr], label='i') plt.plot(np.arange(0, len(frame_dev_eod[n].iloc[i + 1]) / 40000, 1 / 40000), frame_dev_eod[n].iloc[i + 1], label='i+1', color='red') plt.xlim(0, 0.3) plt.subplot(grid1[3]) plt.title('eod after') plt.plot( np.arange(0, len(frame_dev_eod[n].iloc[i_nr]) / 40000, 1 / 40000), frame_dev_eod[n].iloc[i_nr], label='i') plt.plot(np.arange(0, len(shifted_eod) / 40000, 1 / 40000), shifted_eod, label='after2', color='red') plt.legend() save_visualization() plt.show() def plt_shifted_input(inputs, delay, outputs): plt.subplot(3, 1, 1) plt.plot(inputs[0:-delay]) plt.subplot(3, 1, 2) plt.plot(outputs[delay::]) plt.subplot(3, 1, 3) plt.plot(inputs[0:-delay]) plt.plot(outputs[delay::]) save_visualization() plt.show() def find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length, i_nr=0): inputs = frame_dev_eod[name_orig].iloc[i_nr] - np.nanmean( frame_dev_eod[name_orig].iloc[i_nr]) # , input_eod, if inputs != []: outputs = frame_dev_eod[name_orig].iloc[i + 1] - np.nanmean( frame_dev_eod[name_orig].iloc[i + 1]) # , output_eod, titles = 'eod' # 'eod smoothed', '05', '2'] if outputs != []: try: autocorr1 = scipy.signal.correlate(inputs, inputs) corr1 = scipy.signal.correlate(inputs, outputs) except: print('corr1 in utils function') embed() corr_time_neg_zent, corr_time_negative, corr_time = find_corr_time( corr1) if 'Min' in mean_type: minimum = float(mean_type.split('Min')[1].split('sExcluded_')[0]) corr_time_negative[corr_time < len(corr_time_negative) / 2 - minimum * 40000] = 0 delay = np.abs(np.argmax(corr_time_negative) - int(len(corr_time_negative) / 2)) shifted_eod = frame_dev_eod[name_orig].iloc[i + 1][delay::] array_length = np.arange(0, len(outputs), 1) delays_length[n].append(array_length[delay::]) plot = False if plot == True: ################### # plot traces plt_frame_traces_original(frame_dev_eod, name_orig) ############ # plot crosscorrelation plt_delays_pair(i, inputs, outputs, titles, name_orig, autocorr1, frame_dev_eod, shifted_eod, i_nr, delay, corr_time_neg_zent, corr1) ############ # plot shifted input plt_shifted_input(inputs, delay, outputs) else: delays_length[n].append([]) else: array_length = np.arange(0, len(frame_dev_eod[n].iloc[i_nr + 1]), 1) delays_length[n].append(array_length) i_nr += 1 return delays_length, i_nr def create_arrays(df1, i, j, sampling_rate=10000): time = np.arange(0, 30, 1 / sampling_rate) # period[-1] time_fish_r = time * 2 * np.pi * df1[i] eod_fish_r = 1 * np.sin(time_fish_r) time_fish_e = time * 2 * np.pi * df1[j] eod_fish_e = 1 * np.sin(time_fish_e) stimulus = eod_fish_e + eod_fish_r return stimulus, eod_fish_e, eod_fish_r, time def exclude_ratios(f3, df, diff_max, integers, i, j, ratio_f, f_max, f_max2, diff_mean, diff_min, bigger, ratio, df1): self = True integers[i, j] = False if (ratio % 1 == 0) or (ratio_f % 1 == 0): if self == True: f_max[i, j] = 1 / bigger df[i, j] = 1 / bigger f3[i, j] = 1 / bigger f_max2[i, j] = 1 / bigger diff_mean[i, j] = 1 / bigger diff_min[i, j] = 1 / bigger diff_max[i, j] = 1 / bigger else: f_max[i, j] = diff_mean[i, j] integers[i, j] = True if df1[i] == df1[j]: if self == True: f_max[i, j] = df1[j] f_max2[i, j] = df1[j] diff_mean[i, j] = df1[j] diff_min[i, j] = df1[j] diff_max[i, j] = df1[j] df[i, j] = df1[j] f3[i, j] = df1[j] else: f_max[i, j] = diff_mean[i, j] integers[i, j] = True return integers, f_max, diff_max, diff_min, diff_mean, f_max2 def do_splits(period_cut, sampling_rate, stimulus, length=0.4): splits = period_cut * sampling_rate if length != 'no': stim0 = stimulus[int(splits[0]):int(splits[0] + length * sampling_rate)] # [int(splits[0]):int(splits[1])] stim1 = stimulus[int(splits[1]):int(splits[1] + length * sampling_rate)] # [int(splits[1]):int(splits[2])] stim2 = stimulus[int(splits[2]):int(splits[2] + length * sampling_rate)] # [int(splits[2]):int(splits[3])] stim3 = stimulus[int(splits[3]):int(splits[3] + length * sampling_rate)] # [int(splits[3]):int(splits[4])] else: stim0 = stimulus[int(splits[0]):int(splits[1])] stim1 = stimulus[int(splits[1]):int(splits[2])] stim2 = stimulus[int(splits[2]):int(splits[3])] stim3 = stimulus[int(splits[3]):int(splits[4])] return stim0, stim1, stim2, stim3, splits def calc_dist(stim0, stim1): stim01 = stim0 * 1 stim02 = stim1 * 1 if len(stim0) > len(stim1): stim01 = stim0[0:len(stim1)] elif len(stim0) < len(stim1): stim02 = stim1[0:len(stim0)] dist = np.mean(np.sqrt((stim01 - stim02) ** 2)) return dist, stim01, stim02 def get_different_periods(df1, df2): f_max = np.zeros([len(df1), len(df2)]) df = np.zeros([len(df1), len(df2)]) f3 = np.zeros([len(df1), len(df2)]) f_max2 = np.zeros([len(df1), len(df2)]) diff_mean = np.zeros([len(df1), len(df2)]) diff_min = np.zeros([len(df1), len(df2)]) diff_max = np.zeros([len(df1), len(df2)]) var = np.zeros([len(df1), len(df2)]) size_diffs = np.zeros([len(df1), len(df2)]) dist_variable = np.zeros([len(df1), len(df2)]) dist_max = np.zeros([len(df1), len(df2)]) dist_max2 = np.zeros([len(df1), len(df2)]) dist_min = np.zeros([len(df1), len(df2)]) dist_mean = np.zeros([len(df1), len(df2)]) dist_fmax = np.zeros([len(df1), len(df2)]) dist_f3 = np.zeros([len(df1), len(df2)]) dist_df = np.zeros([len(df1), len(df2)]) ratios = np.zeros([len(df1), len(df2)]) integers = np.zeros([len(df1), len(df2)]) limit = 0.1 # 0.09 # 0.05 plot_type = '' # 'dist'#''#'True'#'dist'#''#'dist'#'period'# for i in range(len(df1)): for j in range(len(df2)): print('i ' + str(df1[i]) + ' j ' + str(df2[j])) DF1_per = 1 / df1[i] DF2_per = 1 / df2[j] if not (np.isinf(DF1_per) | np.isinf(DF2_per)): bigger = np.max([DF2_per, DF1_per]) smaller = np.min([DF2_per, DF1_per]) bigger_f = np.max([df1[j], df2[i]]) smaller_f = np.min([df1[j], df2[i]]) ratio_f = bigger_f / smaller_f ratio = bigger / smaller ratios[i, j] = ratio dim = 4000 period = np.arange(0, dim, 1) * bigger # this is the window we are ready to sacrify t[-1] time_bigger_f = (np.arange(0, dim, 1) * ratio) rests_final = time_bigger_f % 1 period_interp = np.arange(0, period[-1], 1 / 1000) interpolated = interpolate(period, rests_final, period_interp, kind='linear') _, _ = ml.psd(interpolated - np.mean(interpolated), Fs=1 / np.diff(period_interp)[0], NFFT=5000, noverlap=5000 / 2) test = False if test == True: from utils_test import plot_psd plot_psd() ##################################### # find the right euclidean distance sampling_rate = 10000 stimulus, eod_fish_e, eod_fish_r, time = create_arrays(df1, i, j, sampling_rate=sampling_rate) p, f = ml.psd(rests_final - np.mean(rests_final), Fs=1 / np.diff(period)[0], NFFT=5000, noverlap=5000 / 2) f_max[i, j] = f[np.argmax(p)] f3[i, j] = 1 / np.abs((1 / df1[i] + 1 / df1[j])) df[i, j] = np.abs(df1[i] - df1[j]) one_zero = (rests_final < limit) | (rests_final - 1 > -limit) period_cut = period[one_zero] diff_mean[i, j] = 1 / np.mean(np.diff(period_cut)) diff_min[i, j] = 1 / np.min(np.diff(period_cut)) diff_max[i, j] = 1 / np.max(np.diff(period_cut)) p2, f2 = ml.psd(one_zero - np.mean(one_zero), Fs=1 / np.diff(period)[0], NFFT=20000, noverlap=20000 / 2) f_max2[i, j] = f2[np.argmax(p2)] ##################### # period_cut paramteres size_diff = np.max(np.diff(period_cut)) - np.min(np.diff(period_cut)) size_diffs[i, j] = size_diff integers, f_max, diff_max, diff_min, diff_mean, f_max2 = exclude_ratios(f3, df, diff_max, integers, i, j, ratio_f, f_max, f_max2, diff_mean, diff_min, bigger, ratio, df1) dist_f3[i, j] = find_dist_pure(1 / f3[i, j], sampling_rate, stimulus) dist_df[i, j] = find_dist_pure(1 / df[i, j], sampling_rate, stimulus) dist_fmax[i, j] = find_dist_pure(1 / f_max[i, j], sampling_rate, stimulus) dist_mean[i, j] = find_dist_pure(1 / diff_mean[i, j], sampling_rate, stimulus) dist_min[i, j] = find_dist_pure(1 / diff_min[i, j], sampling_rate, stimulus) dist_max[i, j] = find_dist_pure(1 / diff_max[i, j], sampling_rate, stimulus) dist_max2[i, j] = find_dist_pure(1 / f_max2[i, j], sampling_rate, stimulus) var[i, j] = np.std(np.diff(period_cut)) dist_variable[i, j] = find_dist_pure(period_cut, sampling_rate, stimulus) print(dist_min[i, j]) if plot_type == 'True': test = True elif plot_type != '': if plot_type == 'dist': if dist_min[i, j] > 0.2: # (: test = True else: test = False elif plot_type == 'period': # if 1 / dist_variable[i, j] > 0.1: # (:1/diff_min[i,j] > 0.2 test = True else: test = False if test: from utils_test import plt_period plt_period() else: print('inf') f_max[i, j] = float('nan') f_max2[i, j] = float('nan') df[i, j] = float('nan') f3[i, j] = float('nan') diff_mean[i, j] = float('nan') diff_min[i, j] = float('nan') diff_max[i, j] = float('nan') size_diffs[i, j] = float('nan') ratios[i, j] = float('nan') integers[i, j] = float('nan') return dist_f3, dist_df, dist_fmax, dist_max2, dist_mean, dist_min, dist_max, dist_variable, var, diff_min, integers, ratios, size_diffs, diff_max, diff_mean, f3, df, f_max2, f_max def find_dist_pure(f3, sampling_rate, stimulus): if type(f3) == np.float64: period_cut = np.arange(0, 20, f3) else: period_cut = f3 stim0, stim1, stim2, stim3, splits = do_splits(period_cut, sampling_rate, stimulus) dist_f3, stim01, stim02 = calc_dist(stim0, stim1) return dist_f3 def define_delays_trials(mean_type, frame, sorted_on='local_reconst_big_norm'): if 'PhaseSort' in mean_type: ############################################## # try the cross spektrum test = False frame_dev_eod = frame[frame.dev == sorted_on] names = frame_dev_eod.keys()[0:-1][::-1] delays_length = {} if test: from utils_test import test_delays test_delays(frame) if 'Same' in mean_type: names_orig = ['control_02', 'control_02', 'control_02', 'control_02'] else: names_orig = names for nn, n in enumerate(names): name_orig = names_orig[nn] if 'base' not in n: delays_length[n] = [] i_nr = 0 for i in range(len(frame_dev_eod) - 1): delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length, i_nr=i_nr) if test: from utils_test import test_delay2 test_delay2() else: delays_length[n] = [] i_nr = 0 for i in range(len(frame_dev_eod) - 1): delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length, i_nr=i_nr) if 'Same' in mean_type: delays_length[n] = [] i_nr = 0 for i in range(len(frame_dev_eod) - 1): delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length, i_nr=i_nr) test = False if test == True: from utils_test import test_delays3 test_delays3() else: delays_length = [] return delays_length def group_the_certain_group_several(grouped, DF2_desired, DF1_desired, emb=False): try: mult1 = np.array([a_tuple[2][0] for a_tuple in grouped.groups.keys()]) mult2 = np.array([a_tuple[2][1] for a_tuple in grouped.groups.keys()]) except: print('tuple problem') embed() if str(mult1[0]) == '(': tuples = np.array([a_tuple[2] for a_tuple in grouped.groups.keys()]) try: tuples_convert = np.array([ast.literal_eval(a_tuple) for a_tuple in tuples]) except: print('tuple thing') embed() mult1 = np.array([a_tuple[0] for a_tuple in tuples_convert]) mult2 = np.array([a_tuple[1] for a_tuple in tuples_convert]) try: mult_array = np.round(np.abs(mult1 - DF1_desired) + np.abs((mult2 - DF2_desired)), 2) except: print('mult tuple problem') embed() restrict = np.argmin(mult_array) min_val = mult_array[restrict] restrict = mult_array == min_val if emb: embed() return restrict def calc_mult(freq1, eodf, freq2): DeltaF1 = freq1 - eodf DeltaF2 = freq2 - eodf mult1 = DeltaF1 / eodf + 1 mult2 = DeltaF2 / eodf + 1 return mult1, mult2, DeltaF2, DeltaF1 def save_features(features, mt, mt_sorted): for f in range(len(features)): name = features[f][len(mt.name) + 1::] mt_feature = np.concatenate(mt.features[features[f]].data[:]) if len(mt_feature) == len(mt.positions[:]): mt_sorted[name] = mt_feature else: # es gibt diese alle ersten Zellen wo wir nur die Daten hatten und dann später in Nix files konvertiert hatten # die mit sehr niedriegen Kontrasten if (name == 'Frequency') | (name == 'DeltaF'): mt_sorted[name + '1'] = mt_feature[0:len(mt_feature):2] mt_sorted[name + '2'] = mt_feature[1:len(mt_feature):2] else: print('mt problems') embed() return mt_sorted def load_metadata_infos_three(mt_sorted, mt, ver_here='new'): if ('fish1.DeltaF' in mt_sorted) and (ver_here == 'new'): phase = mt.metadata.sections[0]['fish1']['fish2']['Phase'] ####################################### # contrasts contrast1 = mt.metadata.sections[0]['fish1alone']['Contrast'] if 'fish2alone' in mt.metadata.sections[0].sections: contrast2 = mt.metadata.sections[0]['fish2alone']['Contrast'] else: # das ist im Fall wenn der eine Kontrast Null ist, also solche Zellen sollten wir eigentlich nicht haben # aber der Vollständigkeithalber ist das hier jetzt drin! contrast2 = mt.metadata.sections[0]['fish1']['fish2']['Contrast'] freq1_orig = mt_sorted['fish1.Frequency'] freq2_orig = mt_sorted['fish2.Frequency'] eodf_orig = mt_sorted['EODf'] # .iloc[0] # das ist für die älteren Zellen, die haben ein bisschen eine andere Namens gebeung elif 'fish1.Frequency' in mt_sorted: phase = mt.metadata.sections[0]['fish2']['Phase'] ver_here = 'code_old' contrast1 = mt.metadata.sections[0]['Contrast'] contrast2 = mt.metadata.sections[0]['fish2']['Contrast'] freq1_orig = mt_sorted['fish1.Frequency'] # todo das eventuell noch ändern freq2_orig = mt_sorted['fish2.Frequency'] # das ist für die sehr alteren Zellen, die haben ein bisschen eine andere Namens gebeung else: # für z.B. Zelle ['2021-06-23-ab-invivo-1'] phase = mt.metadata.sections[0].sections[0]['Phase'] ver_here = 'code_very_old' contrast1 = mt.metadata.sections[0]['Contrast'] contrast2 = mt.metadata.sections[0].sections[0]['Contrast'] freq1_orig = np.array(mt_sorted['Frequency']) # todo das eventuell noch ändern try: mt_sorted['fish1.Frequency'] = freq1_orig except: print('freq problem1') embed() freq2_orig = np.array(mt_sorted['fish2.Frequency']) mt_sorted['fish2.Frequency'] = freq2_orig # elif 'pureEODf' in eodftype: # wir machen im Prinzip immer pure EODf das macht einfach Sinn eodf_orig = mt_sorted['fish2.Frequency'] * float('nan') # .iloc[0] return phase, ver_here, contrast1, contrast2, eodf_orig, freq1_orig, freq2_orig def feautures_in_mtframe(mt): mt_range = np.arange(0, len(mt.positions[:])) mt_sorted = pd.DataFrame(mt_range, columns=['mt']) mt_sorted['mt_names'] = mt.name ############################## # features in nix features, delay_name = feature_extract(mt) mt_sorted = save_features(features, mt, mt_sorted) return mt_sorted def find_eodf(times_final, eodf_orig, eodftype, b, mt, mt_idx=[]): if eodftype == '_psdEOD_': # DEFAULT # ok das andere das kann ich auf einem mehrere MT Level extrahieren, aber diese Analyse muss ich hier einzeln machen # also viele einzelne psds. Deswegen speichern wir das alles ab um das nicht jedes Mal neu zu machen, weil ich ja später nochmal über die Zelle # iteriere # und ich mache das nicht erst später, weil ich ja nach den Mehrfachen gruppiere und den Frequenzen die sich ja # ändern können, deswegen ist das schon gut wenn das hier schon passieren kann eodf, eodf_orig, freq_steps = find_eodf_three(b, mt, eodf_orig, mt_idx=mt_idx) if np.isnan(eodf).any(): eodf = np.ones(len(eodf)) * times_final['EODf'].iloc[0] else: eodf = eodf_orig return eodf def predefine_grouping_frame(b, redo=False, load=True, eodftype='_psdEOD_', freqtype='', printing=True, ver_here='new', intial=False, cell_name=[]): name = 'calc_auc_three_core-spikes_core_AUCI_multsorted2__psdEOD_all.pkl' path = load_folder_name('threefish') + '/' + name version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() load_function = find_load_function() path_local = load_function + '-' + name.replace('.pkl', '.csv') path_local_pkl = load_function + '-' + name if (version_comp == 'public') & (redo == False): # todo: ja das könnte man noch ausbauen mt_sorted = pd.read_csv(path_local, index_col=0) # if len(cell_name) > 0: times_final = mt_sorted[mt_sorted['cell'] == cell_name] else: times_final = mt_sorted[mt_sorted['cell'] == b.name] elif os.path.exists(path) & (load == True): mt_sorted = pd.read_pickle(path) times_final = mt_sorted[mt_sorted['cell'] == b.name] if version_comp == 'develop': mt_sorted.to_pickle(path_local_pkl) print('reloaded mt') else: t1 = time.time() for mt_nr, mt in enumerate(b.multi_tags): # todo: dieses Predefined vielleicht abspeichern damit man nicht so lange plotten muss! t3 = time.time() if 'Three' in mt.name: t2 = time.time() mt_sorted = feautures_in_mtframe(mt) phase, ver_here, contrast1, contrast2, eodf_orig, freq1_orig, freq2_orig = load_metadata_infos_three( mt_sorted, mt, ver_here=ver_here) if printing: print('metadata: ' + str(time.time() - t2)) t4 = time.time() # das hier machen wir einmal für alle mts! try: eodf = find_eodf(mt_sorted, eodf_orig, eodftype, b, mt) except: print('problem stuff') embed() if printing: print('load_eod: ' + str(time.time() - t4)) if freqtype == '_psd_': # die Frequenzen bestimmen wir so wie sie abgespeichert sind das ist sonst ein Problem mit den niedriegen Kotnrasten freq1, freq1_orig, freq2, freq2_orig = find_freqs_three(b, mt, freq1_orig, freq2_orig, mt_sorted) else: # DEFAULT # das ist für die Bilder gut, da muss das glaube ich nicht so genau sein freq1 = freq1_orig freq2 = freq2_orig t5 = time.time() mult1, mult2, DeltaF2, DeltaF1 = calc_mult(freq1, eodf, freq2) if printing: print('calc_mult: ' + str(time.time() - t5)) # das runden wir auch zum gruppieren mt_sorted['EODmult1'] = np.round(mult1, 2) mt_sorted['EODmult2'] = np.round(mult2, 2) mt_sorted['f1'] = freq1 mt_sorted['f2'] = freq2 mt_sorted['f1_orig'] = freq1_orig mt_sorted['f2_orig'] = freq2_orig # wir runden das DeltaF1 und DeltaF2, weil wir die dann ja groupieren wollen try: DeltaF1 = np.array(list(map(int, np.round(freq1 - eodf)))) except: print('eodf thing') embed() DeltaF2 = np.array(list(map(int, np.round(freq2 - eodf)))) mt_sorted['DF2'] = DeltaF2 mt_sorted['DF1'] = DeltaF1 mt_sorted['phase'] = phase mt_sorted['eodf'] = eodf mt_sorted['eodf_orig'] = eodf_orig mt_sorted['DF1, DF2'] = list(zip(DeltaF1, DeltaF2)) mt_sorted['m1, m2'] = list(zip(mt_sorted['EODmult1'], mt_sorted['EODmult2'])) mt_sorted['c1'] = contrast1 mt_sorted['c2'] = contrast2 restrict = np.arange(0, len(mt.positions[:])) mt_sorted = mt_sorted.iloc[restrict] # neuen (mt_sorted) if intial == False: times_final = mt_sorted intial = True else: times_final = pd.concat([times_final, mt_sorted]) test = False if test == True: from utils_test import plt_freqs plt_freqs() if printing: print('predefine_grouping_frame2: ' + str(time.time() - t2)) if printing: print('predefine_grouping_frame3: ' + str(time.time() - t3)) if printing: print('predefine_grouping_frame: ' + str(time.time() - t1)) return times_final def find_freqs_three(b, mt, freq1_orig, freq2_orig, mt_sorted): freq1 = [] freq2 = [] for mt_nr_small in range(len(mt.positions[:])): ################## # get the times where to cut the stimulus zeroth_cut, first_cut, second_cut, third_cut, fish_type, fish_cuts, whole_duration, delay, cont = load_four_durations( mt, mt_sorted, mt_nr_small, mt_nr_small) ################## # get the stimulus eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small] - delay, mt.extents[:][mt_nr_small] + delay, array_name='GlobalEFieldStimulus') # fish_number_base = remove_interspace_fish_nr(fish_number) eods_glb, _ = cut_eod_sequences(eod_global, fish_cuts, cut=0, rec=False, fish_number=fish_type, fillup=True, fish_number_base=fish_type) if len(eods_glb['control_01']) > 0: f1, p1, f = calc_freq_from_psd(eods_glb['control_01'], sampling_rate) # v else: f1 = freq1_orig.iloc[mt_nr_small] freq1.append(f1) if np.max(np.abs(freq1 - freq1_orig)) > 25: print('f1 diff too big') embed() sampling_rate = 40000 _, _ = nfft_improval(sampling_rate, freq1_orig.iloc[mt_nr_small], eods_glb['control_01'], freq1) # problem: Bei kleinen Kontrasten ist das wohl keine so gute Idee.. # wir sollten dann doch davon ausgehen dass das stimmt mit den Frequenzen! test = False if test: fig, ax = plt.subplots(2, 1) ax[0].plot(eods_glb['control_01']) ax[1].plot(f, p1) plt.show() if len(eods_glb['control_02']) > 0: f2, p2, f = calc_freq_from_psd(eods_glb['control_02'], sampling_rate) # v else: f2 = freq2_orig.iloc[mt_nr_small] freq2.append(f2) if np.max(np.abs(freq2 - freq2_orig)) > 25: print('f2 diff too big') embed() freq1 = np.array(freq1) freq2 = np.array(freq2) return freq1, freq1_orig, freq2, freq2_orig def find_eodf_three(b, mt, eodf_orig, max_eod=False, mt_idx=[], freq_step_nfft_eod=0.6103515625): eodf = [] # je nach dem ob ich alle mt freqs extrahieren will oder nur einen bestimmten index! try: if not list(mt_idx): ranges_here = range(len(mt.positions[:])) else: ranges_here = mt_idx except: print('mt something wierd') embed() freq_steps = [] for mt_nr_small in ranges_here: ################## # get the global EOD # hier können wir alles vom mt nehmen weil sich das ja nicht im Abhängigkeit vom Stimulus ändert # eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay, # mt.extents[:][mt_nr] + delay, mt.positions[:][mt_nr], # load_eod_array='EOD') # die Dauer sollte mindestens eine halbe Sekunde haben sonst hat das nicht genug Power! duration = mt.extents[:][mt_nr_small] sampling = get_sampling(b, 'EOD') nfft_eod = int(sampling / freq_step_nfft_eod) # das ist das wir die minimal frequenz auflösung bekommen if duration < nfft_eod / sampling: duration = nfft_eod / sampling global_eod, sampling = get_global_eod_for_eodf(b, duration, mt, mt_nr_small) if len(global_eod) > 0: ################## # das sollte die minimal Frequenz Auflösung sein if max_eod: maximal_nfft = len(global_eod) else: maximal_nfft = nfft_eod eod_fr = get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, maximal_nfft, sampling) try: freq_step_maximal = get_freq_steps(maximal_nfft, sampling) except: print('unclear') embed() else: eod_fr = eodf_orig[mt_nr_small] maximal_nfft = nfft_eod freq_step_maximal = get_freq_steps(maximal_nfft, sampling) freq_steps.append(freq_step_maximal) eodf.append(eod_fr) eodf = np.array(eodf) return eodf, eodf_orig, freq_steps def find_all_dir_cells(): datasets = [] data_dir = [] dirs = ['cells'] # , 'cells_o', 'cells_l', 'cells_gp' for dir in dirs: version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() if version_comp == 'develop': # version_comp == 'develop' dir_path = '../../data/' + dir else: dir_path = '../data/' + dir if os.path.exists(dir_path): list_dir = os.listdir(dir_path + '/')[::-1] for l, list_name in enumerate(list_dir): if os.path.isdir(dir_path + '/' + list_name): if 'noise' not in list_name: # ('invivo' in list_name) and if list_name not in datasets: datasets.append(list_name) data_dir.append(dir + '/') return datasets, data_dir def get_all_nix_names(b, what='Three'): all_mt_names = find_mt_all(b) mt_names = find_mt(b, what) t_names = [] for trials in b.tags: if what in trials.name: t_names.append(trials.name) return all_mt_names, mt_names, t_names def find_mt(b, what): mt_names = [] for t_nr, trials in enumerate(b.multi_tags): if what in trials.name: mt_names.append(trials.name) return mt_names def find_right_dev(devname, devs): dev_nrs = np.arange(len(devname)) dev_nrs = np.array(dev_nrs)[ np.array(devname) == devs[0]] return dev_nrs def load_cells_three(end, data_dir=[], datasets=[]): if end == 'v2_2021-07-06': cells = ['2021-07-06-ag-invivo-1', '2021-07-06-ab-invivo-1', '2021-07-06-ac-invivo-1', '2021-07-06-aa-invivo-1', ] # '2021-06-23-ac-invivo-1', # Das sind glaube ich nochmal vier vom falschen Quadranten elif end == 'v2_2021-07-08': cells = ['2021-07-08-ab-invivo-1', '2021-07-08-aa-invivo-1', '2021-07-08-ac-invivo-1', '2021-07-08-ad-invivo-1'] # das werden nochmal vier sein wo aber nur ein Quadrant dabei ist elif end == 'v2_2021-08-02': cells = ['2021-08-02-ab-invivo-1', '2021-08-02-ac-invivo-1', '2021-08-02-ae-invivo-1'] # das sind drei Zellen wo ich das teilweise mit dem direkt mache elif end == 'v2_2021-08-03': cells = ['2021-08-03-ac-invivo-1', '2021-08-03-af-invivo-1', '2021-08-03-ad-invivo-1'] # das sind zwei Zellen wo ich das auch mit dem direkt mache elif end == 'v2': cells = ['2021-07-08-aa-invivo-1', '2021-07-08-ab-invivo-1', '2021-07-08-ac-invivo-1', '2021-07-08-ad-invivo-1', '2021-08-03-ac-invivo-1', '2021-08-03-af-invivo-1', '2021-08-03-ad-invivo-1', '2021-08-02-ab-invivo-1', '2021-08-02-ac-invivo-1', '2021-08-02-ae-invivo-1', '2021-07-06-ag-invivo-1', '2021-07-06-ab-invivo-1', '2021-07-06-ac-invivo-1', '2021-07-06-aa-invivo-1', ] elif end == 'all': cells = datasets cells = cells[::-1] data_dir = data_dir[::-1] # todo: hier noch anpassen, weil es bei manchen nicht durchgeht!! return data_dir, cells def spikes_for_desired_cells(spikes, data_names=[], names=['intro', 'contrasts', 'bursts', 'ampullary_small', 'model', 'eigen_small', 'eigemania_low_cv', 'eigenmania_high_cv', 'low_cv_punit', 'ampullary', 'bursts_all']): if spikes != '': data_names = find_names_cells(names, data_names) return data_names def find_names_cells(names, data_names=[]): for name in names: try: data_names.extend(p_units_to_show(type_here=name)) except: print('embed thing') embed() return data_names def plt_model_overview2(ax, cells=[], color_special='white', color_all='grey', scores=['perc95_perc5_fr']): a = 0 nr = '2' position = 0 save_names = [load_folder_name('calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_9_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' ] # load_folder_name('calc_model') +'/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_burstIndividual_'] cvs = [] cells_here = [] frames = [] for save_name in save_names: frame = pd.DataFrame() frame = load_model_overview(cells, frame, nr, position, save_name, redo=True) frames.append(frame) cvs.append(frame.cv_stim) cells_here.append(frame.cell) # todo: in der Burst corr version werden das weniger Zellen, schauen warum! for c, cv in enumerate(cvs): for s, save_name in enumerate(save_names): cells_plot2 = p_units_to_show(type_here='model') cells_plot2.extend(["2013-01-08-aa-invivo-1", "2012-12-13-an-invivo-1"]) # burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all frame = frames[s] # load_model_overview(cells, frame, nr, position, save_name) for s, score in enumerate(scores): ax.scatter(cv, frame[score], s=3.5, color=color_all) # , color = color)#, alpha = 0.45 ax.scatter(cv[frame['cell'].isin(cells_plot2)], frame[score][frame['cell'].isin(cells_plot2)], s=5, edgecolor='black', alpha=0.5, color=color_special) # , alpha = 0.45 a += 1 def plt_model_overview(ax, cells=[], scores=['perc95_perc5_fr']): a = 0 nr = '2' position = 0 frame = pd.DataFrame() save_names = [load_folder_name('calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' , load_folder_name( 'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_burstIndividual_'] for save_name in save_names: frame = load_model_overview(cells, frame, nr, position, save_name) for s, score in enumerate(scores): ax[a].scatter(frame['cv'], frame[score]) ax[a].set_ylabel(score) ax[a].set_xlabel('cv') a += 1 def load_model_overview(cells, frame, nr, position, save_name, redo=False): path = save_name + '.pkl' # '../'+ model = load_model_susept(path, cells, save_name, save=False) version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() trials = path.split('TrialsStim_')[1].split('_a_fr_')[0] trials_stim = int(trials) save_name_final = find_load_function() + '_model' + trials + '.csv' try: (not os.path.exists(save_name_final)) | (redo == True) except: print('stil problems') embed() if ((not os.path.exists(save_name_final)) | (redo == True)) & ( version_comp != 'public'): # (version_comp == 'code') | (version_comp == 'develop'): for cell in cells: if len(model) > 0: model_show = model[ (model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] new_keys = model_show.index.unique() # [0:490] try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 c_sig = 1 # todo: doch das stimmt für den Egerland Fall! D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, nr=nr) # var_basedD=D, # doch ich glaube das stimmt schon, ich muss halt aufpassen worüber ich mittel # ok doch das stimmt auch stack_plot = RAM_norm(stack_plot, trials_stim, D_derived) diagonal, frame = get_overview_scores('', frame, stack_plot, position) diag, diagonals_prj_l = get_mat_diagonals(np.array(stack_plot)) frame = signal_to_noise_ratios(diagonals_prj_l, frame, position, '') frame = fill_frame_with_non_float_vals(frame, position, model_show) position += 1 if version_comp == 'develop': frame.to_csv(save_name_final) else: frame = pd.read_csv(save_name_final) return frame def get_overview_scores(add, frame, mat, position): mat = np.array(mat) maximum = np.max(np.array(mat)) minimum = np.min(np.array(mat)) percentiel99 = np.percentile(mat, 99) percentiel90 = np.percentile(mat, 90) percentiel80 = np.percentile(mat, 80) percentiel70 = np.percentile(mat, 70) percentiel10 = np.percentile(mat, 10) percentiel95 = np.percentile(mat, 95) percentiel5 = np.percentile(mat, 5) percentiel1 = np.percentile(mat, 1) frame.loc[position, 'std_mean' + add] = np.std(np.array(mat)) / np.mean(np.array(mat)) frame.loc[position, 'max_min' + add] = (maximum - minimum) / (maximum + minimum) frame.loc[position, 'perc80_perc5' + add] = (percentiel80 - percentiel5) / (percentiel80 + percentiel5) frame.loc[position, 'perc70_perc5' + add] = (percentiel70 - percentiel5) / (percentiel70 + percentiel5) frame.loc[position, 'perc90_perc5' + add] = (percentiel90 - percentiel5) / ( percentiel90 + percentiel5) frame.loc[position, 'perc90_perc10' + add] = (percentiel90 - percentiel10) / ( percentiel90 + percentiel10) frame.loc[position, 'perc95_perc5' + add] = (percentiel95 - percentiel5) / (percentiel95 + percentiel5) frame.loc[position, 'perc99_perc1' + add] = (percentiel99 - percentiel1) / (percentiel99 + percentiel1) test = False if test: from utils_test import test_percentile test_percentile() extra = False if extra: diagonal = mat.diagonal() diagonal_norm = diagonal / np.sum(diagonal) mat_norm = mat / np.sum(diagonal) entropy_mat = scipy.stats.entropy(np.concatenate(mat_norm)) entropy_diagonal = scipy.stats.entropy(diagonal_norm) frame.loc[position, 'entropy_mat' + add] = entropy_mat frame.loc[position, 'entropy_diagonal' + add] = entropy_diagonal else: diagonal = [] return diagonal, frame def fill_frame_with_non_float_vals(frame, position, stack_here): types = list(map(type, stack_here.keys())) keys_else = stack_here.keys()[np.where(np.array(types) != float)] stack_vals = stack_here[keys_else].iloc[0] if 'osf' in stack_vals.keys(): stack_vals.pop('osf') if 'spikes' in stack_vals.keys(): stack_vals.pop('spikes') stack_vals.pop('isf') stack_vals.pop('freqs') if 'freqs_idx' in stack_vals.keys(): stack_vals.pop('freqs_idx') frame.loc[position, list(np.array(stack_vals.keys()))] = stack_vals # stack_else return frame def convert_csv_str_to_float(stack_final): stack_plot = stack_final new_keys = stack_plot.index try: stack_plot = stack_plot[new_keys] except: new_keys = list(map(str, new_keys)) try: stack_plot = stack_plot[new_keys] except: new_keys = np.round(stack_plot.index, 1) new_keys = list(map(str, new_keys)) new_keys = [k + '.0' for k in new_keys] stack_plot = stack_plot[new_keys] print('stack two still not working') embed() stack_plot = stack_plot.astype(complex) stack_plot.columns = list(map(float, stack_plot.columns)) return new_keys, stack_plot def change_model_from_csv_to_plots(model_show): new_keys = model_show.index.unique() # [0:490] try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) return stack_plot def file_names_exlude_func(frame): file_names_there = ['gwn150Hz10s0.3', 'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30' ] frame_file_ex = frame[frame.file_name.isin(file_names_there)] return frame_file_ex def extract_mat(stack_here, keys=[], e=0, sens='', ends_nr=[2000], abs=True, norm=''): if len(keys) < 1: keys = get_float_keys(stack_here) mat_to_div = stack_here[keys[keys < ends_nr[e]]] mat_to_div = mat_to_div.loc[mat_to_div.index < ends_nr[e]] mat_val = ram_norm_choice(mat_to_div, norm, stack_here, abs=abs) # hier adden wir nochmal die Zell Sensitivitäts Bereinigung durch die Var der Spikes if sens != '': spikes_var = stack_here['response_modulation'].iloc[ 0] # stack_here['var_spikes'].iloc[0] / stack_here['snippets'].unique()[0] mat_val = mat_val / spikes_var # todo: das ist eigneltich egal weil die Maße danach beziehen das mit ein! mat = np.array(mat_val) return mat, mat_to_div def ram_norm_choice(mat_to_div, norm, stack_here, abs=True): if 'constnorm' in norm: # _constnorm # const norm mit diesem d_isf1 ist generell falsch # damals haben wir an der NUller vom Power SPectrum geschaut und das als Abschätung # der varianz gemacht und nicht mal über die Power sondern das abs gemittelt deswegen ist das alles falsch mat_val = RAM_norm_data(stack_here['d_isf1'].iloc[0], mat_to_div, stack_here['snippets'].unique()[0]) else: if 'old' in norm: power = 1 else: power = 2 mat_val = RAM_norm_data(stack_here['isf'].iloc[0], mat_to_div, stack_here['snippets'].unique()[0], abs=abs, power=power, stack_here=stack_here) return mat_val def get_mat_diagonals(mat): diagonals = [] shapes = [] diagonals_prj = [] diagonals_prj_l = [] for m in range(len(mat)): # todo: ich glaube ich mache diese Projektion falsch try: diagonals.append(np.diagonal(mat[:, ::-1][m::, :])) # [0:-m, 0:-m]mat[m:, m:] except: print('diagonal thing') embed() diagonals_prj.append(np.mean(np.diagonal(mat[:, ::-1][m::, :]))) diagonals_prj_l.append( np.mean(np.diagonal(np.transpose(mat[:, ::-1])[m::, :]))) shapes.append(np.shape(mat[m:, m:])) diag = np.diagonal(mat) diagonals_prj_l = diagonals_prj_l[::-1] diagonals_prj_l.extend(diagonals_prj) return diag, diagonals_prj_l def axis_projection(mat_val, axis='orig_axis'): if 'orig_axis' in axis: diff_val = np.diff(np.array(mat_val.index))[0] / 2 axis_d = np.arange(mat_val.index[0] - diff_val, mat_val.index[-1] + diff_val, diff_val) else: axis_new = mat_val.index + mat_val.columns diff_val = np.diff(np.array(axis_new))[0] / 2 axis_d = np.arange(axis_new[0] - diff_val, axis_new[-1] + diff_val, diff_val) return axis_d def mod_lims_modulation(cell_type_here, frame_file, score_m, std_est=None): if not std_est: if 'P-Unit' in cell_type_here: mod_limits = np.arange(0, 100, 5) # np.linspace(0,100,11)#np.max(frame_file[score_m]) mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]]) else: mod_limits = np.arange(0, 60, 5) # np.linspace(0,100,11)#np.max(frame_file[score_m]) mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]]) else: nbins = 70 std_estimate, center = hist_threshold(frame_file[score_m][~np.isnan(frame_file[score_m])], thresh_fac=1, nbins=nbins) mod_limits = np.linspace(0, np.median(frame_file[score_m]) + 3 * std_estimate, 30) mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]]) test = False if test: _, _ = hist_threshold() # frame[score][~np.isnan(frame[score])],thresh_fac=1,nbins=nbins return mod_limits def signal_to_noise_ratios(diag_val, frame, position, add): sub_mods = ['', '-med', '-center', '-m'] # , richtig = '' perc_nrs = [99, 99.9] # ,99, 10092.5, 80, 85, 90,98, # die gaussian Werte vom Jan nbins = 70 std_estimate, center = hist_threshold(diag_val, thresh_fac=1, nbins=nbins) std_sigma = np.percentile(diag_val, 84) - np.median(diag_val) std_orig = np.std(diag_val) div_mods = ['', 'med', 'stdthunder'] # , 'stdsigma','stdorig' ,'mean']# richtig = 'med' 'stdthunder2.576', for sub_mod in sub_mods: for perc_nr in perc_nrs: for div_mod in div_mods: percentiel99 = np.percentile(diag_val, perc_nr) frame.loc[position, 'perc' + str(perc_nr) + '_' + add] = percentiel99 frame.loc[position, 'med' + '_' + add] = np.median( diag_val) frame.loc[position, 'stdthunder' + '_' + add] = std_estimate frame.loc[position, 'stdthunder2.576' + '_' + add] = std_estimate * 2.576 frame.loc[ position, 'stdsigma' + '_' + add] = std_sigma frame.loc[ position, 'stdorig' + '_' + add] = std_orig frame.loc[position, 'center' + '_' + add] = center if div_mod == 'med': div = np.median(diag_val) elif div_mod == 'stdthunder': div = std_estimate elif div_mod == 'stdthunder2.576': div = std_estimate * 2.576 elif div_mod == 'stdsigma': div = std_sigma elif div_mod == 'stdorig': div = std_orig elif div_mod == '': div = 1 else: div = np.mean(diag_val) if sub_mod == '-m': sub = div elif sub_mod == '-med': sub = np.median(diag_val) elif sub_mod == '-center': sub = center else: sub = 0 frame.loc[ position, 'perc' + str(perc_nr) + sub_mod + '/' + div_mod + '_' + add] = (percentiel99 - sub) / div return frame def restrict_base_durationts(duration): if duration > 30: duration = 30 else: duration = duration return duration def update_fav_snippet(nfft, fav_snippet=9): return int(np.round( fav_snippet / float(nfft.replace('sec', '')))) def find_cell_cont(redo, cell, frame, saved): if saved == True: try: np.array(frame.cell.unique()) except: print('problem') embed() if cell not in np.array(frame.cell.unique()): cont = True else: cont = False else: cont = True if redo: cont = True return cont def load_frame(redo, name): if redo == False: if os.path.exists(name): if 'csv' in name: try: frame = pd.read_csv(name, index_col=0) except: print('parse thing') embed() else: frame = pd.read_pickle(name) position = len(frame) saved = True else: frame = pd.DataFrame() position = 0 saved = False else: frame = pd.DataFrame() position = 0 saved = False return frame, position, saved def find_common_mod(save_names=[ 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', ]): amps = [] mods_p = [] mods_a = [] mods = [] cell_type_type = 'cell_type_reclassified' for ss, save_name in enumerate(save_names): save_name_here = load_folder_name('calc_RAM') + '/' + save_name + '.csv' frame_load = pd.read_csv(save_name_here, index_col=0) amps.extend(frame_load.amp.unique()) spikes_var = np.sqrt(frame_load['var_spikes'] / frame_load['snippets']) frame_load['modulation'] = spikes_var mods_p.extend(frame_load[frame_load[cell_type_type] == ' P-unit']['modulation']) mods_a.extend(frame_load[frame_load[cell_type_type] == ' Ampullary']['modulation']) mods.extend(frame_load['modulation']) mod_limits_p = np.linspace(0, 1200, 8) mod_limits_p = np.concatenate([mod_limits_p, [np.max(mods)]]) mod_limits_a = np.linspace(0, 500, 8) mod_limits_a = np.concatenate([mod_limits_a, [np.max(mods)]]) return mod_limits_a, mod_limits_p, mods_a, mods_p, frame_load def find_norm_compars(isf, isf_mean, osf, deltat, stack_plot, mean=True): f_range = np.arange(len(stack_plot)) try: _, _, _, _ = fft_matrix(osf[0], f_range, isf[0], norm='') # stimulus, except: print('fmat thing') embed() f_mat1, f_mat2, f_idx_sum, cross_norm = fft_matrix(osf[0], f_range, isf[0], norm='_normPS_') # stimulus, mats_all = [] mats_all_norm = [] scales = [] for t in range(len(osf)): f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osf[t], f_range, isf[t], norm='') # stimulus, f_mat1, f_mat2, f_idx_sum, cross_norm = fft_matrix(osf[t], f_range, isf[t], norm='_normPS_') # stimulus, mats_all_norm.append(cross_norm) mats_all.append(mat_all) scale = find_norm_susept(f_idx_sum, isf[t][f_range]) scales.append(scale) if mean: mats_all_here = np.mean(mats_all, axis=0) else: mats_all_here = np.sum(mats_all, axis=0) mats_all_here_norm = np.mean(mats_all_norm, axis=0) scales = np.mean(scales, axis=0) power_isf_1 = (np.abs(isf_mean[f_range])) power_isf_1 = [power_isf_1] * len(stack_plot) norm_char22 = find_norm_susept(stack_plot, isf_mean[f_range]) norm_char2 = 1 / norm_char22 return scales, cross_norm, f_mat2, mats_all_here, mats_all_here_norm, norm_char2 def overview_model(individual_tag='', many=False, fs=8, hs=0.39, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], var_items=['contrasts'], show=False, contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000], receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], burst_corrs=[''], restrict='restrict'): stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 aa = 0 for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product( burst_corrs, cells, D_extraction_method, external_noise , repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 fig, ax = plt.subplots(2, 4, sharex=True, figsize=(12, 5.5)) # sharey=True,constrained_layout=True,, figsize=(11, 5) plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 ax = np.concatenate(ax) a = 0 iternames = [burst_corrs, D_extraction_method, external_noise, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' position = 0 frame = pd.DataFrame() for all in it.product(*iternames): burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') path = save_name + '.pkl' # '../'+ model = load_model_susept(path, cells, save_name) path_cell_ref = load_folder_name( 'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl' model_sorting = load_model_susept(path_cell_ref, cells, save_name) # ok hier sortiere ich das irgendwie und irgendwas geht dabei schief cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)] cv_sort = True if cv_sort: cells = np.array(np.unique(cells_sorted, return_index=True)[0])[ np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))] for cell in cells: if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5)) stim_type_noise_name = stim_type_noise else: stim_type_noise_name = '' if dendrid == '': dendrid_name = 'standard' else: dendrid_name = dendrid if ref_type == '': ref_type_name = 'standard' else: ref_type_name = dendrid if adapt_type == '': adapt_type_name = 'standard' else: adapt_type_name = adapt_type if len(model) > 0: titles = '' suptitles = '' stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) if 'cells' in var_items: titles += cell[2:13] else: suptitles += cell[2:13] if 'internal_noise' in var_items: titles += ' intrinsic noise=' + stim_type_noise_name2 else: suptitles += ' intrinsic noise=' + stim_type_noise_name2 if 'external_noise' in var_items: titles += ' additive RAM=' + stim_type_afe_name else: suptitles += ' additive RAM=' + stim_type_afe_name if 'repeats' in var_items: titles += ' $N_{repeat}=$' + str(trials_stim) else: suptitles += ' $N_{repeat}=$' + str(trials_stim) if 'contrasts' in var_items: titles += ' contrast=' + str(a_fe) else: suptitles += ' contrast=' + str(a_fe) if 'level_extraction' in var_items: titles += ' Extract Level=' + str(extract) else: suptitles += ' Extract Level=' + str(extract) if 'D_extraction_method' in var_items: titles += str(var_type) else: suptitles += str(var_type) if 'noises_added' in var_items: titles += ' high freq noise=' + str(noise_added) else: suptitles += ' high freq noise=' + str(noise_added) model_show = model[ (model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] new_keys = model_show.index.unique() # [0:490] try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength # (noise_strength ** 2) / 2 D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based stack_plot = RAM_norm(stack_plot, trials_stim, D_derived) if many == True: titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0])) color = title_color(cell) print(color) diagonal, frame = get_overview_scores('', frame, stack_plot, position) frame = fill_frame_with_non_float_vals(frame, position, model_show) position += 1 else: print('len problem2') embed() else: print('len problem') embed() a += 1 scores = ['std_mean', 'max_min', 'perc80_perc5', 'perc70_perc5', 'perc90_perc10', 'perc95_perc5', 'entropy_mat', 'entropy_diagonal'] for s, score in enumerate(scores): ax[s].scatter(frame['cv'], frame[score]) ax[s].set_title(score) end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title + titles + '\n $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str( int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str( np.round(model_show.cv.iloc[0], 2)) + \ ' $cv_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) + '\n $D_{sig}$=' + str( np.round(D_derived, 5)) + ' s=' + str( np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color) # +' file ' save_visualization(individual_tag=individual_tag, pdf=True, show=show) def overview_model_trials(individual_tag='', many=False, row='no', fs=8, hs=0.39, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=8, var_items=['contrasts'], show=False, contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], scores=['std_mean', 'max_min', 'perc80_perc5', 'perc70_perc5', 'perc90_perc10', 'perc95_perc5', 'entropy_mat', 'entropy_diagonal'] , level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000], receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], burst_corrs=[''], restrict='restrict'): stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 aa = 0 for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product( burst_corrs, cells, D_extraction_method, external_noise , repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 if row == 'no': col, row = find_row_col(np.arange(aa * 8 / len(cells)), col=col_desired) # np.arange( else: col = col_desired fig, ax = plt.subplots(row, col, sharex=True, figsize=(12, 5.5)) # sharey=True,constrained_layout=True,, figsize=(11, 5) plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 a = 0 iternames = [burst_corrs, D_extraction_method, external_noise, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' position = 0 frame = pd.DataFrame() for all in it.product(*iternames): burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') path = save_name + '.pkl' # '../'+ model = load_model_susept(path, cells, save_name) path_cell_ref = load_folder_name( 'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl' model_sorting = load_model_susept(path_cell_ref, cells, save_name) # ok hier sortiere ich das irgendwie und irgendwas geht dabei schief cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)] cv_sort = True if cv_sort: cells = np.array(np.unique(cells_sorted, return_index=True)[0])[ np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))] for cell in cells: if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5)) stim_type_noise_name = stim_type_noise else: stim_type_noise_name = '' if dendrid == '': dendrid_name = 'standard' else: dendrid_name = dendrid if ref_type == '': ref_type_name = 'standard' else: ref_type_name = dendrid if adapt_type == '': adapt_type_name = 'standard' else: adapt_type_name = adapt_type if len(model) > 0: titles = '' suptitles = '' stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) if 'cells' in var_items: titles += cell[2:13] else: suptitles += cell[2:13] if 'internal_noise' in var_items: titles += ' intrinsic noise=' + stim_type_noise_name2 else: suptitles += ' intrinsic noise=' + stim_type_noise_name2 if 'external_noise' in var_items: titles += ' additive RAM=' + stim_type_afe_name else: suptitles += ' additive RAM=' + stim_type_afe_name if 'repeats' in var_items: titles += ' $N_{repeat}=$' + str(trials_stim) else: suptitles += ' $N_{repeat}=$' + str(trials_stim) if 'contrasts' in var_items: titles += ' contrast=' + str(a_fe) else: suptitles += ' contrast=' + str(a_fe) if 'level_extraction' in var_items: titles += ' Extract Level=' + str(extract) else: suptitles += ' Extract Level=' + str(extract) if 'D_extraction_method' in var_items: titles += str(var_type) else: suptitles += str(var_type) if 'noises_added' in var_items: titles += ' high freq noise=' + str(noise_added) else: suptitles += ' high freq noise=' + str(noise_added) model_show = model[ (model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] new_keys = model_show.index.unique() # [0:490] try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten stack_plot = model_show[list(map(str, new_keys))] except: stack_plot = model_show[new_keys] stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)] stack_plot.columns = list(map(float, stack_plot.columns)) model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength # (noise_strength ** 2) / 2 D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based stack_plot = RAM_norm(stack_plot, trials_stim, D_derived) if many == True: titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0])) color = title_color(cell) print(color) diagonal, frame = get_overview_scores('', frame, stack_plot, position) frame = fill_frame_with_non_float_vals(frame, position, model_show) position += 1 else: print('len problem2') embed() else: print('len problem') embed() for s, score in enumerate(scores): ax[a, s].scatter(frame['cv'], frame[score]) ax[0, s].set_title(score) ax[-1, s].set_xlabel('cv') ax[a, 0].text(0, 1.2, titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str( int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz $cv_{B}$=' + str( np.round(model_show.cv.iloc[0], 2)) + \ ' $cv_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) + ' $D_{sig}$=' + str( np.round(D_derived, 5)) + ' s=' + str( np.round(model_show.ser_sum_stim.iloc[0], 2)), transform=ax[a, 0].transAxes, ) a += 1 end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title, fontsize=fs, color=color) # +' file ' save_visualization(individual_tag=individual_tag, pdf=True, show=show) def model_cells(individual_tag='', nr_clim=10, many=False, width=0.02, row='no', HZ50=True, fs=8, hs=0.39, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'], show=False, contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000], receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], burst_corrs=[''], clims='all', restrict='restrict', label=r'$\frac{1}{mV^2S}$'): stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 aa = 0 for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product( burst_corrs, cells, D_extraction_method, external_noise , repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ): aa += 1 if row == 'no': col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange( else: col = col_desired if row == 2: default_settings(column=2, length=7.5) # 2+2.25+2.25 elif row == 1: default_settings(column=2, length=4) fig, ax = plt.subplots(row, col, sharex=True, sharey=True) # constrained_layout=True,, figsize=(11, 5) if row == 2: plt.subplots_adjust(bottom=0.067, wspace=0.45, top=0.81, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 elif row == 1: plt.subplots_adjust(bottom=0.1, wspace=0.45, top=0.81, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 else: plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88, left=0.075) # , hspace = 0.6, wspace = 0.5 if row != 1: ax = np.concatenate(ax) a = 0 maxs = [] mins = [] ims = [] perc05 = [] perc95 = [] iternames = [burst_corrs, D_extraction_method, external_noise, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' for all in it.product(*iternames): burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim, power, cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit, mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added, c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type, var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr, trials_nr=trial_nrs, trans=trans, zeros='ones') path = save_name + '.pkl' # '../'+ model = load_model_susept(path, cells, save_name) path_cell_ref = load_folder_name( 'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl' model_sorting = load_model_susept(path_cell_ref, cells, save_name) cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)] cv_sort = True if cv_sort: cells = np.array(np.unique(cells_sorted, return_index=True)[0])[ np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))] for cell in cells: if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5)) stim_type_noise_name = stim_type_noise else: stim_type_noise_name = '' if dendrid == '': dendrid_name = 'standard' else: dendrid_name = dendrid if ref_type == '': ref_type_name = 'standard' else: ref_type_name = dendrid if adapt_type == '': adapt_type_name = 'standard' else: adapt_type_name = adapt_type if len(model) > 0: titles = '' suptitles = '' stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name) if 'cells' in var_items: titles += cell[2:13] else: suptitles += cell[2:13] if 'internal_noise' in var_items: titles += ' intrinsic noise=' + stim_type_noise_name2 else: suptitles += ' intrinsic noise=' + stim_type_noise_name2 if 'external_noise' in var_items: titles += ' additive RAM=' + stim_type_afe_name else: suptitles += ' additive RAM=' + stim_type_afe_name if 'repeats' in var_items: titles += ' $N_{repeat}=$' + str(trials_stim) else: suptitles += ' $N_{repeat}=$' + str(trials_stim) if 'contrasts' in var_items: titles += ' contrast=' + str(a_fe) else: suptitles += ' contrast=' + str(a_fe) if 'level_extraction' in var_items: titles += ' Extract Level=' + str(extract) else: suptitles += ' Extract Level=' + str(extract) if 'D_extraction_method' in var_items: titles += str(var_type) else: suptitles += str(var_type) if 'noises_added' in var_items: titles += ' high freq noise=' + str(noise_added) else: suptitles += ' high freq noise=' + str(noise_added) model_show = model[ (model.cell == cell)] stack_plot = change_model_from_csv_to_plots(model_show) ax[a].set_xlim(0, 300) ax[a].set_ylim(0, 300) ax[a].set_aspect('equal') model_cells = resave_small_files("models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] if len(model_show) > 0: noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength # (noise_strength ** 2) / 2 D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based stack_plot = RAM_norm(stack_plot, trials_stim, D_derived) if many == True: titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0])) color = title_color(cell) print(color) ax[a].set_title( titles + '\n $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str( int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str( np.round(model_show.cv.iloc[0], 2)) + \ ' $cv_{S}$=' + str( np.round(model_show.cv_stim.iloc[0], 2)) + '\n $D_{sig}$=' + str( np.round(D_derived, 5)) + ' s=' + str( np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color) perc = '' # 'perc' im = plt_RAM_perc(ax[a], perc, stack_plot) ims.append(im) maxs.append(np.max(np.array(stack_plot))) mins.append(np.min(np.array(stack_plot))) perc05.append(np.percentile(stack_plot, 5)) perc95.append(np.percentile(stack_plot, 95)) plt_triangle(ax[a], model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300, model_show.eod_fr.iloc[0]) if HZ50: plt_50_Hz_noise(ax[a], 300) ax[a].set_aspect('equal') cbar = colorbar_outside(ax[a], im, fig, add=0, width=width) if many == False: cbar[0].set_label(label, labelpad=100) # rotation=270, else: if a in np.arange(col - 1, 100, col): cbar[0].set_label(label, labelpad=100) if a >= row * col - col: ax[a].set_xlabel(F1_xlabel(), labelpad=20) ax[0].set_ylabel(F2_xlabel()) if a in np.arange(0, 10, 1) * col: ax[a].set_ylabel(F2_xlabel()) else: print('len problem2') embed() else: print('len problem') embed() a += 1 end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # end_name = cut_title(end_name, datapoints=120) name_title = end_name plt.suptitle(name_title) # +' file ' set_clim_shared(clims, ims, maxs, mins, nr_clim, perc05, perc95) save_visualization(individual_tag=individual_tag, pdf=True, show=show) def plt_punit(amp_desired=[0.5, 1, 5], xlim=[], cells_plot2=[], show=False, annotate=False): plot_style() default_settings(column=2, width=12, length=8) # ts=10, fs=10, ls=10, save_names = ['noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr'] amps_desired = amp_desired cell_type_type = 'cell_type_reclassified' frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type) cells_plot = cells_plot2 grid = gridspec.GridSpec(2, 1, wspace=0.1, height_ratios=[1, 4.5], hspace=0.25, top=0.96, left=0.095, bottom=0.07, right=0.92) colors = {'unkown': 'grey', ' P-unit': 'blue', ' Ampullary': 'green', 'nan': 'grey', ' T-unit': 'purple', ' E-cell': 'red', ' Pyramidal': 'darkred', ' I-cell': 'pink', ' E-cell superficial': 'orange', ' Ovoid': 'cyan'} grid2 = gridspec.GridSpecFromSubplotSpec(1, 4, grid[0], wspace=0.5, hspace=0.2) cell_types = [' P-unit', ' Ampullary'] ax0, ax1, ax2 = plt_scatter_three2(grid2, frame, cell_type_type, annotate, colors) ax3 = plt.subplot(grid2[3]) axs = [ax3] burst_name = ['', ' burst corr '] save_names1 = [save_names[0]] # todo: um das mit dem burst cv aus der baseline zu machen muss man das auch aus dem baseline file laden for s, save_name in enumerate(save_names1): load_name = load_folder_name('calc_RAM') + '/' + save_name + '.csv' for c, cell_type_it in enumerate(cell_types): frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type_it) axs[s].scatter(np.array(frame_g['cv']), np.array(frame_g['cv_stim']), alpha=0.5, s=7, color=colors[str(cell_type_it)]) axs[s].set_xlim(0, 1.5) axs[s].set_ylim(0, 1.5) axs[s].set_ylabel('CV stim ' + burst_name[s]) axs[s].set_xlabel('CV ' + burst_name[s]) if len(amps_desired) + 1 == 4: wr = [0.5, 0, 1, 1, 1] elif len(amps_desired) + 1 == 3: wr = [0.5, 0, 1, 1] elif len(amps_desired) + 1 == 5: wr = [0.5, 0, 1, 1, 1, 1] grid1 = gridspec.GridSpecFromSubplotSpec(len(cells_plot), len(amps_desired) + 2, grid[1], hspace=0.17, wspace=0.35, width_ratios=wr) # , plt_cellbody_punitsingle(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=ax3, xlim=xlim, plus=2, burst_corr='_burst_corr_individual') save_visualization(pdf=True) show_func(show=show) def plt_sqaure_isf2(grid1, ax0, ax1, ax2, b, cell, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, labeloff=True, predefined_amps2=False, norm=False): print(cell) frame_cell = frame[(frame['cell'] == cell)] frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type) cell_type = frame_cell[cell_type_type].iloc[0] fr = frame_cell.fr.iloc[0] cv = frame_cell.cv.iloc[0] eod_fr = frame_cell.EODf.iloc[0] # das ist der title fals der square nicht plottet plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz', fontsize=11, ) # cell[0:13] + color=color+ cell_type load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell im = [] axs = [] if os.path.exists(load_name + '.pkl'): im, axs = plt_stack_single(cell_type, load_name, b, cells_plot, norm, cell, amps_desired, labeloff, grid1, eod_fr, save_names, predefined_amps2) ################################ # do the scatter of these cells add = ['', '_burst_corr', ] if ax0 != []: ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if ax1 != []: ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') if ax2 != []: ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1, label=cell_type, s=15, color=colors[str(cell_type)], facecolor='white') return im, axs def hist_part2(axi, cell_type, burst_corr, colors, cell, spikes, eod_fr, ): spikes_all, hists, frs_calc, spikes_cont = load_spikes(spikes, eod_fr) if 'burst' in burst_corr: lim_here = find_lim_here(cell, burst_corr) print(lim_here) if np.min(np.concatenate(hists)) < lim_here: hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all), [eod_fr] * len(spikes_all), lim=lim_here, burst_corr=burst_corr) hists_both = [hists, hists2] else: hists_both = [hists] else: hists_both = [hists] if len(hists_both) > 1: colors_hist = ['grey', colors[str(cell_type)]] else: colors_hist = [colors[str(cell_type)]] for gg in range(len(hists_both)): hists_here = hists_both[gg] plt_hist2(axi, hists_here, colors_hist, gg) def vals_modulation(): pass def plt_power3(spikes_all_here, axp, color='blue', only_one=False): spikes_mat = [[]] * len(spikes_all_here) sampling_calc = 40000 nfft = 2 ** 14 p_array = [[]] * len(spikes_all_here) f_array = [] if only_one: one = [spikes_all_here[0]] else: one = spikes_all_here for s, sp in enumerate(one): if len(sp) > 0: try: spikes_mat[s] = cr_spikes_mat(np.array(sp) / 1000, sampling_rate=sampling_calc, length=int(sampling_calc * np.array(sp[-1]) / 1000)) except: print('spikes_mat[s] =') embed() p_array[s], f_array = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=sampling_calc, NFFT=nfft, noverlap=nfft // 2) axp.plot(f_array, p_array[s], color=color) # alpha=float(alpha - 0.05 * s) color=colors[str(cell_type)], axp.set_xlim(0, 1000) axp.set_xlabel('Hz') axp.set_ylabel('Hz') return p_array, f_array def plt_hist2(axi, hists_here, colors_hist, gg): if len(hists_here) > 0: h = np.concatenate(hists_here) axi.hist(h, bins=100, color=colors_hist[gg], label='CV ' + str(np.round(np.std(h) / np.mean(h), 3)), alpha=0.7) # float(alpha - 0.05 * (hh)) def plt_stack_single(cell_type, load_name, b, cells_plot, norm, cell, amps_desired, labeloff, grid1, eod_fr, save_names, predefined_amps2): im = [] axs = [] stack = pd.read_pickle(load_name + '.pkl') if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']: file_names_exclude = punit_file_exclude() # else: file_names_exclude = ampullary_file_exclude() # files = stack['file_name'].unique() fexclude = False if fexclude: if len(files) > 1: stack = stack[~stack['file_name'].isin(file_names_exclude)] files = stack['file_name'].unique() amps = stack['amp'].unique() if predefined_amps2: for a, amp in enumerate(amps): if amp not in amps_desired: pass amps_defined = [np.min(amps)] stack_file = stack[stack['file_name'] == files[0]] for a, amp in enumerate(amps_defined): if amp in np.array(stack_file['amp']): stack_amp = stack_file[stack_file['amp'] == amp] lengths = stack_file['stimulus_length'].unique() length = np.max(lengths) stack_final = stack_amp[stack_amp['stimulus_length'] == length] trial_nr_double = stack_final.trial_nr.unique() # ok das ist glaube ich ein Anzeichen von einem Fehler if len(trial_nr_double) > 1: print('trial_nr_double') try: stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)] except: print('stack_final1 problem') embed() axs = plt.subplot(grid1[2]) osf = stack_final1.osf isf = stack_final1.isf im, min_lim, max_lim = square_func([axs], stack_final1, norm=norm) plt.colorbar(im, ax=axs) ax_pos = np.array(axs.get_position()) # [[xmin, ymin], [xmax, ymax]]. fr = stack_final1.fr.unique()[0] snippets = stack_final1['snippets'].unique()[0] cv = stack_final1.cv.unique()[0] ser = stack_final1.ser.unique()[0] cv_stim = stack_final1.cv_stim.unique()[0] fr_stim = stack_final1.fr_stim.unique()[0] ser_stim = stack_final1.ser_stim.unique()[0] plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + '' + 'S.Nr ' + str( snippets) + ' % ' + ' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str( np.round(fr)) + ' Hz' + ' ser ' + str(np.round(ser)) + ' Stim: cv ' + str(np.round(cv_stim, 2)) + ' fr ' + str( np.round(fr_stim)) + ' Hz' + ' ser ' + str(np.round(ser_stim)) + ' length ' + str(length) , fontsize=11, ) # cell[0:13] + color=color+ cell_type eod_fr_half_color = 'purple' fr_color = 'red' eod_fr_color = 'magenta' fr_stim_color = 'darkred' if labeloff: if b != len(cells_plot) - 1: remove_xticks(axs) axs.set_xlabel('') # plot the input above axs2 = plt.subplot(grid1[1]) ax_pos2 = np.array(axs2.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width axs2.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0], ax_pos2[1][1] - ax_pos2[0][1]]) clip_on = True freqs = [fr, fr * 2, fr_stim, fr_stim * 2, eod_fr, eod_fr * 2, eod_fr / 2] colors_f = [fr_color, fr_color, fr_stim_color, fr_stim_color, eod_fr_color, eod_fr_color, eod_fr_half_color] plt_isf_ps_red(stack_final1, isf, 0, axs2, freqs=freqs, colors=colors_f, clip_on=clip_on) axs2.set_xlim(min_lim, max_lim) remove_xticks(axs2) axs1 = plt.subplot(grid1[0]) if '2.5' in save_names[0]: burst_name = '2.5 EOD burst corr' elif 'Individual' in save_names[0]: burst_name = 'individual burst corr' elif 'burst' in save_names[0]: burst_name = '1.5 EOD burst corr' else: burst_name = '' axs1.set_title(' std ' + str(amp) + ' ' + burst_name) # + files[0] + '\n' + names) remove_xticks(axs1) ax_pos2 = np.array(axs1.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width axs1.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0], ax_pos2[1][1] - ax_pos2[0][1]]) plt_isf_ps_red(stack_final1, osf, 0, axs1, freqs=freqs, colors=colors_f, clip_on=clip_on) axs1.set_xlim(min_lim, max_lim) return im, axs def fft_matrix(osf, f_range, isf, norm='', quadrant=''): # stimulus, # frequencies xaxis f_mat1 = [f_range] * len(f_range) # freqeuncies yxis f_mat2 = np.transpose(f_mat1) # sum frequency f_idx_sum = f_mat1 + f_mat2 # diff frequency f_idx_diff = f_mat1 - f_mat2 rate_matrix1, rate_matrix2 = find_isf_matrices(f_idx_sum, isf[f_range]) scale = find_norm_susept(f_idx_sum, isf[f_range]) rate_matrix = [[]] * len(f_idx_sum) cross = [[]] * len(f_idx_sum) osf_mat = [[]] * len(f_idx_sum) suscept_nonlin = [[]] * len(f_idx_sum) abs_result = [[]] * len(f_idx_sum) test = False if test: c = isf[f_range][0] # abs einer complexen Zahl berechnet den pythagoras aufgezogen in dem Raum np.abs(c) np.sqrt(np.real(c) ** 2 + np.imag(c) ** 2) for ff in range(len(f_idx_sum)): rate_matrix[ff] = osf[f_idx_sum[ff]] if quadrant == '': if norm != '': cross[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] * scale[ff] else: cross[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] # *scale else: if norm != '': cross[ff] = np.conj(osf[np.abs(f_idx_diff[ff])]) * np.conj(rate_matrix1[ff]) * rate_matrix2[ff] * scale[ ff] else: cross[ff] = np.conj(osf[np.abs(f_idx_diff[ff])]) * np.conj(rate_matrix1[ff]) * rate_matrix2[ ff] # *scale # hier mache ich quasi die conjunktion des zweiten Arguments weg # todo: hier das norm einbauen! suscept_nonlin[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] * scale[ff] test = False if test: fig, ax = plt.subplots(1, 3) ax[0].pcolormesh(np.abs(cross)) ax[1].pcolormesh(abs_result) ax[2].pcolormesh(np.abs(osf_mat)) plt.show() return np.array(f_mat1), np.array(f_mat2), np.array(f_idx_sum), np.array(cross) def exclude_nans_for_corr(file_here, var_item, x=[], y=[], max_x=None, cv_name='cv_base', score='perc99/med'): if len(x) == 0: x = file_here[cv_name] if len(y) == 0: y = file_here[score] c_axis = file_here[var_item] exclude_here = exclude_nans(x, y) x = x[~exclude_here] y = y[~exclude_here] c_axis = c_axis[~exclude_here] if max_x: if np.sum(x > max_x) > 0: y = y[x < max_x] try: c_axis = c_axis.loc[x < max_x] except: print('c something') embed() x = x[x < max_x] return c_axis, x, y, exclude_here def exclude_nans(x, y): exclude_here = (np.isnan(x)) | (np.isnan(y)) | (np.isinf(x)) | (np.isinf(y)) return exclude_here def fav_calc_RAM_cell_sorting(save_names_load=[ 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s'], base_sorted='base_ram_sorted', sorted_cv='cv_base', redo=False, redo_base=False, cell_types_sort=[' P-unit', ' Ampullary', ' P-unit_problem', 'unkown', ' unknown_problem', ' Ampullary_problem', 'unkown', ' Pyramidal', ' T-unit']): cell_sorted = 'cv_cell_sorted' # ''#'cv_cell_sorted'#''#'cell_sorted' if 'cell_sorted' in cell_sorted: cell_type_type = 'cell_type_reclassified' for s, save in enumerate(save_names_load): if 'calc_RAM_overview-_simplified_' not in save: save_names_load[s] = 'calc_RAM_overview-_simplified_' + save data_names, frame, cell_types = sort_cells_base(small_cvs_first=True, name='calc_base_data-base_frame_overview.pkl', cell_sorted=cell_sorted, cell_type_type=cell_type_type, save_names=save_names_load, sorted_cv=sorted_cv, base_sorted=base_sorted, cell_type_sort=cell_types_sort, gwn_filtered=True, redo=redo, redo_base=redo_base) return data_names def version_final(): save_name = 'noise_data12_nfft0.5sec_original__StimPreSaved4__first1_order_' return save_name def find_stimuli(b): names = [] for t in b.tags: if 'filestimulus' in t.name.lower(): names.append(t.name) return names def pearson_label(corr, p_value, y, n=True): if n: n_add = ', $n=%s$' % (len(y)) else: n_add = '' if p_value < 0.001: p_name = ', $p<0.001$' # *** elif p_value < 0.01: p_name = ', $p<0.01$' # ** elif p_value < 0.05: p_name = ', $p=%s$' % (np.round(p_value, 2)) # + '*' else: p_name = ', $p=%s$' % (np.round(p_value, 2)) if np.abs(corr) < 0.01: add = np.round( corr, 3) else: add = np.round( corr, 2) return ' $r=%s$' % add + p_name + n_add def chose_class_cells(cell_types_sort=[' P-unit_problem', 'unkown', ' unknown_problem', ' Ampullary_problem', 'unkown', ' Pyramidal', ' T-unit', ' P-unit', ' Ampullary', ]): cell_type_type = 'cell_type_reclassified' cell_sorted = 'cv_cell_sorted' # ''#'cv_cell_sorted'#''#'cell_sorted' if 'cell_sorted' in cell_sorted: save_names_load = [ 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s'] # noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s base_sorted = 'stim_sorted' sorted_cv = 'cv_base' _, _, _ = sort_cells_base(small_cvs_first=True, cell_sorted=cell_sorted, cell_type_type=cell_type_type, save_names=save_names_load, sorted_cv=sorted_cv, base_sorted=base_sorted, cell_type_sort=cell_types_sort, gwn_filtered=True) else: _, _ = find_all_dir_cells() # wir brauchen eine Zelle die das nix hat (neue Zelle) und eine wo wir die RAM Datei kopiert haben def kernel_scatter(axy, axx, axs, c, cell_type_here, colors, cv_name, frame_file, score, xmin='no', alpha=1, ymin='no', color_given=None, n=True, log=True): ########################### # version comparison with all cells, and no modulation if not color_given: color_given = colors[str(cell_type_here)] x_axis = plot_kernels_on_side(axx, axy, color_given, cv_name, frame_file, score, xmin=xmin, ymin=ymin) # todo: hier noch das andere seiteliche histogram machen # if 'Ampullary' in cell_type_here: # embed() x_axis = plt_overview_scatter(axs, c, cell_type_here, colors, cv_name, frame_file, score, alpha=alpha, n=n, color_text=color_given, color_given=color_given) if log: axy.set_yscale('log') axs.set_yscale('log') axy.set_yticks_blank() axy.minorticks_off() join_x([axs, axx]) join_y([axy, axs]) if log: make_log_ticks([axy, axs]) remove_yticks(axy) axy.minorticks_off() return axs, x_axis def plot_kernels_on_side(ax_x, ax_y, color, cv_name, frame_file, score, step_y=0, xmin='no', ymin='no', step_x=0, ymax='no', xlim=None): x_axis, y_axis = get_axis(cv_name, frame_file, score) if xlim: x_axis = x_axis[x_axis < xlim[1]] kernel_histogram(ax_x, color, np.array(x_axis), xmin=xmin, norm='density', step=step_x, alpha=0.5) # step_x = 0.03 ax_x.show_spines('b') remove_yticks(ax_x) remove_xticks(ax_x) test = False if test: from utils_test import test_kernel test_kernel() kernel_histogram(ax_y, color, np.array(y_axis), orientation='vertical', norm=True, step=step_y, alpha=0.5, xmin=ymin, xmax=ymax) ax_y.set_yticks_blank() ax_y.show_spines('l') remove_yticks(ax_y) remove_xticks(ax_y) return x_axis def plt_albi(ax, cell_type_here, colors, max_val, species, x_axis, y_axis): try: ax.scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val], alpha=1, s=2.5, color=colors[ str(cell_type_here)], clip_on=False) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, ax.axhline(2.576, color='grey', linestyle='--', linewidth=1) ax.set_title(species) ax.set_yscale('log') except: print('axs thing3') embed() def plt_eigen(cv_name, ax, c, cell_type_here, cells_extra, colors, frame_file, max_val, score, species): x_axis, y_axis = get_axis(cv_name, frame_file, score) x = x_axis[x_axis < max_val] y = y_axis[x_axis < max_val] try: ax.scatter(x, y, alpha=1, s=2.5, color=colors[ str(cell_type_here)], label='r=' + str(np.round(np.corrcoef(x, y)[0][1], 2)) + ' n=' + str( len(y)), clip_on=False) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, ax.set_title(species) ax.set_yscale('log') ax.axhline(2.576, color='grey', linestyle='--', linewidth=1) if c == 1: ax.legend() except: print('axs thing2') embed() if cell_type_here == ' P-unit': cells_plot2 = p_units_to_show(type_here='eigen_small')[1::] else: cells_plot2 = [p_units_to_show(type_here='eigen_small')[0]] # for cell_plt in cells_plot2: try: cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index except: print('cells extra here') embed() ax.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra], alpha=1, s=2.5, color=colors[ str(cell_type_here)], clip_on=False, marker='D', edgecolor='black') def plt_overview_scatter(ax, c, cell_type_here, colors, cv_name, frame_file, score, x_pos=0, labelpad='no', n=True, alpha=1, color_text='black', legend_spacing = 0.1, y_val=0.9, fs=7.5, ms=2.5, color_given=None, ha='left'): if not color_given: color_given = colors[str(cell_type_here)] x_axis, y_axis = x_axis_wo_c(cv_name, frame_file, score) try: x = x_axis # [x_axis < max_val] y = y_axis # [x_axis < max_val] ax.scatter(x, y, alpha=alpha, s=ms, color=color_given, clip_on=False) ##, label=corr0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker, print(' mean(' + str(cv_name) + str(np.mean(x)) + ') ' + ' mean(' + str(score) + str(np.mean(y)) + ') ') except: print('axs thing1') embed() legend_wo_dot(ax, y_val - legend_spacing * c, x, y, ha=ha, color=color_text, fs=fs, x_pos=x_pos, n=n) if type(labelpad) != str: ax.set_xlabel(cv_name, labelpad=labelpad) else: ax.set_xlabel(cv_name) return x_axis def x_axis_wo_c(cv_name, frame_file, score): x_axis, y_axis = get_axis(cv_name, frame_file, score) exclude_here = exclude_nans(x_axis, y_axis) x_axis = x_axis[~exclude_here] y_axis = y_axis[~exclude_here] return x_axis, y_axis def legend_wo_dot(ax, y_pos, x, y, color='black', x_pos=0.5, ha='left', n=True, fs=7.5): # , ha = 'right' corr, p_value = stats.pearsonr(x, y) pears_l = pearson_label(corr, p_value, y, n=n) ax.text(x_pos, y_pos, pears_l, fontsize=fs, color=color, transform=ax.transAxes, ha=ha) # ha="left", va="top",corr def get_axis(cv_name, frame_file, score): cvs = frame_file[cv_name] # x_axis = cvs[frame_file[score] > 0] y_axis = np.array(frame_file[score])[frame_file[score] > 0] return x_axis, y_axis def scatter_with_marginals_colorcoded(var_item_name, ax, cell_type_here, cv_name, frame_file, score, axl=None, axk=None, ymin='no', xmin='no', ymax='no', top=False, burst_fraction_reset='burst_fraction_burst_corr_individual_base', var_item='response_modulation', labelpad=0, max_x=None, n=True, xlim=None, x_pos=0, fs=7.5, ms=2.5, c=0, burst_fraction=1, sides=True, color_text='black', ha='left', y_val=0.9, cmap_required=True, color_given=None, cbar_labelpad=0, legend_spacing=0.1): ## # function to plot scatter plot, with marignal distributions and colorbar, all optional cmap = [] x_axis = [] y_axis = [] if len(frame_file) > 0: if cmap_required: # pay attention if the cell type is not a cell but a fish this is not working anymore mod_limits = mod_lims_modulation(cell_type_here, frame_file, score) if cell_type_here == ' P-unit': cm = 'coolwarm' # 'Blues' # else: cm = 'coolwarm' # 'Greens' cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[ ::-1] # len(amps) cmap = cmap[0:len(mod_limits)][::-1] frame_file = frame_file[frame_file[burst_fraction_reset] < burst_fraction] colors = colors_overview() if not color_given: color_given = colors[cell_type_here] if sides: x_axis = plot_kernels_on_side(axk, axl, color_given, cv_name, frame_file, score, xmin=xmin, ymin=ymin, ymax=ymax, xlim=xlim) if var_item != '': c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, var_item, cv_name=cv_name, score=score, max_x=max_x) if len(x_axis) > 0: im = ax.scatter(x_axis, y_axis, alpha=1, s=2.5, c=c_axis, clip_on=True, cmap=cm) # color=cmap[ # ax.set_aspect('equal') # ff]) corr = 'r=' + str(np.round(np.corrcoef(x_axis, y_axis)[0][1], 2)) # if c == 1: # legend_wo_dot(ax, 0.9 - legend_spacing * c, x_axis, y_axis, ha=ha, x_pos=x_pos, n=n) if top: ax_tag = ax else: ax_tag = axl cbar, left, bottom, width, height = colorbar_outside(ax_tag, im, plt.gcf(), width=0.01, pos_axis='top', orientation='bottom', add=cbar_labelpad, top=top) if 'burst_fraction' in cv_name: ax.set_xlim(0, 1.01) ax.set_xticks_delta(0.5) if 'burst_fraction' in score: ax.set_ylim(0, 1.01) ax.set_yticks_delta(0.5) if 'burst_fraction' in var_item: val_chosen = 1 else: val_chosen = None set_clim_same([im], clims='', val_chosen=val_chosen, lim_type='up', nr_clim='None') cbar.set_label(var_item_name) ##+cell_type_here rotation=270,, labelpad=100 else: colors = colors_overview() x_axis = plt_overview_scatter(ax, c, cell_type_here, colors, cv_name, frame_file, score, x_pos=x_pos, ha=ha, labelpad=labelpad, y_val=y_val, ms=ms, fs=fs, color_text=color_text, color_given=color_given, n=n, legend_spacing = legend_spacing) if axl: axl.get_shared_y_axes().join(*[ax, axl]) axl.show_spines('') if axk: axk.get_shared_x_axes().join(*[ax, axk]) axk.show_spines('') return cmap, x_axis, y_axis def plt_burst_modulation(var_item_name, ax, cell_type_here, cv_name, frame_file, score, var_item='response_modulation'): mod_limits = mod_lims_modulation(cell_type_here, frame_file, score) if cell_type_here == ' P-unit': cm = 'coolwarm' # 'Blues' # else: cm = 'coolwarm' # 'Greens' cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[ ::-1] # len(amps) cmap = cmap[0:len(mod_limits)][::-1] c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, var_item, cv_name=cv_name, score=score) if len(x_axis) > 0: im = ax.scatter(x_axis, y_axis, alpha=1, s=2.5, c=c_axis, clip_on=False, cmap=cm) # color=cmap[ legend_wo_dot(ax, 0.9, x_axis, y_axis, x_pos=0) cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal" cbar.set_label(var_item_name + '\n' + cell_type_here) # rotation=270,, labelpad=100 return cmap, x_axis, y_axis def plt_modulation_overview(ax, cell_type_here, cv_name, frame_file, score, species): mod_limits = mod_lims_modulation(cell_type_here, frame_file, score) if cell_type_here == ' P-unit': cm = 'coolwarm' # 'Blues' # else: cm = 'coolwarm' # 'Greens' cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[ ::-1] # len(amps) cmap = cmap[0:len(mod_limits)][::-1] c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, 'response_modulation', cv_name=cv_name, score=score) if len(x_axis) > 0: ax.set_yscale('log') im = ax.scatter(x_axis, y_axis, alpha=1, s=2.5, c=c_axis, clip_on=False, cmap=cm, label='r=' + str(np.round(np.corrcoef(x_axis, y_axis)[0][1], 2))) # color=cmap[ legend_wo_dot(ax, 0.98, x_axis, y_axis, x_pos=0.4) cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal" cbar.set_label( 'Modulation Depth\n' + cell_type_here + '(' + str(species[0:5]) + '.)') # rotation=270,, labelpad=100 return cmap, x_axis, y_axis def data_overview(): plot_style() default_settings(column=2, length=8.5) grid0 = gridspec.GridSpec(3, 1, wspace=0.54, bottom=0.1, hspace=0.25, height_ratios=[1, 1, 2], left=0.1, right=0.87, top=0.95) scoreall = 'perc99/med' scores = [scoreall + '_diagonal_proj'] ########################## # Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln save_names = [ 'calc_RAM_overview-_simplified_' + version_final(), ] # 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_','calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_base", ] cv_name_title = ['CV', 'CV$_{BurstCorr}$', 'CV'] species_all = [' Apteronotus leptorhynchus', ' Apteronotus leptorhynchus', ' Eigenmannia virescens'] counter = 0 cell_types = [' P-unit', ' Ampullary', ] colors = colors_overview() ax_j = [] axls = [] score = scores[0] for cv_n, cv_name in enumerate(x_axis): if cv_n == 0: pass else: pass redo = False frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo) for c, cell_type_here in enumerate(cell_types): species = species_all[cv_n] frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species) grid = gridspec.GridSpecFromSubplotSpec(1, 3, grid0[0], hspace=0, wspace=0.15) # if c == 0: grid_k = gridspec.GridSpecFromSubplotSpec(2, 2, grid[0, cv_n], hspace=0.1, wspace=0.1, height_ratios=[0.35, 3], width_ratios=[3, 0.5]) try: axk = plt.subplot(grid_k[0, 0]) except: print('grid something') embed() ax_j.append(axk) axs = plt.subplot(grid_k[1, 0]) ax_j.append(axs) axl = plt.subplot(grid_k[1, 1]) axls.append(axl) if c in [0, 2]: axk.set_title(species) axs, x_axis = kernel_scatter(axl, axk, axs, c, cell_type_here, colors, cv_name, frame_file, score) axs.set_xlabel(cv_name_title[cv_n]) if cv_n == 0: axs.set_ylabel('Perc(99)/Median') grid_lower = gridspec.GridSpecFromSubplotSpec(2, 2, grid0[2], hspace=0.55, wspace=0.5) # cv_name = "cv_base" species = ' Apteronotus leptorhynchus' for c, cell_type_here in enumerate(cell_types): frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species) # embed() ############################################## # jetzt kommen die extra P-unit statistiken if cell_type_here == ' P-unit': if c == 0: ################################ # Modulation, cell type comparison var_types = ['burst_fraction_burst_corr_base', 'cv_base'] x_axis = ['cv_base', 'burst_fraction_burst_corr_base', ] var_item_names = ['Burst Fraction', 'CV$'+basename()+'$'] x_axis_names = ['CV$'+basename()+'$', 'Burst Fraction', ] for v, var_type in enumerate(var_types): ax = plt.subplot(grid_lower[0, v]) cmap, _, y_axis = plt_burst_modulation(var_item_names[v], ax, cell_type_here, x_axis[v], frame_file, score, var_item=var_type) ax.set_ylabel(score) ax.set_xlabel(x_axis_names[v]) ax.set_yscale('log') if v == 0: ############################ # extra Zellen Scatter # todo: diese Zellen müssen noch runter konvertiert werden # todo: extra funktion für Zellen über 9 Snippets schreiben und die nochmal extra machen cells_plot2 = p_units_to_show(type_here='bursts') cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index ax.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra], s=5, color='white', edgecolor='black', alpha=0.5, clip_on=False) # colors[str(cell_type_here)] ########################################## # burst gegen CV var_types = ['burst_fraction_burst_corr_base', 'response_modulation'] var_item_names = ['Burst Fraction', 'Modulatoin'] x_axis = ['cv_base', 'burst_fraction_burst_corr_base'] x_axis_names = ['Burst Fraction$'+basename()+'$', 'Burst Fraction$'+basename()+'$'] # 'CV$'+basename()+'$' scores_here = ['coherence_', 'burst_fraction_burst_corr_stim'] # 'wo_burstcorr' for v, var_type in enumerate(var_types): if scores_here[v] in frame_file.keys(): ax = plt.subplot(grid_lower[1, v]) cmap, _, y_axis = plt_burst_modulation(var_item_names[v], ax, cell_type_here, x_axis[v], frame_file, scores_here[v], var_item=var_type) if v == 1: ax.plot([0, 1], [0, 1], color='grey', linewidth=0.5) ax.set_xlabel(x_axis_names[v]) ax.set_ylabel(scores_here[v]) else: embed() grid_lower_lower = gridspec.GridSpecFromSubplotSpec(1, 2, grid0[1], wspace=0.5, hspace=0.55) # , height_ratios = [1,3] for c, cell_type_here in enumerate(cell_types): frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species) ############################################## # modulatoin comparison for both cell_types ################################ # Modulation, cell type comparison # todo: hier die diff werte über die zellen axs = plt.subplot(grid_lower_lower[c]) cmap, _, y_axis = plt_modulation_overview(axs, cell_type_here, cv_name, frame_file, score, species) axs.set_ylabel(score) axs.set_xlabel(cv_name) # axs.get_shared_x_axes().join(*[axs, axd]) ###################################################### # hier kommen die kontrast Punkte dazu # für die Zellen spielt Burst correctin ja keine Rolle if cell_type_here == ' P-unit': cells_plot2 = p_units_to_show(type_here='contrasts')[1::] else: cells_plot2 = [p_units_to_show(type_here='contrasts')[0]] # for cell_plt in cells_plot2: cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index # ax = plt.subplot(grid[1, cv_n]) axs.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra], s=5, color='white', edgecolor='black', alpha=0.5, clip_on=False) # colors[str(cell_type_here)] counter += 1 ######################## # modell model = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') cells = model.cell.unique() plt_model_overview2(ax_j[1], cells, scores=[scoreall + '_']) plt.subplots_adjust(left=0.07, right=0.95, top=0.98, bottom=0.05, wspace=0.45, hspace=0.55) ax_j[0].get_shared_y_axes().join(*[ax_j[1], ax_j[3], ax_j[5], axls[0], axls[1], axls[2]]) ax_j[0].get_shared_x_axes().join(*ax_j) save_visualization(pdf=True) def calc_averag_spike_per_burst_package(burst_corr, h, lim, spikes_all): # also hier nimmt man einfach all jene spikes die übrig bleiben, nur die erste Verteilung zu nehmen ist je unmöglich if 'inverse' in burst_corr: first_true = [False] first_true.extend(h < lim) else: first_true = [True] first_true.extend(h > lim) test = False # todo: also entweder man schneidet zusammen nur die teile des inputs und des outputs wo die bursts waren, # fr_base = len(spikes_all[0]) / (spikes_all[0][-1] / 1000) # # todo: hier vielleicht auch noch ein <= machen # spike_ex = np.array(spikes_all)[np.array(first_true)] nrs_first_spike = np.arange(0, len(spikes_all), 1)[np.array(first_true)] burst_nr = np.diff(nrs_first_spike) return burst_nr, test def get_float_keys(stack_here): types = list(map(type, stack_here.keys())) keys = stack_here.keys()[np.where(np.array(types) == float)] if len(stack_here) != len(keys): keys = stack_here.index # () return keys def calc_serial(isi): corrs2 = [] if len(isi) > 100: length = len(isi) - 50 else: length = len(isi) for l in range(1, length): previous = isi # [0:-l] next = np.roll(isi, l) cut = True if cut: previous = previous[l::] # [0:-l] next = next[l::] # np.roll(isi, l) corrs2.append(np.corrcoef(next, previous)[0][1]) corr = np.mean(corrs2) sum_corr = np.sum(corrs2) test = False if test: from utils_test import corr_test corr_test() return corr, corrs2[0], sum_corr def roc_part(titles, devs, group_mean, ranges, fig, subdevision_nr, datapoints, datapoints_way, color, c, chose_score, cell, DF1_desired_ROC, DF2_desired_ROC, contrast_small, contrast_big, contrast1, dfs, start, dev, contrast, grid2, plot_group, autodefine2='_dfchosen_', sorted_on='eod_loc_synch', c1=10, c2=10, cut_matrix='malefemale', autodefine='_dfchosen_closest_first_', chirps='', data_dir='', mean_type='MeanTrialsIndexPhaseSort', extract='', mult_type='_multsorted2_', indices=['_allindices_'], eodftype='_psdEOD_', titles_up=['Without female', 'With female']): _, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired_ROC_exact, DF2_desired_ROC_exact, min_y, min_x, min_val, diff_cut = chose_mat_max_value( DF1_desired_ROC, DF2_desired_ROC, extract, mult_type, eodftype, indices, cell, contrast_small, contrast_big, contrast1, dfs, start, dev, contrast, autodefine=autodefine2, cut_matrix=cut_matrix, chose_score=chose_score, mean_type=mean_type) # chose_score = 'auci02_012-auci_base_01' colors = ['orange', 'green'] base = cell.split(os.path.sep)[-1] + ".nix" if data_dir == '': path = load_folder_name('threefish') + '/' + cell else: path = '../data/' + data_dir[c] + cell full_path = path + '/' + base try: file = nix.File.open(full_path, nix.FileMode.ReadOnly) except: full_path = '../data/cells/' + cell + '/' + cell + ".nix" file = nix.File.open(full_path, nix.FileMode.ReadOnly) print('load extra' + full_path) b = file.blocks[0] all_mt_names, mt_names, t_names = get_all_nix_names(b, what='Three') if mt_names: nix_there = check_nix_fish(b) if nix_there: times_sort = predefine_grouping_frame(b, eodftype=eodftype) counter_waves = 0 times_sort = times_sort[ (times_sort['c2'] == c2) & (times_sort['c1'] == c1)] for gg in range(len(DF1_desired_ROC_exact)): ax1_3 = {} ################### # all trials in one grouped = times_sort.groupby( ['c1', 'c2', 'm1, m2'], as_index=False) grouped_mean = chose_certain_group(DF1_desired_ROC_exact[gg], DF2_desired_ROC_exact[gg], grouped, several=True, emb=False, concat=True) # for g in range(len(grouped_mean)): # if 'Trials' not in mean_type: ################### # groups sorted by repro tag grouped = times_sort.groupby( ['c1', 'c2', 'm1, m2', 'repro_tag_id'], as_index=False) grouped_orig = chose_certain_group(DF1_desired_ROC_exact[gg], DF2_desired_ROC_exact[gg], grouped, several=True) ################### # other group variants colors_groups = ['black', 'brown', 'red', 'pink', 'orange', 'yellow', 'lightgreen', 'green', 'darkgreen', 'lightblue', 'blue', 'navy', 'purple'] # [::-1] ######################################################### groups_variants = [[grouped_mean]] ax1_3[plot_group] = plt.subplot(grid2, aspect='auto') for g, grouped2 in enumerate(groups_variants): results_diff = grouped2[0].copy() cv0, spike_pures_split, delays_split = plt_error_bar(plot_group, group_mean, extract, ax1_3, subdevision_nr, groups_variants.copy(), b, chirps, mean_type, devs, counter_waves, results_diff, datapoints, datapoints_way, grouped_orig, sorted_on=sorted_on, color=color) # frame, devname, spikes_pure, group_name, auc_names_condition, auc_names_control = plt_only_roc_repetitive( extract, ax1_3, fig, grouped2, g, b, chirps, mean_type, devs, counter_waves, results_diff, datapoints, datapoints_way, grouped_orig, colors_groups, ranges=ranges, sorted_on=sorted_on, lw=1.5) fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine) plt.suptitle( cell + ' c1: ' + str(group_name[0]) + '% m1: ' + str( group_name[2][0]) + ' DF1: ' + str( grouped_mean['DF1, DF2'].iloc[0][ 0]) + ' c2: ' + str( group_name[1]) + '% m2: ' + str( group_name[2][1]) + ' DF2: ' + str( grouped_mean['DF1, DF2'].iloc[0][ 1]) + '\n Trials nr ' + str( len(grouped_mean)) + ' sorted on ' + sorted_on + ' ' + mean_type + ' cv ' + str( np.round(cv0, 2)) + ' ' + fr_end) try: mt_group1 = grouped2[0][1] except: mt_group1 = grouped2[0] try: eodf = np.mean(mt_group1.eodf) except: print('eod problem4') embed() _, _ = find_length_of_all_trials(grouped2, group_name) if g == 0: if len(auc_names_control) > 0: ax1_3[plot_group].text(0.5, 2, auc_names_control[0][0] + '-' + auc_names_control[0][1], va='center', ha='center', transform=ax1_3[plot_group].transAxes, ) else: ax1_3[plot_group].text(0.5, 2, 'base' + '-' + '01', auc_names_control[0][1], va='center', ha='center', transform=ax1_3[plot_group].transAxes, ) ax1_3[plot_group].text(0.7, 1.5, 'm1: ' + str( group_name[2][0]) + ' /DF1: ' + str( int((group_name[2][0] - 1) * eodf)) + '[Hz]', va='center', ha='center', transform=ax1_3[plot_group].transAxes, color=colors[gg]) ax1_3[plot_group].text(0.7, 1.7, ' m2: ' + str( group_name[2][1]) + '/ DF2: ' + str( int((group_name[2][1] - 1) * eodf)) + '[Hz] ', va='center', ha='center', transform=ax1_3[plot_group].transAxes, ) if (gg == 0) & (g == len(groups_variants) - 1): ax1_3[plot_group].set_ylabel('Correct-Detection Rate: ' + titles[plot_group][1]) ax1_3[plot_group].set_xlabel('False-Positive Rate: ' + titles[plot_group][0]) ax1_3[plot_group].set_title(titles_up[plot_group]) return frame, devname, spikes_pure, spike_pures_split, delays_split def plt_error_bar(plot_group, group_mean, extract, ax1_3, subdevision_nr, groups_variants, b, chirps, mean_type, devs, counter_waves, results_diff, datapoints, datapoints_way, grouped_orig, sorted_on='eod_loc_synch', color=['grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey']): spike_pures = [] delays_split = [] if '_AllTrialsIndex' in mean_type: plt_error_bar_trials_roc(counter_waves, results_diff, mean_type, extract, chirps, ax1_3, plot_group, group_mean, datapoints_way, b, datapoints, devs, groups_variants, grouped_orig, test=False) else: range_nr = int(len(group_mean[1]) / subdevision_nr) grouped_borders = find_group_variants(group_mean[1], [], start=1, steps=1, ranges=[range_nr]) groups_variants = grouped_borders for g, grouped in enumerate(groups_variants): print('group_variants' + str(g)) group_name = grouped_orig[0][0] if type(list(grouped)[0]) != str: grouped = list(grouped) tp = {} fp = {} for ggg in range(len(grouped)): try: grouped2 = [group_name, grouped[ggg]] spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three( grouped2, b, extract, chirps=chirps, emb=False, mean_type=mean_type, sorted_on=sorted_on) _, _, _, _ = get_mt_features3(b, grouped2) except: grouped2 = grouped[ggg] spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three( grouped2, b, extract, chirps=chirps, emb=False, mean_type=mean_type) _, _, _, _ = get_mt_features3(b, grouped2) print('grouped2 problem') embed() spike_pures.append(spikes_pure) mean_isi, std_isi, fr, isi, cv0, ser0, ser_first, sum_corr = calc_baseline_char( np.array(spikes_pure.base_0), np.abs(fish_cuts[0]), len(spikes_pure.base_0)) t1 = time.time() t2 = time.time() - t1 print('spikes pure' + str(t2)) dev_nrs = find_right_dev(devname, devs) t = dev_nrs[0] frame_dev = frame[frame['dev'] == devname[t]] delays_length = define_delays_trials(mean_type, frame, sorted_on=sorted_on) delays_split.append(delays_length) t1 = time.time() array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials( frame, devname[t], delays_length, mean_type) t2 = time.time() - t1 print('array' + str(t2)) plt_error_bar0(array0, array01, array012, array02, ax1_3, color, counter_waves, datapoints, datapoints_way, frame_dev, g, ggg, group_mean, plot_group, results_diff, tp=tp, fp=fp) return cv0, spike_pures, delays_split def plt_error_bar0(array0, array01, array012, array02, ax1_3, color, counter_waves, datapoints, datapoints_way, frame_dev, g, ggg, group_mean, plot_group, results_diff, tp={}, fp={}): # mt, name_here, mt threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values( array0, array01, array02, array012, datapoints_way[0], datapoints[0], results_diff, counter_waves=counter_waves, id_group=group_mean) arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting( [], [], roc_01, roc_012, tp_01_all, tp_012_all, frame_dev, 0, fp_all, tp_02_all, roc_0, roc_02, [], []) if ggg == 0: tp[g] = [auc_tp_condition[0][plot_group]] fp[g] = [auc_fp_control[0][plot_group]] else: tp[g].append(auc_tp_condition[0][plot_group]) fp[g].append(auc_fp_control[0][plot_group]) try: ax1_3[plot_group].plot(np.transpose(fp[g][ggg]), np.transpose(tp[g][ggg]), color=color[ggg]) # , alpha=0.5 ax1_3[plot_group].plot(np.transpose(fp[g][ggg]), np.transpose(tp[g][ggg]), color=color[ggg]) # , alpha=0.5 print(color[ggg]) except: print('ggg problem') embed() test = False if test: some_roc_test(fp, tp) def some_roc_test(fp, tp): fig, ax = plt.subplots(3, 3, sharex=True, sharey=True) ax = np.concatenate(ax) for g in range(len(tp)): ax[g].plot(np.transpose(fp[g]), np.transpose(tp[g])) ax[g].plot(np.percentile(fp[g], 95, axis=0), np.percentile(tp[g], 5, axis=0), color='grey', alpha=0.5) ax[g].plot( np.percentile(fp[g], 5, axis=0), np.percentile(tp[g], 5, axis=0), color='grey', alpha=0.5) def define_arrays_for_roc_plotting(roc_01_eod, roc_012_eod, roc_01, roc_012, tp_01_all, tp_012_all, frame_dev, d, fp_all, tp_02_all, roc_0, roc_02, roc_02_eod, base_here_eod): roc_array1_eod_condition = [] roc_array_condition = [] second = [] counter1 = [] counter2 = [] auc_names_condition = [] roc_array_eod_control = [] auc_array_condition = [] if frame_dev['control_02'].iloc[d] != []: second = 'first_sw' counter1 = 0 counter2 = 2 arrow = True if second == 'first_sw': ################################## # NICHT VERWIRREN LASSEN; VON OBEN NACH UNTEN LESEN; hier ist BASE und 01 das erste Bild! auc_names_control = [['base', '02', ]] auc_array_control = [[fp_all, tp_02_all, ]] roc_array_control = [[roc_0, roc_02, ]] roc_array_eod_control = [[base_here_eod, roc_02_eod, ]] arrow = False auc_names_condition = [['01', '012']] auc_array_condition = [[tp_01_all, tp_012_all]] if len(roc_01_eod) > 0: roc_array_condition = [[roc_01, roc_012]] roc_array1_eod_condition = [[roc_01_eod, roc_012_eod]] counter1 = 1 counter2 = 3 elif second == 'first': auc_names_control = [['02', 'base']] auc_array_control = [[tp_02_all, fp_all]] roc_array_control = [[roc_02, roc_0]] roc_array_eod_control = [ [roc_02_eod, base_here_eod]] auc_names_condition = [['012', '01']] auc_array_condition = [[tp_012_all, tp_01_all]] roc_array_condition = [[roc_012, roc_01]] if len(roc_01_eod) > 0: roc_array1_eod_condition = [ [roc_012_eod, roc_01_eod]] elif second == 'second': ################################## auc_names_control = [['01', 'base']] auc_array_control = [[tp_01_all, fp_all]] roc_array_control = [[roc_01, roc_0]] if len(roc_01_eod) > 0: roc_array_eod_control = [ [roc_01_eod, base_here_eod]] auc_names_condition = [['012', '02']] auc_array_condition = [[tp_012_all, tp_02_all]] roc_array_condition = [[roc_012, roc_02]] roc_array1_eod_condition = [[roc_012_eod, roc_02_eod]] else: ################################## # das plottet nur die zwei Kombis einmal kontrast 01 zu base und einmal kontrast 02 zu bas auc_names_control = [['01', 'base'], ['02', 'base']] auc_array_control = [[tp_01_all, fp_all], [tp_02_all, fp_all]] roc_array_control = [[roc_01, roc_0], [roc_02, roc_0]] if len(roc_01_eod) > 0: roc_array_eod_control = [ [roc_01_eod, base_here_eod], [roc_02_eod, base_here_eod]] auc_names_condition = [['012', '02'], ['012', '01']] auc_array_condition = [[tp_012_all, tp_02_all], [tp_012_all, tp_01_all]] if len(roc_012) > 0: roc_array_condition = [[roc_012, roc_02], [roc_012, roc_01]] roc_array1_eod_condition = [[roc_012_eod, roc_02_eod], [roc_012_eod, roc_01_eod]] else: auc_names_control = [['base', '01'], '012', ] auc_array_control = [[fp_all, tp_01_all], tp_012_all] roc_array_control = [[roc_0, roc_01], roc_012] return arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_array_condition, auc_names_control, auc_array_control, roc_array_control, roc_array1_eod_condition, roc_array_condition def find_length_of_all_trials(grouped, group_name): lengths = [] for l in range(len(grouped)): if len(grouped[l]) != 2: grouped2 = [group_name, grouped[l]] else: grouped2 = grouped[l] lengths.append(len(grouped2[1])) sum_trials = lengths return sum_trials, lengths def plt_error_bar_trials_roc(counter_waves, results_diff, mean_type, extract, chirps, ax1_3, plot_group, group_mean, datapoints_way, b, datapoints, devs, groups_variants, grouped_orig, test=False): for g, grouped in enumerate(groups_variants): print('group_variants' + str(g)) if type(list(grouped)[0]) != str: grouped = list(grouped) for ggg in range(len(grouped)): if len(grouped[ggg]) != 2: pass else: pass spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three( group_mean, b, extract, chirps=chirps, emb=False, mean_type=mean_type) features, mt, name_here, l = get_mt_features3(b, grouped_mean) # not there dev_nrs = find_right_dev(devname, devs) t = dev_nrs[0] frame_dev = frame[frame['dev'] == devname[t]] delays_length = define_delays_trials(mean_type, frame) array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials( frame, devname[t], delays_length, mean_type) range_nr = int(len(array0) / 3) array0_gr = find_group_variants(array0, [], start=1, steps=1, ranges=[range_nr]) array01_gr = find_group_variants(array01, [], start=1, steps=1, ranges=[range_nr]) array02_gr = find_group_variants(array02, [], start=1, steps=1, ranges=[range_nr]) array012_gr = find_group_variants(array012, [], start=1, steps=1, ranges=[range_nr]) tp = {} fp = {} for g in range(len(array012_gr)): print(g) for g_in in range(len(array012_gr[g])): threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values( array0_gr[g][g_in], array01_gr[g][g_in], array02_gr[g][g_in], array012_gr[g][g_in], datapoints_way[0], datapoints[0], results_diff, mean_nrs, l, features, name_here, mt, counter_waves=counter_waves, id_group=group_mean) arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting( [], [], roc_01, roc_012, tp_01_all, tp_012_all, frame_dev, 0, fp_all, tp_02_all, roc_0, roc_02, [], []) if g_in == 0: tp[g] = [auc_tp_condition[0][plot_group]] fp[g] = [auc_fp_control[0][plot_group]] else: tp[g].append(auc_tp_condition[0][plot_group]) fp[g].append(auc_fp_control[0][plot_group]) print(g_in) ax1_3[plot_group].plot(np.transpose(fp[g]), np.transpose(tp[g]), color='grey', alpha=0.5) ax1_3[plot_group].plot(np.transpose(fp[g]), np.transpose(tp[g]), color='grey', alpha=0.5) if test: from utils_test import test_groups test_groups() def plt_only_roc_repetitive(extract, ax1_3, fig, grouped, g, b, chirps, mean_type, devs, counter_waves, results_diff, datapoints, datapoints_way, grouped_orig, colors_groups, sorted_on='eod_loc_synch', ranges=[], lw=0.4): print('group_variants' + str(g)) group_name = grouped_orig[0][0] if type(list(grouped)[0]) != str: grouped = list(grouped) roc_color = colors_groups[g] # todo: diese Funktion funktioniert eigentlich nur für den Mean for ggg in range(len(grouped)): if len(grouped[ggg]) != 2: grouped2 = [group_name, grouped[ggg]] else: grouped2 = grouped[ggg] frame, devname, spikes_pure, auc_names_condition, auc_names_control = plt_only_roc_plot(extract, counter_waves, results_diff, datapoints, datapoints_way, ax1_3, fig, grouped2, b, chirps, mean_type, devs, roc_color=roc_color, sorted_on=sorted_on, range_roc=ranges, lw=lw) return frame, devname, spikes_pure, group_name, auc_names_condition, auc_names_control def plt_only_roc_plot(extract, counter_waves, results_diff, datapoints, datapoints_way, ax1_3, fig, group_mean, b, chirps, mean_type, devs, roc_color='black', sorted_on='eod_loc_synch', range_roc=[], lw=0.7): spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three( group_mean, b, extract, chirps=chirps, emb=False, mean_type=mean_type, sorted_on=sorted_on) features, mt, name_here, l = get_mt_features3(b, group_mean) _, _, _, _, _, _ = get_fish_number(b, group_mean, mean_type) auc_names_condition = [] auc_names_control = [] if len(devname) > 0: dev_nrs = find_right_dev(devname, devs) t = dev_nrs[0] frame_dev = frame[frame['dev'] == devname[t]] delays_length = define_delays_trials(mean_type, frame, sorted_on=sorted_on) if len(delays_length) > 1: if not delays_length['012']: print('DEBUGG: add sorted_on=sorted_on in cut_spikes_and_eod_three!!!') array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials( frame, devname[t], delays_length, mean_type) test = False if test: from utils_test import plt_arrays_sort plt_arrays_sort() plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02, array012_all, array012, ) auc_names_condition, auc_names_control = plt_only_roc_plot0(array0, array01, array012, array02, ax1_3, counter_waves, datapoints, datapoints_way, features, fig, frame_dev, group_mean, l, lw, mean_nrs, mt, name_here, range_roc, results_diff, roc_color) return frame, devname, spikes_pure, auc_names_condition, auc_names_control def plt_only_roc_plot0(array0, array01, array012, array02, ax1_3, counter_waves, datapoints, datapoints_way, features, fig, frame_dev, group_mean, l, lw, mean_nrs, mt, name_here, range_roc, results_diff, roc_color): threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values( array0, array01, array02, array012, datapoints_way[0], datapoints[0], results_diff, mean_nrs, l, features, name_here, mt, counter_waves=counter_waves, id_group=group_mean) arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting( [], [], roc_01, roc_012, tp_01_all, tp_012_all, frame_dev, 0, fp_all, tp_02_all, roc_0, roc_02, [], []) a_all = 0 counter_a = 0 # here we choose which of the two arrays comparison we want, only the base-01 or also the 01-012 if len(range_roc) < 1: range_roc = range(len(auc_fp_control[0])) for a in range_roc: if (type(ax1_3) == list) | (type(ax1_3) == dict): ax = ax1_3[a] else: ax = ax1_3 try: plot_rocs(fig, ax, counter_a, auc_names_control[a_all][a], a, auc_names_condition[a_all], auc_fp_control[a_all], auc_tp_condition[a_all], results_diff, auc_names_control[a_all][a], auc_names_condition[a_all][a], pos=[0, -0.35], legend=False, arrow=arrow, add=0.2, alpha=1, counter1=counter1, counter2=counter2, roc_color=roc_color, emb=False, second_roc=False, lw=lw) except: print('ax something') embed() return auc_names_condition, auc_names_control def traces_new(array012, position_diff, array01, way, array02, array0): datapoints_all = [250, 500, 750, 1000, 1500] restricts = np.arange(1000, len(array012[0]), 4000) counter = 0 grid0 = gridspec.GridSpec(len(datapoints_all), len(restricts), bottom=0.07, top=0.93, wspace=0.24, left=0.06, right=0.92) # hspace=0.4,wspace=0.2, ax = None for d, datapoints in enumerate(datapoints_all): for r, restrict in enumerate(restricts): print(len(array012[0][0:restrict])) try: trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, [array012[0][0:restrict]], [array01[0][0:restrict]], [array02[0][0:restrict]], [array0[0][0:restrict]], t_off=10, way=way, emb=False, datapoints=datapoints) # , threshhold=threshhold if type(ax) is None: ax = plt.subplot(grid1[0, 0], sharey=ax) else: ax = plt.subplot(grid1[0, 0]) ax.set_title(str(restrict) + 'dp at all, ' + str(datapoints) + 'dp window', fontsize=7) plt.plot(np.transpose(roc_0), color='orange') plt.plot(np.transpose(roc_01), color='green') plt.subplot(grid1[1, 0], sharey=ax, sharex=ax) plt.plot(np.transpose(roc_02), color='orange') plt.plot(np.transpose(roc_012), color='blue') ax = plt.subplot(grid1[:, 1]) ax.plot(fp_all, tp_01_all, label='base-01', color='green') ax.plot(tp_02_all, tp_012_all, label='02-012', color='blue') counter += 1 except: pass grid1 = gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.4, wspace=0.2, subplot_spec= grid0[counter]) # plt.legend() save_visualization() plt.show() def calc_auci_values(array0, array01, array02, array012, way, datapoints, results_diff, mean_nrs='', l=[], features=[], name_here=[], mt=[], counter_waves=[], t_off=10, sampling=40000, position_diff=[], time_sacrifice=0, id_group=[]): if position_diff == []: position_diff = len(results_diff) # todo: noch hier das mit Mehrfachen einbauen results_diff.loc[position_diff, 'time_sacrifice'] = time_sacrifice trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, datapoints=datapoints) test = False if test: traces_new(array012, position_diff, array01, way, array02, array0) if len(features) > 0: results_diff = feature_extract_cut(mt, l, name_here, results_diff, position_diff, features) results_diff.loc[position_diff, 'datapoints'] = datapoints results_diff.loc[position_diff, 'datapoints_time'] = np.round(datapoints / sampling, 3) results_diff.loc[position_diff, 'datapoints_way'] = way results_diff.loc[position_diff, 'trial_nrs'] = len(roc_01) results_diff.loc[position_diff, 'mean_nrs'] = mean_nrs results_diff.loc[position_diff, 't_off'] = t_off results_diff = save_structure_to_frame(position_diff, results_diff, np.array(id_group[1]['mt']), name='mt') results_diff = save_structure_to_frame(position_diff, results_diff, id_group[0], name='g_idx') counter_savename = [] return threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves def concat_rocs(control_02, base_orig, control_01, array_012, datapoints, t_off): roc_02_con = [] arrays = [array_012, control_02, control_01, base_orig] names = ['012', '02', '01', 'base'] arrays_new = {} arrays_last = {} # todo: hier noch was machen for a, array in enumerate(arrays): start1 = True for d in range(len(array)): trials = np.arange(datapoints + t_off, len(array[d]), datapoints + t_off) if len(array[d]) > 0: arrays_new[names[a]] = np.split(array[d], trials) arrays_new[names[a]] = arrays_new[names[a]][0:-1] arrays_new[names[a]] = np.array(arrays_new[names[a]])[:, 0:-t_off] try: if len(arrays_new[names[a]]) != 1: if len(arrays_new[names[a]][-1]) != len(arrays_new[names[a]][-2]): arrays_new[names[a]] = arrays_new[names[a]][0:-1] except: print('utils func roc to short') embed() if start1 == True: arrays_last[names[a]] = arrays_new[names[a]] start1 = False else: try: prev = list(arrays_last[names[a]]) prev.extend(arrays_new[names[a]]) arrays_last[names[a]] = prev except: print('array append problem') embed() if '012' in arrays_last.keys(): roc_012_con = arrays_last['012'] else: roc_012_con = [] if 'base' in arrays_last.keys(): base_con = arrays_last['base'] else: base_con = [] if '01' in arrays_last.keys(): roc_01_con = arrays_last['01'] else: roc_01_con = [] if '02' in arrays_last.keys(): roc_02_con = arrays_last['02'] roc2_there = True else: roc2_there = False return trials, roc2_there, roc_02_con, roc_012_con, roc_01_con, base_con def calc_auci_pd(results_diff, position_diff, array_012, control_01, control_02, base_orig, add='', t_off=5, way='', emb=[], printing=False, datapoints=[], threshhold_step=50, f0='EODf', sampling=40000): ## better to not convert to pandas to much especially if it has numerous of columns.. this might take really long! if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2' try: datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0) except: try: f0 = 'f0' datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0) except: f0 = 'EODf' datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0) t1 = time.time() trials, roc2_there, roc_02_con, roc_012_con, roc_01_con, base_con = concat_rocs(control_02, base_orig, control_01, array_012, datapoints, t_off) if printing: print('ROC0' + str(time.time() - t1)) t1 = time.time() tp_02, tp_01, tp_012, fp_base, threshhold = threshold_roc(threshhold_step, roc2_there, base_con, roc_01_con, roc_02_con, roc_012_con) if printing: print('ROC1' + str(time.time() - t1)) t1 = time.time() tp_012_all = np.mean(tp_012, axis=0) tp_01_all = np.mean(tp_01, axis=0) fp_base_all = np.mean(fp_base, axis=0) if roc2_there == True: tp_02_all = np.mean(tp_02, axis=0) else: tp_02_all = [] results_diff, names_present, names_present_real = calc_auc_diff(tp_02_all, add, results_diff, position_diff, tp_012_all, fp_base_all, tp_01_all, roc2_there) if printing: print('ROC2' + str(time.time() - t1)) t1 = time.time() if roc2_there == True: results_diff.loc[position_diff, 'auc_' + '02' + '_' + '01' + add] = metrics.auc(tp_02_all, tp_01_all) results_diff.loc[position_diff, 'auci_' + '02' + '_' + '01' + add] = np.abs( np.asarray(results_diff.loc[position_diff, 'auc_' + '02' + '_' + '01' + add]) - 0.5) names_present_real.append('02' + '_' + '01') try: _, interp = interp_arrays(fp_base_all, tp_02_all, step=0.05) except: print('Interp line 6662') embed() results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='base_02' + add, double=False) _, interp = interp_arrays(tp_02_all, tp_012_all, step=0.05) results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='02_012' + add, double=False) try: _, interp = interp_arrays(fp_base_all, tp_01_all, step=0.05) except: print('interp fp_base_all in utils_func') embed() test = False if test: fig, ax = plt.subplots(4, 1, sharex=True) ax[0].plot(control_02[0]) ax[1].plot(control_01[0]) ax[2].plot(array_012[0]) ax[3].plot(base_orig[0]) results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='base_01' + add, double=False) time_array, interp = interp_arrays(tp_01_all, tp_012_all, step=0.05) results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='01_012' + add, double=False) results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='time_array' + add, double=False) if printing: print('ROC3' + str(time.time() - t1)) t1 = time.time() diff_tuples = [ ['base_012', 'base_02'], ['base_012', 'base_01'], ['02_012', 'base_02'], ['01_012', 'base_01'], ['02_012', 'base_01'], ['01_012', 'base_02'], ['01_02', 'base_01'], ['02_01', 'base_02'] ] for diff_tuple in diff_tuples: if ('auc_' + diff_tuple[0] + add in results_diff.keys()) and ( 'auc_' + diff_tuple[1] + add in results_diff.keys()): results_diff.loc[position_diff, 'auc_' + diff_tuple[0] + '-' + 'auc_' + diff_tuple[1] + add] = \ results_diff.loc[position_diff, 'auc_' + diff_tuple[0] + add] - results_diff.loc[ position_diff, 'auc_' + diff_tuple[1] + add] results_diff.loc[position_diff, 'auci' + diff_tuple[0] + '-' + 'auci_' + diff_tuple[1] + add] = \ results_diff.loc[position_diff, 'auci_' + diff_tuple[0] + add] - results_diff.loc[ position_diff, 'auci_' + diff_tuple[1] + add] if printing: print('ROC4' + str(time.time() - t1)) plot = False if plot: plot_roc_in_function() # tp_02_all, array_all, t, eod_fe, e, eod_fr, eod_fj, j, fpr, tpr,tp_012_all, fp_base_all, tp_01_all if emb: embed() return trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_base_all, roc_01_con, base_con, roc_02_con, roc_012_con, threshhold def threshold_roc(threshhold_step, roc2_there, base_con, roc_01_con, roc_02_con, roc_012_con): if roc2_there == True: max_arrays = np.concatenate([np.nanmax(base_con, axis=1), np.nanmax(roc_012_con, axis=1), np.nanmax(roc_01_con, axis=1), np.nanmax(roc_02_con, axis=1)]) else: try: max_arrays = np.concatenate([np.nanmax(base_con, axis=1), np.nanmax(roc_012_con, axis=1), np.nanmax(roc_01_con, axis=1)]) except: print('base_con problem') embed() higher_max = np.nanmax(max_arrays) lower_max = np.nanmin(max_arrays) threshhold = np.linspace(0.97 * lower_max, 1.02 * higher_max, threshhold_step) try: tp_012 = np.transpose( [np.max(roc_012_con, axis=1)] * len(threshhold) > np.transpose( [threshhold] * len(np.max(roc_012_con, axis=1)))) tp_01 = np.transpose([np.max(roc_01_con, axis=1)] * len(threshhold) > np.transpose( [threshhold] * len(np.max(roc_01_con, axis=1)))) fp_base = np.transpose([np.max(base_con, axis=1)] * len(threshhold) > np.transpose( [threshhold] * len(np.max(base_con, axis=1)))) except: print('threshold in utils_func') if roc2_there == True: tp_02 = np.transpose([np.max(roc_02_con, axis=1)] * len(threshhold) > np.transpose( [threshhold] * len(np.max(roc_02_con, axis=1)))) else: tp_02 = [] return tp_02, tp_01, tp_012, fp_base, threshhold def calc_auc_diff(tp_02_all, add, results_diff, position_diff, tp_012_all, fp_base_all, tp_01_all, roc2_there): if roc2_there == True: auc_names = ['base', '01', '02', '012', ] auc_array = [fp_base_all, tp_01_all, tp_02_all, tp_012_all] else: auc_names = ['base', '01', '012', ] auc_array = [fp_base_all, tp_01_all, tp_012_all] counter_a = 0 names_present = [] names_present_real = [] for a in range(len(auc_array)): for aa in range(0, len(auc_array), 1): if auc_names[a] != auc_names[aa]: if auc_names[a] + '_' + auc_names[aa] not in names_present: names_present.append(str(auc_names[a]) + '_' + auc_names[aa]) names_present.append(str(auc_names[aa]) + '_' + auc_names[a]) names_present_real.append(str(auc_names[a]) + '_' + auc_names[aa]) results_diff = results_diff.copy() results_diff.loc[position_diff, 'auc_' + auc_names[a] + '_' + auc_names[aa] + add] = metrics.auc( auc_array[a], auc_array[aa]) results_diff.loc[position_diff, 'auci_' + auc_names[a] + '_' + auc_names[aa] + add] = np.abs( np.asarray( results_diff.loc[position_diff, 'auc_' + auc_names[a] + '_' + auc_names[aa] + add]) - 0.5) counter_a += 1 return results_diff, names_present, names_present_real def plot_roc_in_function(tp_02_all, array_all, t, eod_fe, e, eod_fr, eod_fj, j, fpr, tpr, tp_012_all, fp_base_all, tp_01_all): plt.title('fe' + str(eod_fe[e]) + 'Hz fj' + str( eod_fj[j]) + 'Hz fr ' + str(eod_fr) + 'Hz') plt.subplot(2, 3, 1) plt.title('012') plt.plot(array_all['012'][t], color='red') plt.subplot(2, 3, 2) plt.title('01') plt.plot(array_all['control_01'][t], color='blue') plt.subplot(2, 3, 3) plt.plot(array_all['012'][t], color='red') plt.plot(array_all['control_01'][t], color='blue') plt.subplot(2, 3, 4) plt.plot(fpr, tpr) plt.subplot(2, 3, 5) plt.hist(array_all['012'][t], bins=100, color='red') plt.hist(array_all['control_01'][t], bins=100, color='blue') plt.subplot(2, 3, 6) plt.plot(np.sort(array_all['012'][t])) plt.plot(np.sort(array_all['control_01'][t])) plt.show() plt.subplot(3, 1, 1) plt.plot(fp_base_all, tp_012_all) plt.subplot(3, 1, 2) plt.plot(fp_base_all, tp_01_all) # plt.subplot(3, 1, 3) plt.plot(fp_base_all, tp_02_all) plt.show() def calc_baseline_char(spike_adapted, stimulus_length, trials_nr_base, data_restrict=[], emb=False): if emb: embed() fr = len(np.concatenate(spike_adapted)) / (stimulus_length * trials_nr_base) if len(data_restrict) > 0: max_pos = np.argmax(data_restrict) isi = np.diff(spike_adapted[max_pos]) else: isi = np.diff(spike_adapted[0]) if len(isi) < 3: for i in range(len(spike_adapted)): if len(spike_adapted[i]) > 2: isi = np.diff(spike_adapted[i]) if len(isi) > 1: std_isi = np.std(isi) mean_isi = np.mean(isi) cv0 = std_isi / mean_isi try: ser0, ser_first, sum_corr = calc_serial(isi) except: print('ser problem') embed() else: cv0 = np.float('nan') ser0 = np.float('nan') std_isi = np.float('nan') mean_isi = np.float('nan') ser_first = np.float('nan') sum_corr = np.float('nan') return mean_isi, std_isi, fr, isi, cv0, ser0, ser_first, sum_corr def find_group_variants(grouped_mean, groups_variants, start=15, steps=10, ranges=[]): if len(ranges) < 1: ranges = np.arange(start, len(grouped_mean), steps) for rr in range(len(ranges)): # das hier geht über die ranges und sagt wie viele einträge jeweils in einer gruppe sein sollen # das sind sozusagen unterkategorien von means # np.shape(groups_variants[0]) # so würde das zum bespeil das gruppieren # Out[36]: (2, 15, 18523) # In [37]: np.shape(groups_variants[1]) # Out[37]: (1, 25, 18523) # In [38]: np.shape(groups_variants[2]) # Out[38]: (1, 35, 18523) splits = np.arange(ranges[rr], len(grouped_mean), ranges[rr]) splits_done = np.split(grouped_mean, splits) if len(splits_done[-1]) != ranges[rr]: splits_done = splits_done[0:-1] splits_append = splits_done groups_variants.append(splits_append) return groups_variants def plot_second_roc(ax1, fig, array1, array2_0, array2_1, results_diff, names1_0, names1_1, names2_0, names2_1, add_name='', arrow=True, arrow2=True, pos=[1, -0.45], add=0.1): if arrow: ax1.annotate('', ha='center', xy=(1, 0.5), xytext=(1.4, 0.5), arrowprops={"arrowstyle": "->", "linestyle": "-", "linewidth": 3, "color": 'black'}, zorder=1) fig.texts.append(ax1.texts.pop()) time_interp, array2 = interp_arrays(array2_0[::-1], array2_1[::-1], step=0.01) auc1 = np.round( results_diff.iloc[-1][ 'auci_' + str( names1_0) + '_' + names1_1 + add_name] * 100) / 100 auc2 = np.round( results_diff.iloc[-1][ 'auci_' + str(names2_0) + '_' + names2_1 + add_name] * 100) / 100 auci_diff = np.round((auc1 - auc2) * 100) / 100 auci_label = 'auci ' + str(auc1) + '-' + str(auc2) + '=' + str(auci_diff) auc1 = np.round( results_diff.iloc[-1][ 'auc_' + str(names1_0) + '_' + names1_1 + add_name] * 100) / 100 auc2 = np.round( results_diff.iloc[-1][ 'auc_' + str(names2_0) + '_' + names2_1 + add_name] * 100) / 100 auc_diff = np.round((auc1 - auc2) * 100) / 100 auc_label = 'auc ' + str( auc1) + '-' + str(auc2) + '=' + str( auc_diff) if auc_diff > 0: ax1.text(pos[0], pos[1] - add, auc_label, fontsize=10, transform=ax1.transAxes, color='red') plt.fill_between(time_interp, array2, array1, color='red', alpha=0.5) ypos1 = array2[int(len(time_interp) / 2)] ypos2 = array1[int(len(time_interp) / 2 - 5)] xpos1 = time_interp[int(len(time_interp) / 2)] xpos2 = time_interp[int(len(time_interp) / 2 - 5)] mod = np.sqrt((ypos2 - ypos1) ** 2 + (xpos2 - xpos1) ** 2) print(arrow) if arrow2 == True: if mod > 0.1: ax1.annotate('', ha='center', xy=(xpos1, ypos1), xytext=(xpos2, ypos2), arrowprops={ "arrowstyle": "<-", "linestyle": "-", "linewidth": 1, "color": 'black'}, zorder=1) else: ax1.text(pos[0], pos[1] - add, auc_label, fontsize=10, transform=ax1.transAxes, color='blue') plt.fill_between(time_interp, array2, array1, color='blue', alpha=0.5) ypos1 = array1[ int(len(time_interp) / 2)] ypos2 = array2[ int(len(time_interp) / 2 - 5)] xpos1 = time_interp[ int(len(time_interp) / 2)] xpos2 = time_interp[ int(len(time_interp) / 2 - 5)] mod = np.sqrt((ypos2 - ypos1) ** 2 + ( xpos2 - xpos1) ** 2) if arrow2 == True: if mod > 0.1: ax1.annotate('', ha='center', xy=(xpos1, ypos1), xytext=(xpos2, ypos2), arrowprops={"arrowstyle": "->", "linestyle": "-", "linewidth": 1, "color": 'black'}, zorder=1) if auci_diff > 0: ax1.text(pos[0], pos[1], auci_label, fontsize=10, transform=ax1.transAxes, color='red') # transform else: ax1.text(pos[0], pos[1], auci_label, fontsize=10, transform=ax1.transAxes, color='blue') # transform plt.plot(time_interp, array2, color='black') # label=auci_label, def plot_rocs(fig, ax, counter_a, auc_names, a, auc_names1, fp_arrays, tp_arrays, results_diff, names, names1, counter1=0, counter2=2, legend=True, roc_color=[], alpha=1, lw=0.3, emb=False, second_roc=True, arrow=True, pos=[1, -0.45], add=-0.1, add_name=''): if emb: embed() fp_array = fp_arrays[a] if len(tp_arrays) > 0: tp_array = tp_arrays[a] fp, tp = interp_arrays(fp_array[::-1], tp_array[::-1], step=0.01) else: fp = [] tp = [] if (counter_a == counter1) or (counter_a == counter2): if (counter_a == 0) or (counter_a == 2): array2_0 = fp_arrays[a + 1] else: array2_0 = fp_arrays[a - 1] if len(tp_arrays) > 0: ############################## # je nach dem ob man das mit dem vorhergehenden oder nachkommenden array vergleicht if (counter_a == 0) or (counter_a == 2): array2_1 = tp_arrays[a + 1] else: array2_1 = tp_arrays[a - 1] else: array2_1 = [] if len(roc_color) > 0: color = roc_color else: color = 'purple' ax.plot(fp, tp, color=color, linewidth=lw, alpha=alpha, label=(np.round(results_diff.iloc[-1][ 'auci_' + str(auc_names[a]) + '_' + auc_names1[ a] + add_name] * 100) / 100)) # label=auci_label, if second_roc: if (counter_a == 0) or (counter_a == 2): plot_second_roc(ax, fig, tp, array2_0, array2_1, results_diff, auc_names[a], auc_names1[ a], auc_names[a + 1], auc_names1[a + 1], pos=pos, linewidth=lw, add=add, add_name=add_name, arrow=arrow) else: plot_second_roc(ax, fig, tp, array2_0, array2_1, results_diff, auc_names[a], auc_names1[ a], auc_names[a - 1], auc_names1[a - 1], pos=pos, linewidth=lw, add=add, add_name=add_name, arrow=arrow) else: if len(roc_color) > 0: color = roc_color else: color = 'black' ax.plot(fp, tp, label=(np.round(results_diff.iloc[-1][ 'auci_' + str(names) + '_' + names1 + add_name] * 100) / 100), color=color, linewidth=lw, alpha=alpha) if legend: plt.legend() ax.plot([0, 1], [0, 1], color='grey', linestyle='--', linewidth=0.5) def find_c_unique(name0, contrastc2, contrastc1, ): c1_uniques = [] c2_uniques = [] combinations = [] if os.path.exists(name0): spikes_o = pd.read_pickle(name0) combinations = spikes_o.groupby( [contrastc1, contrastc2]).groups.keys() # [[contrastc1,contrastc2]].unique() c2_unique = spikes_o[contrastc2].unique() c1_unique = spikes_o[contrastc1].unique() c1_uniques.append(c1_unique) c2_uniques.append(c2_unique) c1_unique = np.unique(c1_uniques)[::-1] c2_unique = np.unique(c2_uniques)[::-1] return c2_unique, c1_unique, combinations def find_all_threewave_versions(): dirs = os.listdir(load_folder_name('threefish')) dir_version = [] sizes = [] for dir in dirs: if 'invivo' not in dir: if ('DetectionAnalysis' not in dir) & ('pdf' not in dir) & ('png' not in dir) & ('AllTrials' in dir): dir_version.append(dir) sizes.append(os.path.getsize(load_folder_name('threefish') + '/' + dir)) dir_version = np.array(dir_version) dir_version = dir_version[np.argsort(sizes)[::-1]] sizes = np.sort(sizes)[::-1] return dir_version, sizes def get_fish_number(b, mt_group, mean_type): mt_list = mt_group[1]['mt'] # todo: da könnte man noch die schleife rausnehmen for mt_idx, mt_nr in enumerate(list(map(int, mt_list))): # range(start_l, len(mt.positions[:])) repro_position = mt_nr features, mt, name_here, l = get_mt_features3(b, mt_group, mt_idx) # somehow we have mts with negative extend, we exclude these if (mt.extents[:][mt_nr] > 0).any(): _, _, _, _, fish_number, fish_cuts, whole_duration, cont = load_durations(mt_nr, mt, mt_group[1], mt_idx, mean_type=mean_type, emb=False) delay = np.abs(fish_cuts[0]) if cont: contrast1 = mt_group[1]['c1'].iloc[ mt_idx] # mt_group[1]['c1'].loc[indices[mt_idx]] # mt.metadata.sections[0]['fish1alone']['Contrast'] contrast2 = mt_group[1]['c2'].iloc[mt_idx] # mt.metadata.sections[0]['fish2alone']['Contrast'] return fish_cuts, whole_duration, delay, contrast1, contrast2, repro_position def model_and_data_isi(nr_clim=10, many=False, width=0.005, HZ50=True, fs=8, nffts=['whole'], powers=[1], var_items=['contrasts'], contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV', D_extraction_method=['additiv_cv_adapt_factor_scaled'], internal_noise=['RAM'], external_noise=['RAM'], level_extraction=[''], cut_off2=300, receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], clims='all', restrict='restrict'): # ['eRAM'] stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100 trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500 variant = 'sinz' mimick = 'no' cell_recording_save_name = '' trans = 1 # 5 rep = 500000 # 500000#0 repeats = [20, rep] # 250000 good_data, remaining = overlap_cells() cells_all = good_data default_settings(column=2, length=4.9) # 0.75 grid = gridspec.GridSpec(1, 4, wspace=0.95, bottom=0.115, hspace=0.13, left=0.04, right=0.9, top=0.92, width_ratios=[0.7, 1, 1, 1]) a = 0 maxs = [] mins = [] ims = [] perc05 = [] perc95 = [] iternames = [D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] nr = '2' cell_contrasts = ["2013-01-08-aa-invivo-1"] cells_triangl_contrast = np.concatenate([cells_all, cell_contrasts]) rows = len(good_data) + len(cell_contrasts) perc = 'perc' lp = 10 label_model = r'Nonlinearity $\frac{1}{S}$' for all in it.product(*iternames): var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all fig = plt.figure() hs = 0.45 ################################# # model cells adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width = plt_model_part(HZ50, a, a_fe, a_fr, adapt_type, c_noise, c_sig, cell_recording_save_name, cells_all, cut_off1, cut_off2, dendrid, extract, fft_i, fft_o, fig, fs, grid, hs, ims, mV_unit, many, maxs, mimick, mins, nfft, noise_added, nr, perc05, perc95, power, ref_type, repeats, spikes_unit, stim_type_afe, stim_type_noise, stimulus_length, trans, trial_nrs, var_items, var_type, variant, width, label=label_model, rows=rows, perc=perc, xlabels=False, title=False) ################################# # data cells grid_data = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[1], hspace=hs) print('here') ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width, cbar_label=False, lp=lp, title=False) for ax in ax_data: # remove_xticks(ax) ax.set_xticks_delta(100) ax.text(-0.42, 0.87, F2_xlabel(), ha='center', va='center', transform=ax.transAxes, rotation=90) ax.text(1.66, 0.5, nonlin_title(), rotation=90, ha='center', va='center', transform=ax.transAxes) ax.arrow_spines('lb') ################################# # plt isi of data grid_isi = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[0], hspace=hs) spikes_type = 'base' if spikes_type == 'base': ax_isi = [] for f, cell in enumerate(cells_triangl_contrast): ###################################################### # frame = load_cv_base_frame(good_data, cell_type_type='cell_type_reclassified') frame, frame_spikes = load_cv_vals_susept(cells_triangl_contrast, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_sp='/calc_base_data-base_frame_overview.pkl', frame_general=False) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, 'cell_type_reclassified', frame) spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr) ax = plt.subplot(grid_isi[f]) colors = colors_overview() plt_susept_isi_base(colors[cell_type], ax, isi, delta=5, xlim=[0, 15], ypos=-0.15, clip_on=True) ax_isi.append(ax) else: ax_isi = plt_isi(cells_all, grid_isi, stack_spikes=stack_spikes_all, eod_frs=eod_frs) ###################################################################### print('started model contrasts') # hier das mit den Kontrasten # ok der code ist jetzt halt complex, aber den hab ich jetzt halt schon # daraus wollen wir so eine Übersicht machen params_c = {'contrasts': [0, 0.01, 0.025]} # 0.01, def_repeats = [rep] params = [ {'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][0]], 'D_extraction_method': D_extraction_method}, {'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][1]], 'D_extraction_method': D_extraction_method}, {'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][2]], 'D_extraction_method': D_extraction_method}, ] axes_contrast = [] for a in range(3): grid_model = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[1 + a], hspace=hs) ax = plt.subplot(grid_model[-1]) axes_contrast.append(ax) plt_squares_special(params, col_desired=3, var_items=['contrasts', 'repeats', 'level_extraction'], clims='', show=False, width=width, share=False, cells_given=[cells_all.iloc[0]], perc=perc, internal_noise=internal_noise, external_noise=external_noise, lp=lp, ax=axes_contrast, label='', new_plot=False, titles_plot=False) # 'D_extraction_method','all'"2013-01-08-aa-invivo-1" print('finished model contrasts') for a, ax in enumerate(axes_contrast): if a == 0: ax_data.append(ax) ax.text(-0.42, 0.87, F2_xlabel(), ha='center', va='center', transform=ax.transAxes, rotation=90) elif a == 1: ax_model.insert(2, ax) else: ax_model.append(ax) if a != 0: remove_yticks(ax) ax.text(1.05, -0.25, F1_xlabel(), ha='center', va='center', transform=ax.transAxes) ax.arrow_spines('lb') ax.set_xlabel('') ax.set_ylabel('') if a == 2: ax.text(1.5, 0.5, label_model, rotation=90, ha='center', va='center', transform=ax.transAxes) # axes.join ax_isi[0].get_shared_x_axes().join(*ax_isi) end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str( dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str( adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str( stimulus_length) + ' ' + ' power=' + str( power) + ' ' + restrict # set_clim_same(ims, perc05=perc05, perc95=perc95, lim_type='up', nr_clim=nr_clim, clims=clims) axes = np.array([np.array(ax_isi), np.array(ax_data), np.array(ax_model[0:int(len(ax_model) / 2)]), np.array(ax_model[int(len(ax_model) / 2)::])]) axes = np.transpose(axes) fig.tag([list(axes[0])], xoffs=-3, yoffs=2) # , minor_index=2 fig.tag([list(axes[1])], xoffs=-3, yoffs=2) # , minor_index=2 fig.tag([list(axes[2])], xoffs=-3, yoffs=2) # , minor_index=2 # ATTENTION: niemals dieses minor index machen save_visualization(pdf=True) def create_full_matrix2(chi_2_ur_numeric, chi_2_ul_numeric): """ Creates all areas in the frequency plot from the reduced numeric matrix. :param chi_2_numeric: The numeric matrix :return: The full matrix """ steps = chi_2_ur_numeric.shape[0] chi_2_ur = chi_2_ur_numeric.copy() for i in range(steps): for j in range(steps): if i >= j: chi_2_ur[i][j] = chi_2_ur[j][i] # matrix for lower left corner chi_2_ll = np.conj(np.flip(chi_2_ur)) # ok man könnte das auch gleich richtig abspeichern aber so geht das halt auch # matrix for upper left corner chi_2_ul = np.transpose(chi_2_ul_numeric).copy() for i in range(steps): for j in range(steps): if i <= j: chi_2_ul[i][j] = np.conj(chi_2_ul[j][i]) chi_2_ul = np.flip(chi_2_ul, 1) # matrix for lower right corner chi_2_lr = np.conj(np.flip(chi_2_ul)) # put all domains together chi_2 = np.zeros(shape=(2 * steps, 2 * steps), dtype=complex) for i in range(2 * steps): for j in range(2 * steps): # upper right if i >= steps and j >= steps: chi_2[i][j] = chi_2_ur[i - steps][j - steps] # upper left if i >= steps > j: chi_2[i][j] = chi_2_ul[i - steps][j - steps] # lower left if i < steps and j < steps: chi_2[i][j] = chi_2_ll[i][j] # lower right if i < steps <= j: try: chi_2[i][j] = chi_2_lr[i][j - steps] except: print('chi something') embed() return chi_2 def create_full_matrix(chi_2_numeric): """ Creates all areas in the frequency plot from the reduced numeric matrix. :param chi_2_numeric: The numeric matrix :return: The full matrix """ steps = chi_2_numeric.shape[0] # matrix for upper right corner chi_2_ur = chi_2_numeric.copy() for i in range(steps): for j in range(steps): if i >= j: chi_2_ur[i][j] = chi_2_ur[j][i] # matrix for lower left corner chi_2_ll = np.conj(np.flip(chi_2_ur)) # matrix for upper left corner chi_2_ul = chi_2_numeric.copy() for i in range(steps): for j in range(steps): if i <= j: chi_2_ul[i][j] = np.conj(chi_2_numeric[j][i]) chi_2_ul = np.flip(chi_2_ul, 1) # matrix for lower right corner chi_2_lr = np.conj(np.flip(chi_2_ul)) # put all domains together chi_2 = np.zeros(shape=(2 * steps, 2 * steps), dtype=complex) for i in range(2 * steps): for j in range(2 * steps): # upper right if i >= steps and j >= steps: chi_2[i][j] = chi_2_ur[i - steps][j - steps] # upper left if i >= steps > j: chi_2[i][j] = chi_2_ul[i - steps][j - steps] # lower left if i < steps and j < steps: chi_2[i][j] = chi_2_ll[i][j] # lower right if i < steps <= j: chi_2[i][j] = chi_2_lr[i][j - steps] return chi_2 def get_axis_on_full_matrix(full_matrix, stack_final): stack_final = pd.DataFrame(full_matrix, index=np.array(list(map(int, np.concatenate( [-stack_final.index[::-1], stack_final.index])))), columns=np.array(list(map(int, np.concatenate( [-stack_final.columns[::-1], stack_final.columns]))))) return stack_final def get_stack_one_quadrant(cell, cell_add, cells_save, path1, save_name_rev, direct_load=False, redo=False, creation_time_update=False, size_update=True): stack_saved = get_stack_initial(cell, cell_add, cells_save, path1, save_name_rev, direct_load=direct_load, redo=redo, creation_time_update=creation_time_update, size_update=size_update) stack_plot = change_model_from_csv_to_plots(stack_saved) try: stack_plot1 = RAM_norm(stack_plot, model_show=stack_saved) except: print('model thing2') embed() return stack_plot1, stack_saved def get_stack_initial(cell, cell_add, cells_save, path1, save_name_rev, direct_load=False, redo=False, creation_time_update=False, size_update=True): if direct_load: if '.pkl' in path1: model = pd.read_pickle(path1) # pd.read_pickle(path) else: model = pd.read_csv(path1, index_col=0) # pd.read_pickle(path) else: model = load_model_susept(path1, cells_save, save_name_rev.split(r'/')[-1] + cell_add, redo=redo, creation_time_update=creation_time_update, size_update=size_update) #embed() model_show = model[(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)] return model_show def find_noise_names(b, base=''): try: noise_there = True except: noise_there = False if noise_there: if base == '': names_mt_gwns = find_names_gwn(b) else: names_mt_gwns = find_mt(b, 'init') # ich glaube das stimmt so nicht! else: names_mt_gwns = [] return names_mt_gwns, noise_there def get_groups_ram(base_properties, file_name, data_between_2017_2018=''): try: base_properties = base_properties.sort_values(by='c', ascending=False) except: print('contrast problem sorting') embed() # hier muss ich nochmal nach dem file sortieren! if data_between_2017_2018 != 'all': file_name_sorted = base_properties[base_properties.file_name == file_name] else: file_name_sorted = base_properties del base_properties if len(file_name_sorted) < 1: print('file_name problem') embed() file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1] # ich sollte auf dem level schon nach dem richtigen filename filtern! grouped = file_name_sorted.groupby('c') return grouped def calc_abs_power(isfs): cross = np.zeros(len(isfs[0]), dtype=np.complex_) for i, isf in enumerate(isfs): cross += np.abs(isf) ** 2 return cross def get_osf_restricted(deltat, max_f, spikes_hann, stimulus_hann, fft_i='forward', fft_o='forward'): f_orig = np.fft.fftfreq(len(stimulus_hann), deltat) f, restrict = restrict_freqs(f_orig, max_f) f_range = np.arange(0, len(f), 1) f_same = f_orig[restrict] # # also hier haben wir auch ortho weil das ist auch nah an der Formeln dran! # embed() # ALSO FORWARD IST AUCHLAUT MASCHA RICHTIG! # there none is backward and for forward you need an extra version try: osf = np.fft.fft(spikes_hann - np.mean(spikes_hann), norm=fft_o) # das sollte ortho seid # Und wenn ich den Input genau so nehme wie er produziert wird sollte das passen nicht wahr? # also forward ist hier das richtige denn es ist das gegenteil von der generierung # also das soll wohl hier hin weil das genau das ist was wir geneirt haben isf = np.fft.fft(stimulus_hann - np.mean(stimulus_hann), norm=fft_i) # /nfft # nas sollte forward sein except: if fft_o == 'backward': osf = np.fft.fft(spikes_hann) # das sollte ortho seid elif fft_o == 'forward': osf = np.fft.fft(spikes_hann) * deltat if fft_i == 'backward': isf = np.fft.fft(stimulus_hann) # /nfft # nas sollte forward sein elif fft_i == 'forward': isf = np.fft.fft(stimulus_hann) * deltat # /nfft # nas sollte forward sein left = np.argmin(np.abs(np.sort(f_orig)) - 0) - 10 left2 = np.argmin(np.abs(np.sort(f_orig)) - 0) d_isf = np.mean(np.abs(isf)[np.argsort(f_orig)][left:left2]) d_osf = np.mean(np.abs(osf)[np.argsort(f_orig)][left:left2]) # ah wir haben für dieses d_osf wieder an der Null geschaut und nicht die varianz genommen, # ja deswegen ist das so komisch, und nicht mal hoch zwei d_osf1 = np.abs(osf)[np.argsort(f_orig)][left2 - 1] d_isf1 = np.abs(isf)[np.argsort(f_orig)][left2 - 1] isf = isf[restrict] osf = osf[restrict] return d_isf, d_isf1, d_osf, d_osf1, f, f_orig, f_range, f_same, isf, osf, restrict def get_psds_for_coherence(amp, b, cut_off, file_name_save, indices, max_val, mt, names_mt_gwn, nfft, nr_snippets, sampling, p11s=[], stimulus_given=[], p12s=[], give_stimulus=False, p22s=[], spikes_mat_given=[], overlapp='', dev='original', mean=''): if overlapp == '': nr_snippets = nr_snippets * 2 if len(p11s) < 1: p22s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_) p12s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_) p11s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_) length = range_with_overlap(nfft, overlapp, max_val * sampling) length_array_isf = [[]] * len(length) length_array_osf = [[]] * len(length) length_array_stimulus = [[]] * len(length) length_array_spikes = [[]] * len(length) a_mi = 0 mats = [] stims = [] count_final = 0 for mm, m in enumerate(indices): print(mm) first, minus, second, stimulus_length = find_first_second(b, names_mt_gwn, m, mt, mm=mm) if len(spikes_mat_given) > 0: spikes_mt = spikes_mat_given[m] else: spikes_mt = link_arrays_spikes(b, first, second, minus) if len(spikes_mt) > 2: if 'first' in mean: if second > 10: second = 10 if len(stimulus_given) > 0: eod_interp = stimulus_given deltat = 1 / sampling else: deltat, eod_interp, eodf_size, sampling, time_array = load_presaved(b, amp, file_name_save, first, m, mt, sampling, second) if len(spikes_mat_given) > 0: spikes_mat = spikes_mat_given[m] else: spikes_mat = cr_spikes_mat(spikes_mt, sampling, len(eod_interp)) # [0:-2]int(stimulus_length * 1 / deltat) if dev == '05': # 'original': window05 = 0.0005 * sampling spikes_mat = gaussian_filter(spikes_mat, sigma=window05) elif dev == '2': window05 = 0.002 * sampling spikes_mat = gaussian_filter(spikes_mat, sigma=window05) elif dev == '5': window05 = 0.005 * sampling spikes_mat = gaussian_filter(spikes_mat, sigma=window05) elif dev == '10': window05 = 0.01 * sampling spikes_mat = gaussian_filter(spikes_mat, sigma=window05) mats.append(spikes_mat) stims.append(eod_interp) # fft_i = 'forward' fft_o = 'forward' if overlapp == '': pass else: pass isfs = [] osfs = [] stimulus = [] spikes = [] for aa, a in enumerate(length): stimulus_array = eod_interp[int(0 + a):int(a + nfft)] spikes_array = spikes_mat[int(0 + a):int(a + nfft)] hann_true = True if hann_true: hann = np.hanning(len(stimulus_array)) try: stimulus_hann = (stimulus_array - np.mean(stimulus_array)) * hann except: print('stimulus something') embed() hann = np.hanning(len(spikes_array)) spikes_hann = (spikes_array - np.mean(spikes_array)) * hann if (len(stimulus_hann) == nfft) and (len(spikes_hann) == nfft): d_isf, d_isf1, d_osf, d_osf1, f, f_orig, f_range, f_same, isf, osf, restrict = get_osf_restricted( deltat, cut_off, spikes_hann, stimulus_hann, fft_i, fft_o) isfs.append(isf) osfs.append(osf) spikes.append(np.array(spikes_hann)) stimulus.append(np.array(stimulus_hann)) a_mi += 1 else: isf = float('nan') * np.zeros(int(nfft / 2 - 1), dtype=np.complex_) isfs.append(isf) osf = float('nan') * np.zeros(int(nfft / 2 - 1), dtype=np.complex_) osfs.append(osf) spikes_hann = float('nan') * np.zeros(int(nfft)) # , dtype=np.complex_ stimulus_hann = float('nan') * np.zeros(int(nfft)) # , dtype=np.complex_ spikes.append(spikes_hann) stimulus.append(stimulus_hann) try: length_array_isf[aa] except: print('length vals') embed() if len(length_array_isf[aa]) < 1: length_array_isf[aa] = [isf] if give_stimulus: length_array_spikes[aa] = [spikes_hann] length_array_stimulus[aa] = [stimulus_hann] else: try: length_array_isf[aa].append(isf) except: print('append thing') embed() if give_stimulus: length_array_spikes[aa].append(spikes_hann) length_array_stimulus[aa].append(stimulus_hann) if len(length_array_osf[aa]) < 1: length_array_osf[aa] = [osf] else: try: length_array_osf[aa].append(osf) except: print('append thing') embed() p22, count_final = crossSpectrum(osfs, osfs) p12, count_final = crossSpectrum(isfs, osfs) p11, count_final = crossSpectrum(isfs, isfs) p22s += p22 p12s += p12 p11s += p11 count_final += 1 if len(length_array_osf) > nr_snippets: print('length something0') embed() length_array_osf = np.array(length_array_osf) length_array_isf = np.array(length_array_isf) if give_stimulus: length_array_stimulus = np.array(length_array_stimulus) length_array_spikes = np.array(length_array_spikes) print('done mm') return count_final, stims, mats, a_mi, f_same, length, p11s, p12s, p22s, length_array_isf, length_array_osf, length_array_spikes, length_array_stimulus, def coherence_and_mutual_response_wo_sqrt(a_mir, a_mir2, cut_vals, p12_rrs, p22_rrs): coh_resp = np.abs(p12_rrs / a_mir) ** 2 / (p22_rrs.real / a_mir2) ** 2 mutual_information_resp = - np.log2(1 - coh_resp[cut_vals]) # np.sum(* np.diff(freq)[0] return coh_resp, mutual_information_resp def prepeare_test_arrays(indices, length, length_array_isf, length_array_osf, nr_snippets): length_array_isf_test = [[]] * nr_snippets length_array_osf_test = [[]] * nr_snippets for _, _ in enumerate(indices): for aa, a in enumerate(length): if len(length_array_isf_test[aa]) < 1: length_array_isf_test[aa] = [length_array_isf[0][0]] length_array_osf_test[aa] = [length_array_osf[0][0]] else: try: length_array_isf_test[aa].append(length_array_isf[0][0]) length_array_osf_test[aa].append(length_array_osf[0][0]) except: print('append thing') embed() length_array_osf_test = np.array(length_array_osf_test) length_array_isf_test = np.array(length_array_isf_test) return length_array_osf_test, length_array_isf_test def rescale_colorbar_and_values(abs_matrix, add_nonlin_title=None, resize_val=None): # das auf jeden Fall auf der finalen Matrix machen! if add_nonlin_title: resize_val = find_resize(add_nonlin_title) max_val = np.max(np.max(abs_matrix, axis=0), axis=0) if not resize_val: if max_val > 1000000000000: resize_val = 1000000000000 elif max_val > 1000000000: resize_val = 1000000000 elif max_val > 1000000: resize_val = 1000000 elif max_val > 1000: resize_val = 1000 elif max_val < 0.000000000001: resize_val = 0.000000000001 # pico elif max_val < 0.000000001: resize_val = 0.000000001 # nano elif max_val < 0.000001: resize_val = 0.000001 # micro elif max_val < 0.001: resize_val = 0.001 # mili else: resize_val = 1 try: abs_matrix = abs_matrix / resize_val except: print('resize thing') embed() add_nonlin_title = find_add_title(resize_val) return abs_matrix, add_nonlin_title, resize_val def find_resize(add_nonlin_title): if add_nonlin_title == 'k': resize_val = 1000 elif add_nonlin_title == 'M': resize_val = 1000000 elif add_nonlin_title == 'G': resize_val = 1000000000 elif add_nonlin_title == 'T': resize_val = 1000000000000 elif add_nonlin_title == 'P': resize_val = 1000000000000000 elif add_nonlin_title == 'E': resize_val = 1000000000000000000 elif add_nonlin_title == 'p': resize_val = 0.000000000001 elif add_nonlin_title == 'n': resize_val = 0.000000001 elif add_nonlin_title == '$\mu$': resize_val = 0.000001 elif add_nonlin_title == 'm': resize_val = 0.001 elif add_nonlin_title == 'c': resize_val = 0.01 elif add_nonlin_title == 'd': resize_val = 0.1 elif add_nonlin_title == '': resize_val = 1 return resize_val def find_add_title(resize_val): if resize_val == 1000: add_nonlin_title = 'k' elif resize_val == 1000000: add_nonlin_title = 'M' elif resize_val == 1000000000: add_nonlin_title = 'G' elif resize_val == 1000000000000: add_nonlin_title = 'T' elif resize_val == 1000000000000000: add_nonlin_title = 'P' # Peta elif resize_val == 1000000000000000000: add_nonlin_title = 'E' # Peta elif resize_val == 0.000000000001: # pico add_nonlin_title = 'p' elif resize_val == 0.000000001: # pico add_nonlin_title = 'n' elif resize_val == 0.000001: add_nonlin_title = '$\mu$' elif resize_val == 0.001: add_nonlin_title = 'm' # mili elif resize_val == 0.01: add_nonlin_title = 'c' # centi elif resize_val == 0.1: add_nonlin_title = 'd' # deci elif resize_val == 1: add_nonlin_title = '' # deci return add_nonlin_title def range_with_overlap(nfft, overlap, lenth_here): if overlap == '_nooverlap_': length = np.arange(0, lenth_here, nfft) else: length = list(map(int, np.arange(0, lenth_here, nfft / 2))) return length def gaussKernel(sigma, dt): """ Creates a Gaussian kernel with a given standard deviation and an integral of 1. Parameters ---------- sigma : float The standard deviation of the kernel in seconds dt : float The temporal resolution of the kernel, given in seconds. Returns: np.array The kernel in the range -4 to +4 sigma """ x = np.arange(-4. * sigma, 4. * sigma, dt) y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma return y def firing_rate(spikes, duration, sigma=0.005, dt=1. / 20000.): """Convert spike times to a firing rate estimated by kernel convolution with a Gaussian kernel. Args: spikes (np.array): the spike times duration (float): the trial duration sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005. dt (float, optional): desired temporal resolution of the firing rate. Defaults to 1./20000.. Returns: np.array: the firing rate """ binary = np.zeros(int(np.round(duration / dt))) indices = np.asarray(np.round(spikes / dt), dtype=int) binary[indices[indices < len(binary)]] = 1 kernel = gaussKernel(sigma, dt) rate = np.convolve(kernel, binary, mode="same") return rate def get_rates(rr, time, dt, sigma): valid_stim_count = sum([1 for s in rr if s.duration == time[-1]]) rates = np.zeros((valid_stim_count, len(time))) index = 0 for i in range(rr.stimulus_count): if rr[i].duration != time[-1]: continue spikes = rr.spikes(i) rates[index, :] = firing_rate(spikes, time[-1], sigma, dt=dt) index += 1 return rates def coherence(rates, stim, nperseg, noverlap, dt): assert (rates.shape[1] == len(stim)) all_rate_spectra, all_stim_spectra, f = get_rates_stacked(dt, noverlap, nperseg, rates, stim) csd = cross_spectrum(all_stim_spectra, all_rate_spectra) stimasd = auto_spectrum(all_stim_spectra) respasd = auto_spectrum(all_rate_spectra) coh = csd / (stimasd * respasd) return f[f >= 0], coh[f >= 0], all_rate_spectra, all_stim_spectra def get_rates_stacked(dt, noverlap, nperseg, rates, stim): stim_segments = get_segments(stim, nperseg, noverlap) f, stim_spectra = spectra(stim_segments, dt) for i in range(rates.shape[0]): rate_segments = get_segments(rates[i, :], nperseg, noverlap) _, rate_spectra = spectra(rate_segments, dt) if i == 0: all_rate_spectra = rate_spectra all_stim_spectra = stim_spectra else: all_rate_spectra = np.vstack((all_rate_spectra, rate_spectra)) all_stim_spectra = np.vstack((all_stim_spectra, stim_spectra)) # hier hat er sie alle appended also nicht direct return all_rate_spectra, all_stim_spectra, f def exp_coherence(rates, nperseg, noverlap, dt): mrate = np.mean(rates, axis=0) mrate_segments = get_segments(mrate, nperseg, noverlap) f, mrate_spectra = spectra(mrate_segments, dt) for i in range(rates.shape[0]): rate_segments = get_segments(rates[i, :], nperseg, noverlap) _, rate_spectra = spectra(rate_segments, dt) if i == 0: all_mrate_spectra = mrate_spectra all_rate_spectra = rate_spectra else: all_mrate_spectra = np.vstack((all_mrate_spectra, mrate_spectra)) all_rate_spectra = np.vstack((all_rate_spectra, rate_spectra)) csd = cross_spectrum(all_mrate_spectra, all_rate_spectra) mrateasd = auto_spectrum(all_mrate_spectra) rateasd = auto_spectrum(all_rate_spectra) c = csd / (rateasd * mrateasd) return f[f >= 0], c[f >= 0] def coherences(rates, s, dt, nperseg=2 ** 14): f, gamma, all_rate_spectra, all_stim_spectra = coherence(rates, s, nperseg, nperseg // 2, dt) _, exp_gamma = exp_coherence(rates, nperseg, nperseg // 2, dt) _, rr_gamma = rr_coherence(rates, nperseg, nperseg // 2, dt) return f, gamma, exp_gamma, rr_gamma, all_rate_spectra, all_stim_spectra def plt_cohs_ich(ax, coh, coh_resp, coh_resp_mean, coh_resp_directs, coh_resp_restrict, coh_s_directs, cut_off, f_same): ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp_restrict[f_same < cut_off]), label='coherence_r_restrict', color='purple') ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp[f_same < cut_off]), label='coherence_r', color='orange') ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp_directs[f_same < cut_off]), label='coherence_r_direct', color='orange', linestyle='--') ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_s_directs[f_same < cut_off]), label='coh_s_directs', color='blue', linestyle='--') ax[1].plot(f_same[f_same < cut_off], coh_resp_mean[f_same < cut_off], label='coherence_r_expected', color='green') ax[1].plot(f_same[f_same < cut_off], coh[f_same < cut_off], label='coherence_s', color='blue') def plt_cohs_jan(ax, cut_off, exp_gamma, f_jan, gamma, rr_gamma): ax[0].plot(f_jan[f_jan < cut_off], gamma[f_jan < cut_off], label='gamma') ax[0].plot(f_jan[f_jan < cut_off], exp_gamma[f_jan < cut_off], label='exp_gamma') ax[0].plot(f_jan[f_jan < cut_off], rr_gamma[f_jan < cut_off], label='rr_gamma') ax[0].legend() def get_mats_same_shape(indices, mats, mt, sampling, stims): length_val = np.max(mt.extents[:][indices]) * sampling mats_jan = [] for mm, m in enumerate(mats): if len(m) == length_val: mats_jan.append(np.array(m)) stim_jan = stims[mm] mats_jan = np.array(mats_jan) return mats_jan, stim_jan def tranfer_xlabel(): return '$f/'+f_eod_name_core_rm()+'$' # \,[Hz] def tranfer_xlabel_hz(): return '$f$ [Hz]' # \,[Hz] def diagonal_xlabel(): return '$f_{1}+f_{2}$\,[Hz]' def diagonal_xlabel_nothz(): return '$(f_{1}+f_{2})/'+f_eod_name_core_rm()+'$' def NLI_scorename2(): return 'PNL$(f'+basename()+')$' def NLI_name2(): return 'PNL' def NLI_scorename(): return 'NLI$(f'+basename()+')$' def join_x(axts_all): axts_all[0].get_shared_x_axes().join(*axts_all) def join_y(axts_all, mult_val=1.00): axts_all[0].get_shared_y_axes().join(*axts_all) if axts_all[0].get_ylim()[-1] != axts_all[1].get_ylim()[-1]: try: for a in range(len(axts_all) - 1): first = axts_all[a].get_ylim() # [-1] second = axts_all[a + 1].get_ylim() # [-1] starting_val = np.min([first[0], second[0]]) end_val = np.max([first[1], second[1]]) axts_all[a].set_ylim(starting_val, end_val * mult_val) axts_all[a + 1].set_ylim(starting_val, end_val * mult_val) except: print('joiny something') def find_peaks_simple(eodf, freq1, freq2, name, color1, color2): if name == '01': freqs = [np.abs(freq1), eodf, freq1 + eodf] colors_peaks = [color1, 'black', color1] alpha = [1, 1, 0.5] labels = ['DF1', 'EODf', 'F1'] elif name == '02': freqs = [np.abs(freq2), eodf, freq2 + eodf] colors_peaks = [color2, 'black', color2] alpha = [1, 1, 0.5] labels = ['DF2', 'EODf', 'F2'] elif name == '012': freqs = [np.abs(freq1), np.abs(freq2), eodf, freq1 + eodf, freq2 + eodf] colors_peaks = ['blue', 'red', 'black', 'blue', 'red'] alpha = [1, 1, 1, 0.5, 0.5] labels = ['DF1', 'DF2', 'EODf', 'F1', 'F2'] elif name == '0': freqs = [np.abs(freq1), eodf, freq1 + eodf] colors_peaks = [color1, 'black', color1] alpha = [1, 1, 0.5] labels = ['DF1', 'EODf', 'F1'] return labels, alpha, colors_peaks, freqs def plt_single_contrast(f_counter, axts, axps, f, grid_ll, c_nn, freq1, freq2, eodf, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, a, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, colors_array, reshuffled, sampling, cell_here, c_nr, n, dev_name, c_nn_nr=1, xpos=1, ypos=1.35, small_peaks=True, second=True, ws=0.6, start=1, legend=False, v_mem_choice=True): array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names = calc_roc_amp_core_cocktail_for_plot( [freq1 + eodf], [freq2 + eodf], auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 0.0005, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n, reshuffled=reshuffled) array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_orig, names = calc_roc_amp_core_cocktail_for_plot( [freq1 + eodf], [freq2 + eodf], auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n, reshuffled=reshuffled) time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) if v_mem_choice: arrays_here = v_mems[start::] else: if start == 1: arrays_here = [arrays[start::][0][0], arrays[start::][1][0], arrays[start::][2][0]] else: arrays_here = [arrays[start::][0][0], arrays[start::][1][0], arrays[start::][2][0], arrays[start::][3][0]] if start == 1: arrays_here_psd = [arrays_orig[start::][0][0], arrays_orig[start::][1][0], arrays_orig[start::][2][0]] else: arrays_here_psd = [arrays_orig[start::][0][0], arrays_orig[start::][1][0], arrays_orig[start::][2][0], arrays_orig[start::][3][0]] names = np.array(['0', '01', '02', '012'])[start::] spikes_here = arrays_spikes[start::] colors_array_here = colors_array[start::] pps = [] for a in range(len(arrays_here)): grid_pt = gridspec.GridSpecFromSubplotSpec(1, 2, hspace=0.45, wspace=ws, width_ratios=[1, 1.2], subplot_spec=grid_ll[a, c_nn]) # hspace=0.4,wspace=0.2,len(chirps) axt = plt.subplot(grid_pt[0]) axp = plt.subplot(grid_pt[1]) if v_mem_choice: axt.eventplot(spikes_here[a], lineoffsets=np.max(arrays_here[a]), color='black') # np.max(v1)* else: axt.eventplot(spikes_here[a], lineoffsets=np.max(arrays_here[a]), linelengths=np.max(arrays_here[a]) * 0.1, color='black') # np.max(v1)* axts.append(axt) axps.append(axp) if f != 0: remove_yticks(axt) if a != len(arrays_here) - 1: remove_xticks(axt) if f_counter == 0: axt.set_ylabel(names[a]) set_amplitude_titles(a, a_f2s, arrays_here, axt, c_nr, colors_array_here, start, time) axt.set_xlim(0.1, 0.22) try: pp, ff = ml.psd(arrays_here_psd[a] - np.mean(arrays_here_psd[a]), Fs=sampling, NFFT=nfft, noverlap=nfft // 2) except: print('pp problems') embed() pps.append(pp) axp.plot(ff, pp, color=colors_array_here[a]) # colors_contrasts[c_nn] maxx = 1000 axp.set_xlim(-10, maxx) if small_peaks: labels, alpha, colors_peaks, freqs = find_peaks_simple(eodf, freq1, freq2, names[a], colors_array[1], colors_array[2]) else: alpha, labels, colors_peaks, freqs = mult_beat_freqs(eodf, maxx, freq1, color_df_mult=colors_array[1], color_eodf='black', color_stim='orange', color_stim_mult='pink', ) set_titles_freqs(a, axt, c_nn, c_nn_nr, eodf, freq1, freq2, second, start, xpos, ypos) plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, alphas=alpha) if a != 2: remove_xticks(axp) if c_nn != 0: remove_yticks(axt) remove_yticks(axp) else: axt.set_ylabel('Hz') axp.set_ylabel('Hz') if legend: axp.legend(loc=(-0.3, 1.2), ncol=3) axt.set_xlabel('Time [s]') axp.set_xlabel('Frequency [Hz]') try: f_counter += 1 except: print('counter thing') embed() return f_counter def set_amplitude_titles(a, a_f2s, arrays_here, axt, c_nr, colors_array_here, start, time): if start == 1: if a == 0: axt.set_title(' Amplitude 1 = ' + str(c_nr) + ', Amplitude 2 = 0') elif a == 1: axt.set_title(' Amplitude 1 = 0,' + ' Amplitude 2 = ' + str(a_f2s[0])) else: axt.set_title(' Amplitude 1 = ' + str(c_nr) + ', Amplitude 2 = ' + str(a_f2s[0])) try: axt.plot(time, arrays_here[a], color=colors_array_here[a]) # colors_contrasts[c_nn] except: print('time something') embed() def set_titles_freqs(a, axt, c_nn, c_nn_nr, eodf, freq1, freq2, second, start, xpos, ypos): if c_nn == c_nn_nr: if start == 1: if a == 0: # if second: second_part = 'F1=' + str(np.round(int(freq1 + eodf))) + 'Hz' + ' DF1=F1-EODf=' + str( int(np.round(freq1))) + 'Hz' else: second_part = '' axt.text(xpos, ypos, 'Only Frequency 1 (F1): \n' + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) elif a == 1: if second: second_part = 'F2=' + str(np.round(int(freq2 + eodf))) + 'Hz ' + 'F2-EODf=' + str( int(np.round(freq2))) + ' Hz ' else: second_part = '' axt.text(xpos, ypos, 'Only Frequency 2 (F2): \n' + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) else: if second: second_part = 'F1=' + str(int(np.round(freq1 + eodf))) + 'Hz' + ' F1-EODf=' + str( int(np.round(freq1))) + 'Hz' + ' F2=' + str(int(freq2 + eodf)) + 'Hz ' + 'DF2=F2-EODf=' + str( int(np.round(freq2))) + ' Hz ' else: second_part = '' axt.text(xpos, ypos, 'Frequency 1 (F1) + Frequency 2 (F2): \n' + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) def find_diffs(c, frame_cell, diffs, add=''): if c == 'c1': # 'B1_diff' try: frame_cell['diff'] = np.sqrt( (frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs except: # irgnedwann habe ich das Format geändert deswegen try: add = '_original' frame_cell['diff'] = np.sqrt( (frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs except: add = '' frame_cell['diff'] = np.sqrt( (frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs else: # 'B2_diff' try: frame_cell['diff'] = np.sqrt( frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs except: try: add = '_original' frame_cell['diff'] = np.sqrt( frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs except: add = '' frame_cell['diff'] = np.sqrt( frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs return frame_cell def find_deltas(frame_cell, c): diffs = list(np.diff(frame_cell[c])) diffs.extend([np.diff(frame_cell[c])[-1]]) diffs = np.array(diffs) return diffs def find_dfs(frame_cell): f1s = np.unique(frame_cell.f1) f2s = np.unique(frame_cell.f2) df1s = f1s - frame_cell.f0.unique() df2s = f2s - frame_cell.f0.unique() # für den Fall dass df falsch ausgerechnet wurde frame_cell['df1'] = frame_cell.f1 - frame_cell.f0 frame_cell['df2'] = frame_cell.f2 - frame_cell.f0 return frame_cell, df1s, df2s, f1s, f2s def plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1, freq2, add='', B_replace='DF', sum=True, add_label='', alpha=[1, 1, 1, 1], linewidths=[], xscale='', c_dist_recalc=True, lw=1.5, linestyles=['-', '-', '-', '-', '-'], nr=4, labels=[], scores=['amp_B1_01_mean_original', 'amp_B1_012_mean_original', 'amp_B2_02_mean_original', 'amp_B2_012_mean_original'], colors=['green', 'blue', 'orange', 'red', 'grey'], lim_recalc = (0, 70), default_colors=True, delta=True): if default_colors: if 'amp_B1_01_mean_original' in frame_cell_orig.keys(): add = '_mean_original' else: add = '_mean' labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add=add, nr=nr) frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(labels) < 1: labels = scores # .replace('amp_','') c1 = c_dist_recalc_here(c_dist_recalc, frame_cell) for sss, score in enumerate(scores): if len(linewidths) > 1: lw = linewidths[sss] try: ax_u1.plot(c1, frame_cell[score], zorder=100, color=colors[sss], alpha=alpha[sss], label=labels[sss].replace('_mean', '').replace('amp_', '').replace('B', B_replace).replace( 'original', '').replace('original', 'or').replace('distance', 'dist').replace('power', '') + add_label, linestyle=linestyles[sss], linewidth=lw) except: # vals = frame_cell.keys()#frame_cell.filter(like='var').keys() print('linestyle problem') embed() ax_upper.append(ax_u1) if sum: ax_u1.plot(c1, np.sqrt(frame_cell['amp_B2_012_mean' + add] ** 2 + frame_cell['amp_B1_012_mean' + add] ** 2), zorder=100, color='grey', label='B1+B2_012', linestyle='--') ax_u1.plot(c1, np.sqrt(frame_cell['amp_B2_02_mean' + add] ** 2 + frame_cell['amp_B1_01_mean' + add] ** 2), zorder=100, color='black', label='B1_01+B2_02', linestyle='-') if c_dist_recalc: if lim_recalc: ax_u1.set_xlim(lim_recalc) ax_u1.set_xlabel('C1 Distance [cm]') else: ax_u1.set_xlabel('Contrast$_{1}$ [$\%$]') if xscale == 'log': ax_u1.set_xscale('log') ax_u1.set_ylabel(representation_ylabel(delta=delta)) return ax_upper def c_dist_recalc_here(c_dist_recalc, frame_cell): c1 = c_dist_recalc_func(frame_cell, cell=frame_cell.cell.unique()[0], c_dist_recalc=c_dist_recalc) if not c_dist_recalc: c1 = np.array(c1) * 100 return c1 def representation_ylabel(delta=True): if delta: val = 'Amplitude $A(\Delta f)$ [Hz]' else: val = 'Amplitude $A(f)$ [Hz]' return val def c_dist_recalc_func(frame_cell=[], mult_eod=0.5, c_nrs=[], cell=[], eod_size_change=True, c_dist_recalc=True, recalc_contrast_in_perc=1): if len(c_nrs) < 1: c_nrs = frame_cell.c1 else: c_nrs = np.array(c_nrs) if c_dist_recalc: # # ich weiß noch nicht ob das jetzt für mv oder kontraste stimmen sollte if eod_size_change: try: baseline, b, eod_size, = load_eod_size(cell, max='perc') except: try: update_ssh_file() baseline, b, eod_size, = load_eod_size(cell, max='perc') except: print('EODF SIZE ESTIMATION BIASED') eod_size = 1 else: eod_size = 1 # bassiert auf henninger 2020 print('eodfsize' + str(eod_size)) # also eod size mal 0.5 um den maximalen wert zu haben und mal den kotnrast c1 = c_to_dist(eod_size * mult_eod * c_nrs) else: c1 = np.array(c_nrs) * recalc_contrast_in_perc # frame_cell.c1 return c1 def calc_cv_three_wave(results_diff, position_diff, arrays=[], adds=[]): for a, add in enumerate(adds): if len(arrays[0]) > 1: # das ist für mehrere Trials isi = np.diff( arrays[a][0]) # auch hier nehmen wir erstmal nur den ersten Trial sonst wird das komplex und zu viel else: # für einen Trial isi = np.diff(arrays[a][0]) try: results_diff.loc[position_diff, 'cv' + add] = np.std(isi) / np.mean(isi) except: print('ROC problem') embed() results_diff.loc[position_diff, 'std_isi' + add] = np.std(isi) results_diff.loc[position_diff, 'mean_isi' + add] = np.mean(isi) results_diff.loc[position_diff, 'mean_isi' + add] = np.median(isi) burst_1, burst_2 = calc_burst_perc(results_diff.loc[position_diff, 'f0'], isi) results_diff.loc[position_diff, 'burst_1' + add] = burst_1 results_diff.loc[position_diff, 'burst_2' + add] = burst_2 return results_diff def deltat_sampling_factor(sampling_factor, deltat, eod_fr): if sampling_factor == 'EODmult': deltat = 1 / (eod_fr * 2) elif sampling_factor != '': deltat = sampling_factor return deltat def eod_fish_e_generation(time_array, a_fe=0.2, eod_fe=[750], e=0, phaseshift_fr=0, sampling=20000, stimulus_length=1, nfft_for_morph=2 ** 14, cell_recording='', fish_morph_harmonics_var='analyze', zeros='zeros', mimick='no', fish_emitter='Alepto', fish_jammer='Alepto', thistype='emitter'): # WICHTIG: die ersten vier time_fish_e = time_array * 2 * np.pi * eod_fe[e] if (a_fe == 0) and (zeros != 'zeros'): eod_fish_e = np.ones(len(time_array)) else: # in case you want to mimick the second here you just can do it based on the dictonary of the thunderfish, # since here we were not interested in more but one could adapt this based on the eod_fish_r_genration function if ('Emitter' in mimick) and (thistype == 'emitter'): if 'Wavemorph' in mimick: input, eod_fr_data, data_array_eod, time_data_eod, eod_fish_e, pp, ff, p_array_new, f_new, amp, phase, b, t = thunder_morph_func( phaseshift_fr, cell_recording, eod_fe[e], sampling, stimulus_length, a_fe, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var) else: eod_fish_e = fakefish.wavefish_eods(fish_emitter, frequency=eod_fe[e], samplerate=sampling, duration=stimulus_length, phase0=0.0, noise_std=0.00) if ('Zenter' in mimick) and ('NotZentered' not in mimick): eod_fish_e = zenter_and_normalize(eod_fish_e, a_fe) elif ('Jammer' in mimick) and (thistype == 'jammer'): if 'Wavemorph' in mimick: input, eod_fr_data, data_array_eod, time_data_eod, eod_fish_e, pp, ff, p_array_new, f_new, amp, phase, b, t = thunder_morph_func( phaseshift_fr, cell_recording, eod_fe[e], sampling, stimulus_length, a_fe, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var) else: eod_fish_e = fakefish.wavefish_eods(fish_jammer, frequency=eod_fe[e], samplerate=sampling, duration=stimulus_length, phase0=0.0, noise_std=0.00) if ('Zenter' in mimick) and ('NotZentered' not in mimick): eod_fish_e = zenter_and_normalize(eod_fish_e, a_fe) else: eod_fish_e = a_fe * np.sin(time_fish_e) # this is since some of the zeros can be negative, so just made them all positiv if (a_fe == 0) and (zeros == 'zeros'): eod_fish_e = np.abs(eod_fish_e) return eod_fish_e, time_fish_e def deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length, deltat=None): if not deltat: deltat = model_params.pop("deltat") deltat = deltat_sampling_factor(sampling_factor, deltat, eod_fr) sampling = 1 / deltat time_array = np.arange(0, stimulus_length, deltat) return time_array, sampling, deltat def get_arrays_for_three(cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle='reshuffled', length_adapt=True, dev=0.0005, zeros='', a_fr=1, params_dict={'burst_corr': ''}, redo_stim=True, nfft='', cell_recording='', phaseshift_fr=0, fish_emitter='', fish_receiver='', beat='', fish_jammer='', fish_morph_harmonics_var='', mimick='', nfft_for_morph='', phase_right='', sampling=''): ####################################### # do the 01 params_dict['eod_fish1'] = eod_fish1 # } stimulus_01, meansmoothed05_01, spikes_01, smoothed01, mat01, offset_new, v_mem_output_01 = do_array_for_three(nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM, eod_stimulus, eod_fish_r, freq2, a_f1, a_f2, cell, stimulus_length, offset, model_params, n, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, [], phaseshift_fr, variant, eod_fr, length_adapt=length_adapt, dict_here=params_dict, redo_stim=redo_stim, stim_type='01', reshuffle=reshuffle, dev=dev) # do the 02 stim_02, meansmoth05_02, spikes_02, smoothed02, mat02, offset_new, v_mem_02 = do_array_for_three(nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM, eod_stimulus, eod_fish_r, freq2, a_f1, a_f2, cell, stimulus_length, offset, model_params, n, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, [], phaseshift_fr, variant, eod_fr, length_adapt=length_adapt, dict_here=params_dict, redo_stim=redo_stim, stim_type='02', reshuffle=reshuffle, dev=dev) ####################################### # do the 012 stimulus_012, meansmoothed05_012, spikes_012, smoothed012, mat012, offset_new, v_mem_output_012 = do_array_for_three( nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM, eod_stimulus, eod_fish_r, freq2, a_f1, a_f2, cell, stimulus_length, offset, model_params, n, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, [], phaseshift_fr, variant, eod_fr, length_adapt=length_adapt, dict_here=params_dict, redo_stim=redo_stim, stim_type='012', reshuffle=reshuffle, dev=dev) print(offset) test = False if test: from utils_test import test_stimulus test_stimulus() return np.array([v_mem_output_01, v_mem_02, v_mem_output_012]), offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stim_02, stimulus_012, meansmoothed05_01, spikes_01, meansmoth05_02, spikes_02, meansmoothed05_012, spikes_012 def calc_roc_amp_core_cocktail_for_plot(freq1, freq2, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, a_f1s=[], dev_name='05', n=1, reshuffled='', test=False, SAM='_SAM_'): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) f1 = 0 f2 = 0 sampling_factor = '' phaseshift_fr = 0 cell_recording = '' mimick = 'no' zeros = 'zeros' fish_morph_harmonics_var = 'harmonic' fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus'] fish_receiver = 'Alepto' # phase_right = '_phaseright_' damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35 damping_type = '' exponential = '' # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) spikes_base = [[]] * trials_nr for run in range(runs): print(run) for t in range(trials_nr): stimulus_0 = eod_fish_r adapt_offset = 'adaptoffset_bisecting' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, v_mem0, noise_final = simulate(cell, offset, stimulus_0, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, power_alpha=alpha, power_nr=n, reshuffle=reshuffled, **model_params) print(' offset orig ' + str(offset)) test = False if test: from utils_test import test_cvs2 test_cvs2() if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 print(' Base ' + str(adapt_offset) + ' offset ' + str(offset)) if printing: print('Baseline time' + str(time.time() - t1)) _, _ = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) length_adapt = False base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev, length_adapt=length_adapt) fr = np.mean(base_cut) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) for aaa, a_f2 in enumerate(a_f2s): # [0] for aa, a_f1 in enumerate(a_f1s): # [0] t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 if test: eod_stimulus_d = eod_fish1 + eod_fish2 fig, ax = plt.subplots(2, 1, sharex=True, sharey=True) ax[0].plot(time_array, eod_stimulus_d) ax[1].plot(time_array, eod_stimulus) plt.show() if test: fig, ax = plt.subplots(4, 1, sharex=True, sharey=True) ax[0].plot(time_array, eod_fish_r) ax[1].plot(time_array, eod_fish2) ax[2].plot(time_array, eod_stimulus) ax[3].plot(time_array, eod_stimulus) ax[0].set_xlim(0, 0.1) plt.show() adapt_offset_later = '' v_mem, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset_later, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, length_adapt=length_adapt, dev=dev) if test: fig, ax = plt.subplots(4, 1, sharex=True, sharey=True) ax[0].plot(stimulus_0) ax[1].plot(stimulus_01) ax[2].plot(stimulus_02) ax[3].plot(stimulus_012) plt.show() if printing: print('Generation process' + str(time.time() - t1)) v_mems = np.concatenate([[v_mem0], v_mem]) array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] for dev_n in dev_name: results_diff.loc[position_diff, 'fr'] = fr results_diff.loc[position_diff, 'f1'] = freq1[0] results_diff.loc[position_diff, 'f2'] = freq2[0] results_diff.loc[position_diff, 'f0'] = eod_fr results_diff.loc[position_diff, 'df1'] = np.abs(eod_fr - freq1) results_diff.loc[position_diff, 'df2'] = np.abs(eod_fr - freq2) results_diff.loc[position_diff, 'cell'] = cell results_diff.loc[position_diff, 'c1'] = a_f1 results_diff.loc[position_diff, 'c2'] = a_f2 results_diff.loc[position_diff, 'trial_nr'] = trials_nr #################################### # calc cvs results_diff = calc_cv_three_wave(results_diff, position_diff, arrays=[spikes_base, spikes_01, spikes_02, spikes_012], adds=['_0', '_01', '_02', '_012']) if dev_n == '05': dev = 0.0005 # tp_02_all, tp_012_all # das mit den Means ist jetzt einfach nur ein # test wie ich die std und var und psd eigentlich gruppieren müsste if dev_n == 'original': array0 = [np.mean(mat0, axis=0)] array01 = [np.mean(mat01, axis=0)] array02 = [np.mean(mat02, axis=0)] array012 = [np.mean(mat012, axis=0)] elif dev_n == '05': array0 = [np.mean(smoothed0, axis=0)] array01 = [np.mean(smoothed01, axis=0)] array02 = [np.mean(smoothed02, axis=0)] array012 = [np.mean(smoothed012, axis=0)] #################################################################### arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012] arrays = [array0, array01, array02, array012] arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012] array_mat = [mat0, mat01, mat02, mat012] ###################################### # upper and lower bound berechnen names = ['0', '01', '02', '012'] position_diff += 1 return array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names def huxley(): np.random.seed(1000) # Start and end time (in milliseconds) tmin = 0.0 tmax = 50.0 # Average potassium channel conductance per unit area (mS/cm^2) gK = 36.0 # Average sodoum channel conductance per unit area (mS/cm^2) gNa = 120.0 # Average leak channel conductance per unit area (mS/cm^2) gL = 0.3 # Membrane capacitance per unit area (uF/cm^2) Cm = 1.0 # Potassium potential (mV) VK = -12.0 # Sodium potential (mV) VNa = 115.0 # Leak potential (mV) Vl = 10.613 # Time values # Potassium ion-channel rate functions def alpha_n(Vm): return (0.01 * (10.0 - Vm)) / (np.exp(1.0 - (0.1 * Vm)) - 1.0) def beta_n(Vm): return 0.125 * np.exp(-Vm / 80.0) # Sodium ion-channel rate functions def alpha_m(Vm): return (0.1 * (25.0 - Vm)) / (np.exp(2.5 - (0.1 * Vm)) - 1.0) def beta_m(Vm): return 4.0 * np.exp(-Vm / 18.0) def alpha_h(Vm): return 0.07 * np.exp(-Vm / 20.0) def beta_h(Vm): return 1.0 / (np.exp(3.0 - (0.1 * Vm)) + 1.0) # n, m, and h steady-state values def n_inf(Vm=0.0): return alpha_n(Vm) / (alpha_n(Vm) + beta_n(Vm)) def m_inf(Vm=0.0): return alpha_m(Vm) / (alpha_m(Vm) + beta_m(Vm)) def h_inf(Vm=0.0): return alpha_h(Vm) / (alpha_h(Vm) + beta_h(Vm)) # Input stimulus def Id(t): if 0.0 < t < 1.0: return 150.0 elif 10.0 < t < 11.0: return 50.0 return 0.0 # Compute derivatives def damping_kashimori(stimulus, time, GbCa_=15, GbKCa_=500): @jit() # (nopython=True) def func1(x, t, R, F, stimulus, dt, GbCa, GbKCa, T, CiCa, convert_micro): volt_to_milli = 1 volt_to_milli2 = 1 convert_milli = 0.001 # valance zCl = -1 zNa = 1 zCa = 2 zK = 1 # area Sb = 1 Sa = 20 St = Sb # capacitance F/m^2 C = 0.01 # Concentration milli Mol/l ClNa = 15 # * convert_milli ClK = 0.000000001 # * convert_milli ClCa = 0.000000001 # * convert_milli CcNa = 5 # * convert_milli CcK = 150 # * convert_milli CcCa = 0.01 # * convert_milli CiNa = 150 # * convert_milli CiK = 5 # * convert_milli # Permeability m/s PaNa = 1 * (10 ** -11) PaK = 9.8 * (10 ** -11) PaCl = 0.5 * (10 ** -11) PbNa = 2 * (10 ** -9) PbK = 5 * (10 ** -9) PbCl = 1 * (10 ** -9) PtNa = 5 * (10 ** -11) PtK = 5 * (10 ** -11) PtCl = 5 * (10 ** -11) # Initialize Fa = x[0] Fb = x[1] mc = x[2] CcCaI = x[3] # CcCa = CcCaI C0 = x[4] C1 = x[5] C2 = x[6] O2 = x[7] O3 = x[8] V = x[9] m = x[10] h = x[11] n = x[12] # concentration S/m^2 ClCl = ClNa + ClK + 2 * ClCa CcCl = CcNa + CcK + 2 * CcCa CiCl = CiNa + CiK + 2 * CiCa # area ra = Sa / Sb rt = St / Sb Rat = ra + rt + rt * ra Ft = Fa - Fb # The equillimbium potential: V FaNa = ((R * T) / (zNa * F)) * np.log(ClNa / CcNa) * volt_to_milli FaK = ((R * T) / (zK * F)) * np.log(ClK / CcK) * volt_to_milli FaCl = ((R * T) / (zCl * F)) * np.log(ClCl / CcCl) * volt_to_milli FbNa = ((R * T) / (zNa * F)) * np.log(CiNa / CcNa) * volt_to_milli FbK = ((R * T) / (zK * F)) * np.log(CiK / CcK) * volt_to_milli FbCa = ((R * T) / (zCa * F)) * np.log(CiCa / CcCa) * volt_to_milli FbCl = ((R * T) / (zCl * F)) * np.log(CiCl / CcCl) * volt_to_milli FtNa = ((R * T) / (zNa * F)) * np.log(ClNa / CiNa) * volt_to_milli FtK = ((R * T) / (zK * F)) * np.log(ClK / CiK) * volt_to_milli FtCl = ((R * T) / (zCl * F)) * np.log(ClCl / CiCl) * volt_to_milli # equation 14 : Conductivity leaky channels S/m^2 etaa = (F * Fa / volt_to_milli) / (R * T) # no unit, Fa has to be in Volt for cancellation faNa = ((zNa ** 2) * (F ** 2) * Fa * PaNa * (ClNa - CcNa * np.exp(zNa * etaa))) / ( (R * T * (Fa - FaNa)) * (1 - np.exp(zNa * etaa))) faK = ((zK ** 2) * (F ** 2) * Fa * PaK * (ClK - CcK * np.exp(zK * etaa))) / ( (R * T * (Fa - FaK)) * (1 - np.exp(zK * etaa))) faCl = ((zCl ** 2) * (F ** 2) * Fa * PaCl * (ClCl - CcCl * np.exp(zCl * etaa))) / ( (R * T * (Fa - FaCl)) * (1 - np.exp(zCl * etaa))) etat = (F * Ft / volt_to_milli) / (R * T) ftNa = ((zNa ** 2) * (F ** 2) * Ft * PtNa * (ClNa - CiNa * np.exp(zNa * etat))) / ( (R * T * (Ft - FtNa)) * (1 - np.exp(zNa * etat))) ftK = ((zK ** 2) * (F ** 2) * Ft * PtK * (ClK - CiK * np.exp(zK * etat))) / ( (R * T * (Ft - FtK)) * (1 - np.exp(zK * etat))) ftCl = ((zCl ** 2) * (F ** 2) * Ft * PtCl * (ClCl - CiCl * np.exp(zCl * etat))) / ( (R * T * (Ft - FtCl)) * (1 - np.exp(zCl * etat))) etab = (F * Fb / volt_to_milli) / (R * T) fbNa = ((zNa ** 2) * (F ** 2) * Fb * PbNa * (CiNa - CcNa * np.exp(zNa * etab))) / ( (R * T * (Fb - FbNa)) * (1 - np.exp(zNa * etab))) fbK = ((zK ** 2) * (F ** 2) * Fb * PbK * (CiK - CcK * np.exp(zK * etab))) / ( (R * T * (Fb - FbK)) * (1 - np.exp(zK * etab))) fbCl = ((zCl ** 2) * (F ** 2) * Fb * PbCl * (CiCl - CcCl * np.exp(zCl * etab))) / ( (R * T * (Fb - FbCl)) * (1 - np.exp(zCl * etab))) # A5,6,7,8: Conductivity K and Ca k_1 = 6 * (10 ** 3) k_2 = 100 * (10 ** 3) k_3 = 30 * (10 ** 3) betac = 1000 delta1 = 0.2 delta2 = 0 delta3 = 0.2 K10 = 6 * (10 ** -3) # mmol/l K20 = 45 * (10 ** -3) # mmol/l K30 = 20 * (10 ** -3) # mmol/l Valpha = 33 * (10 ** -3) * volt_to_milli # V alphac0 = 450 # -s Ks = 28000 # -s U = 0.2 k1 = (k_1 / K10) * np.exp( (-2 * delta1 * F * Fb / volt_to_milli) / (R * T)) # Fb has to be in V for cancellation k2 = (k_2 / K20) * np.exp((-2 * delta2 * F * Fb / volt_to_milli) / (R * T)) k3 = (k_3 / K30) * np.exp((-2 * delta3 * F * Fb / volt_to_milli) / (R * T)) alphac = alphac0 * np.exp(-Fb / Valpha) GbCa_var = GbCa * (mc ** 3) dC0 = k_1 * C1 - k1 * CcCaI * C0 dC1 = k1 * CcCaI * C0 + k_2 * C2 - (k_1 + k2 * CcCaI) * C1 dC2 = k2 * CcCaI * C1 + alphac * O2 - (k_2 + betac) * C2 dO2 = betac * C2 + k_3 * O3 - (alphac + k3 * CcCaI) * O2 dO3 = k3 * CcCaI * O2 - k_3 * O3 GbKCa_var = GbKCa * (O2 + O3) # GbKCa_var = GbKCa_ * ((np.abs(O2) + np.abs(O3))/(np.abs(C0)+np.abs(C1)+np.abs(C2)+np.abs(O2)+np.abs(O3))) Ga = faNa * (Fa - FaNa) + faK * (Fa - FaK) + faCl * ( Fa - FaCl) Gb = fbNa * (Fb - FbNa) + fbCl * ( Fb - FbCl) + fbK * (Fb - FbK) + GbCa_var * (Fb - FbCa) + GbKCa_var * ( Fb + FbK) # print('Na:') Gt = ftNa * (Fa - Fb - FtNa) + ftK * (Fa - Fb - FtK) + ftCl * ( Fa - Fb - FtCl) try: stimulus1 = stimulus[int(np.round(t / dt))] * convert_micro * 10000 * volt_to_milli except: stimulus1 = stimulus[0] * convert_micro * 10000 * volt_to_milli print('error in indexing') print(int(np.round(t / dt))) dFa = (1 / (C * Rat)) * (-(ra + rt) * stimulus1 - ra * ( 1 + rt) * Ga - rt * Gb - rt * Gt) dFb = (1 / (C * Rat)) * (ra * (ra + rt) * stimulus1 - ra * rt * Ga - ( ra + rt) * Gb + ra * rt * Gt) dFb_rest = (1 / (C * Rat)) * (- ra * rt * Ga - ( ra + rt) * Gb + ra * rt * Gt) V0 = 70 * convert_milli * volt_to_milli Vb = 6.17 * convert_milli * volt_to_milli beta0 = 0.97 Kb = 940 alpha0 = 22800 Va = 8.01 * convert_milli * volt_to_milli Ka = 510 beta = beta0 * np.exp((Fb + V0) / Vb) + Kb # s^-1 alpha = alpha0 * np.exp(-(Fb + V0) / Va) + Ka # s^-1 dmc = beta * (1 - mc) - alpha * mc # s^-1 IbCa = GbCa * (mc ** 3) * (Fb - FbCa) if IbCa == -0: IbCa = 0 l = 25 * (10 ** -6) # m Xi = 3.4 * (10 ** -6) # fraction of the volume that aborbs Ca the bigger the fraction the smaller Ca should be const = U / (2 * Xi * l * F) # * 1000 dCcCa = const * IbCa - Ks * CcCaI Cn = 0.01 # S/m^2 gNa_ = 1200 # S/m^2 gK_ = 400 # S/m^2 gl = 2.4 # S/m^2 VNa = 0.056 * volt_to_milli2 # V VK = -0.093 * volt_to_milli2 # V Vl = -0.03 * volt_to_milli2 # V w = 4.7 eta = 0.150 * volt_to_milli2 # A/m^2 e = 0.050 * volt_to_milli2 # A/m^2 Ips = (w / (1 + np.exp(-(np.abs(IbCa * volt_to_milli2) - eta) / e))) * volt_to_milli2 dV = (-gNa_ * (m ** 3) * h * (V - VNa) - gK_ * (n ** 4) * (V - VK) - gl * (V - Vl) + Ips) / Cn # Volt betahinf = 1.8 # *(10**3)#ms mV = V * 1000 / volt_to_milli2 # mv beta_ending = 1000 # tranfer ms and mv back in v and s alpham = (-0.1 * (mV + 40) / (np.exp(-(mV + 40) / 10)) - 1) * beta_ending # s betam = (4 * np.exp(-(mV + 65) / 18)) * beta_ending # s alphah = (0.07 * np.exp(-(mV + 65) / 20)) * beta_ending # s betah = (betahinf / (np.exp(-(mV + 35) / 10) + 1)) * beta_ending # s alphan = (-0.01 * (mV + 55) / (np.exp(-(mV + 55) / 10) - 1)) * beta_ending # s betan = (0.125 * np.exp(-(mV + 65) / 80)) dm = alpham - (alpham + betam) * m # s^-1 dh = alphah - (alphah + betah) * h # s^-1 dn = alphan - (alphan + betan) * n # s^-1 vars = (dFa, dFb, dmc, dCcCa, dC0, dC1, dC2, dO2, dO3, dV, dm, dh, dn, dFb_rest) return vars convert_mili = 0.001 convert_micro = 0.000001 R = 8.31446261815324 F = 96485.3329 x = [[]] * 14 x[0] = -0.03 # *volt_to_milli # Fa = x[1] = -0.050 # *volt_to_milli # = Fb x[2] = 0 # = mc x[3] = 0.01 * (10 ** -3) # = CcCa x[4] = 0.05 # = C0 x[5] = 0.05 # = C1 x[6] = 0.1 # = C2 x[7] = 0.4 # = O2 x[8] = 0.4 # = O3 x[9] = -0.08 # *volt_to_milli # = V x[10] = 1 # = m x[11] = 0 # = h x[12] = 0.5 # = n x[13] = -0.08 CiCa = 1 # *(10**-3) #milli mol/l T = 298.5 # K in Koshimori us = odeint(func1, x, time, args=( R, F, stimulus, np.abs(time[0] - time[1]), GbCa_, GbKCa_, T, CiCa, convert_mili, convert_micro)) # ,hmin = np.abs(time[0]-time[1]),hmax = np.abs(time[0]-time[1]) mc = us[:, 2] Fb = us[:, 1] CcCa = us[:, 3] zCa = 2 FbCa = ((R * T) / (zCa * F)) * np.log(CiCa / CcCa) IbCa = GbCa_ * (mc ** 3) * (Fb - FbCa) return np.std(IbCa), np.abs(np.min(IbCa)) - np.abs(np.max(IbCa)), IbCa, us def damping_hundspet(mechanical, stimulus, time, damping_type): convert_SI = 1 convert_SI_minus = 1 / convert_SI convert_mili = 0.001 * convert_SI convert_micro = 0.000001 * convert_SI convert_nano = 0.000000001 * convert_SI convert_pico = 0.000000000001 * convert_SI thousand = 1 x = [[]] * 8 x[0] = -0.05 # 30 * convert_mili # = Fb x[1] = 0.27 # = mc x[2] = 0.1 # * convert_mili * thousand # = CcCa x[3] = 0.05 # = C0 x[4] = 0.05 # = C1 x[5] = 0.1 # = C2 x[6] = 0.4 # = O2 x[7] = 0.4 # = O3 @jit() # (nopython=True) def func1(x, t, R, damping_type, F, convert_SI_minus, stimulus, dt, T, convert_milli, convert_micro, convert_nano, convert_pico, mechanical): # Initialize Fb = x[0] mc = x[1] CcCaI = x[2] C0 = x[3] C1 = x[4] C2 = x[5] O2 = x[6] O3 = x[7] GbKCa_ = 16.8 * convert_nano # S FbK = -80 * convert_milli # V FbCa = 100 * convert_milli # V GbCa_ = 4.14 * convert_nano # S # same parameter for kinetic sheme k_1 = 300 * convert_SI_minus # s^-1 k_2 = 5000 * convert_SI_minus # s^-1 k_3 = 1500 * convert_SI_minus # s^-1 betac = 1000 * convert_SI_minus # s^-1 alphac0 = 450 * convert_SI_minus # s^-1 delta1 = 0.2 delta2 = 0 delta3 = 0.2 K10 = 6 * convert_micro # mol/l K20 = 45 * convert_micro # mol/l K30 = 20 * convert_micro # mol/l Valpha = 33 * convert_milli # V alphac = alphac0 * np.exp(-Fb / Valpha) # s^-1 GbCa = GbCa_ * (mc ** 3) # from kashimori: S if 'nieman' in damping_type: K1 = (K10 * np.exp((-1 * delta1 * F * Fb) / (R * T))) k1 = (k_1 / K1) K2 = (K20 * np.exp((-1 * delta2 * F * Fb) / (R * T))) k2 = (k_2 / K2) K3 = (K30 * np.exp((-1 * delta3 * F * Fb) / (R * T))) k3 = (k_3 / K3) GbKCa = GbKCa_ * (O2 + O3) dC0 = 0 C0 = 1 - (C1 + C2 + O2 + O3) else: K1 = (K10 * np.exp((-2 * delta1 * F * Fb) / (R * T))) # mol/l k1 = (k_1 / K1) # l/(s*mol) K2 = (K20 * np.exp((-2 * delta2 * F * Fb) / (R * T))) # mol/l k2 = (k_2 / K2) # l/(s*mol) K3 = (K30 * np.exp((-2 * delta3 * F * Fb) / (R * T))) # mol/l k3 = (k_3 / K3) # l/(s*mol) dC0 = k_1 * C1 - k1 * CcCaI * C0 # CcCaI here is mol/l GbKCa = GbKCa_ * (O2 + O3) # S dC1 = k1 * CcCaI * C0 + k_2 * C2 - (k_1 + k2 * CcCaI) * C1 dC2 = k2 * CcCaI * C1 + alphac * O2 - (k_2 + betac) * C2 dO2 = betac * C2 + k_3 * O3 - (alphac + k3 * CcCaI) * O2 dO3 = k3 * CcCaI * O2 - k_3 * O3 try: stimulus1 = stimulus[int(np.round(t / dt))] except: stimulus1 = stimulus[0] print('error in indexing') FbL = -30 * convert_milli # V fbL = 1 * convert_nano # S C = 15 * convert_pico # F x = 20 * convert_nano # m G1 = 0.75 * 1000 # kcal/mol = 1000 cal/mol G2 = 0.25 * 1000 # kcal/mol = 1000 cal/mol Z1 = 10 * 1000 * 10 ** 6 # kcal/mol microm = 1000 10**6 cal/mol* m Z2 = 2 * 1000 * 10 ** 6 # kcal/mol microm = 1000 10**6 cal/mol* m A = (G1 - Z1 * x) / (R * T) B = (G2 - Z2 * x) / (R * T) Fbt = 0 * convert_milli # V gt_ = 3 * convert_nano # S gt = gt_ / (1 + np.exp(B) * (1 + np.exp(A))) # S if mechanical == 'mechanical': mechanical = gt * (Fb - Fbt) else: mechanical = 0 dFb = -(fbL * (Fb - FbL) + GbCa * (Fb - FbCa) + GbKCa * ( Fb - FbK) + mechanical - stimulus1) / C # +fbCa*(Pb -PbCa) alpha0 = 22800 * convert_SI_minus # s^-1 V0 = 70 * convert_milli # V VA = 8.01 * convert_milli # V Ka = 510 * convert_SI_minus # s^-1 beta0 = 0.97 * convert_SI_minus # s^-1 VB = 6.17 * convert_milli # V Kb = 940 * convert_SI_minus # s^-1 beta = beta0 * np.exp((Fb + V0) / VB) + Kb alpha = alpha0 * np.exp(-(Fb + V0) / VA) + Ka dm = beta * (1 - mc) - alpha * mc IbCa = GbCa_ * (mc ** 3) * (Fb - FbCa) # A if IbCa == -0: IbCa = 0 U = 0.02 Xi = 3.4 * ( 10 ** -5) # fraction of volume that binds Ca, the bigger this is the smaller should be Ca increase therefore I guess this factor is necessary Cvol = 1.25 * convert_pico # l Ks = 2800 * convert_SI_minus ##s^-1 if 'nieman' in damping_type: const = -0.00061 else: const = (U / (2 * F * Cvol * Xi)) dCcCa = (const * IbCa) - Ks * CcCaI # concentration is mol/l vars = (dFb, dm, dCcCa, dC0, dC1, dC2, dO2, dO3) return vars R = 8.31446261815324 F = 96485.336 CiCa = 1 * convert_mili * thousand T = 298.5 stimulus = stimulus * 20 * convert_pico us = odeint(func1, x, time, args=( R, damping_type, F, convert_SI_minus, convert_SI, thousand, stimulus, np.abs(time[0] - time[1]), T, CiCa, convert_mili, convert_micro, convert_nano, convert_pico, mechanical), full_output=True) # ,hmin = np.abs(time[0]-time[1]),hmax = np.abs(time[0]-time[1]) # full_output = True #,hmin = np.abs(time[0]-time[1])/10,hmax = np.abs(time[0]-time[1])*10 # return us def all_damping_variants(stimulus, time_array, damping_type='', eod_fr=750, damping_gain=1, damping='', damping_element='', damping_output=[], plot=False, std_dump=0, max_dump=0, range_dump=0): # function: here you can choose the way of dumping but in realitiy only the damping_type == 'damping' is properly done, # the rest would need further modifications if damping_type == 'damping': damping_output, stimulus, std_dump, max_dump, range_dump = damping_func(stimulus, time_array, eod_fr, damping, damping_element, damping_gain) elif (damping_type == 'damping_hundspeth') or (damping_type == 'damping_nieman'): std_dump, extent, IbCa, damping_output = damping_hundspet(stimulus, time_array, damping_type) elif damping_type == 'damping_kashimori': std_dump, extent, IbCa, damping_output = damping_kashimori(stimulus, time_array) if plot == True: from utils_test import plot_kashimori plot_kashimori(stimulus, time_array, damping_output, IbCa) elif damping_type == 'damping_huxley': huxley() return std_dump, max_dump, range_dump, stimulus, damping_output, def a2(): pass def do_array_for_three(nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM, eod_stimulus, eod_fish_r, freq2, a_f1, a_f2, cell, stimulus_length, offset, model_params, n, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, stimulus, phaseshift_fr=0, variant='sinz', eod_fr=750, length_adapt=True, dict_here=[], redo_stim=False, stim_type='01', reshuffle='reshuffled', dev=0.0005, damping=''): spikes = [[]] * trials_nr spikes_bef = [[]] * trials_nr for t in range(trials_nr): if (t == 0) | (type(phaseshift_fr) == str): if redo_stim: eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) eod_fish1, eod_fish2, eod_stimulus, t1 = stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2, fish_emitter, fish_jammer, freq1, freq2, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, phaseshift_fr, sampling, stimulus_length, time_array, zeros) # if we need new stimulus each time we generate it here each time if stim_type == '02': stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_fish2, eod_fish_r, freq2, f2, eod_fr, time_array, a_f2, eod_fj=freq2, j=f2, a_fj=a_f2, ) # three='Three' elif stim_type == '012': stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1, eod_fr, time_array, a_f1, eod_fj=freq2, j=f2, a_fj=a_f2, three='Three') # SAM, eod_stimulus, eod_fish_r,freq2,a_f1, a_f2 elif stim_type == '01': stimulus, eod_fish_sam = create_stimulus_SAM(SAM, dict_here['eod_fish1'], eod_fish_r, freq1, f1, eod_fr, time_array, a_f1, eod_fj=freq1, j=f2, a_fj=a_f2, ) # three='Three' # damping variants if damping != '': embed() std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(stimulus, time_array, damping_type, eod_fr, damping_gain, damping, damping_variant, plot=False, std_dump=0, max_dump=0, range_dump=0) cvs, adapt_output, baseline_after, _, rate_adapted, rate_baseline_before, rate_baseline_after, spikes_bef[t], \ stimulus_altered, \ v_dent_output, offset_new, v_mem_output, noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, power_variant=variant, power_nr=n, reshuffle=reshuffle, **model_params) isi = calc_isi(spikes_bef[t], eod_fr) spikes[t] = spikes_after_burst_corr(spikes_bef[t], isi, dict_here['burst_corr'], cell, eod_fr, model_params=model_params) if length_adapt == False: spikes_mat = [[]] * len(spikes) for s in range(len(spikes)): spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat)) else: spikes_mat = spikes_mat_depending_on_length(spikes, deltat, stimulus_length) sampling_rate = 1 / deltat if dev != 'original': smoothed = gaussian_filter(spikes_mat, sigma=dev * sampling_rate) else: smoothed = spikes_mat mean_smoothed = np.mean(smoothed, axis=0) return stimulus, mean_smoothed, spikes, smoothed, spikes_mat, offset_new, v_mem_output def stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2, fish_emitter, fish_jammer, freq1, freq2, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, phaseshift_fr, sampling, stimulus_length, time_array, zeros): t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) if phaseshift_fr == 'randALL': phaseshift_f1 = np.random.rand() * 2 * np.pi phaseshift_f2 = np.random.rand() * 2 * np.pi eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 return eod_fish1, eod_fish2, eod_stimulus, t1 def check_peak_overlap_only_stim_big_final(stimulus_length_data=0.5, dev=0.001, reshuffled='reshuffled', printing=False, show=False, beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto', us_name=''): runs = 1 n = 1 default_settings() # ts=13, ls=13, fs=13, lw = 0.7 # extra combination with female small # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = '05' model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") a_fr = 1 a = 0 trials_nrs = [1] datapoints = 1000 results_diff = pd.DataFrame() position_diff = 0 default_settings(column=2, length=8.5) for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 # 2**16##6#32768#2**12#32768 cells_here = ['2012-06-27-an-invivo-1'] for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_250-1325-25_F2_500-525-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_mult__StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_not_log_temporal'] c_grouped = ['c1'] # , 'c2'] # adds = [-150, -50, -10, 10, 50, 150] # fig, ax = plt.subplots(4, len(adds), constrained_layout=True, figsize=(12, 5.5)) frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] frame_cell_orig_orignal = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: # (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0) new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) new_f2_tuple = new_f2_tuple[sorted] frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)] frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell) diffs = find_deltas(frame_cell, c_grouped[0]) frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original') new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score'] freq1s = np.unique(new_frame.df1) freq2s = np.unique(new_frame.df2) freq_example = 30 # 65 freq1s = [freq1s[np.argmin(np.abs(freq1s - freq_example))]] freq2s = [freq2s[0]] else: freq_example = 30 # 65 freq1s = [freq_example] freq2s = [10] for freq1 in freq1s: for freq2 in freq2s: c_nrs = [0.0002, 0.2, 0.8] # 0.0002, , 0.50.01,0.075,0.1, grid0 = gridspec.GridSpec(1, 1, bottom=0.1, top=0.85, left=0.09, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.15, hspace=0.27, subplot_spec=grid0[0], ) # grid_u = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.7, wspace=0.3, subplot_spec=grid00[ 0]) # hspace=0.4,wspace=0.2,len(chirps) grid_l = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.7, wspace=0.1, subplot_spec=grid00[ 1]) # hspace=0.4,wspace=0.2,len(chirps) if len(frame_cell_orig) > 0: # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren c_here = 'c1' if c_here == 'c1': # 'B1_diff' pass if c_here == 'c2': # 'B1_diff' pass f_counter = 0 ax_upper = [] frame_cell_orig_orignal, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig_orignal) frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] axis_all = [] f += 1 # plot the baseline Peak above linestyles = [['-', '-', '-', '-'], ['-', '-', '-', '-'] , ['--', '-.', '--'], ['-', '-', '-', '-'], ] scores = [['amp_B1_01_mean_original', 'amp_f1_01_mean_original', 'amp_f0_01_mean_original'], ] # ['cv_01'] colors = [['green', 'pink', 'black', 'red'], ['grey'], ] alpha = [[1, 0.5, 1, 1], [0.5, 0.5, 0.5, 0.5], [1, 1, 1, 1], [1, 1, 1, 1], ] axs = [] for s in range(len(scores)): ax_u1 = plt.subplot(grid_u[s]) ax_upper = plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1, freq2, alpha=alpha[s], sum=False, linestyles=linestyles[s], scores=scores[s], colors=colors[s]) ax_u1.legend(loc=(0, 1), ncol=3) if 'cv' not in scores[s][0]: axs.append(ax_u1) join_x(axs) join_x(axs) join_y(axs) # frame_cell_orig_orignal for ax in ax_upper: ax.scatter(c_nrs, np.zeros(len(c_nrs)), marker='^', color='black') grid_ll = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs), hspace=0.55, wspace=0.2, subplot_spec=grid_l[ f]) # hspace=0.4,wspace=0.2,len(chirps) colors_array = ['grey', 'green', 'orange', 'purple'] print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 cvs = pd.read_csv(load_folder_name('calc_base') + '/csv_model_data.csv') cv = np.round(cvs[cvs['cell'] == cell_here].cv_model.iloc[0], 3) fr = np.round(cvs[cvs['cell'] == cell_here].fr_model.iloc[0]) plt.suptitle(cell_here + ' EODf=' + str(np.round(eodf)) + ' Hz' + ' cv=' + str(cv) + ' fr=' + str( np.round(fr)) + 'Hz F1=' + str(np.round(freq1 + eodf)) + ' Hz' + ' F1-EODf=' + str( np.round(freq1)) + ' Hz' + ' F2=' + str(np.round(freq2 + eodf)) + ' Hz ' + 'F2-EODf=' + str( np.round(freq2)) + ' Hz ') axts = [] axps = [] axis = [] for c_nn, c_nr in enumerate(c_nrs): f_counter = plt_single_contrast_ps_isi_01(axis, f_counter, axts, axps, f, grid_ll, c_nn, freq_example, freq2, eodf, datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, a, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length_data, model_cells, position_diff, colors_array, reshuffled, dev, sampling, cell_here, c_nr, n, [dev_name], min_amps, extend=True, first=False, ypos=1.65, c_nn_nr=0, xpos=1, second=False) axts_all.extend(axts) axps_all.extend(axps) axis_all.extend(axis) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) axis_all[0].get_shared_x_axes().join(*axis_all) save_visualization(cell_here + '_freq1_' + str(freq1) + '_freq2_' + str(freq2), show) def plt_single_contrast_ps_isi_01(axis, f_counter, axts, axps, f, grid_ll, c_nn, freq1, freq2, eodf, datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, a, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, colors_array, reshuffled, dev, sampling, cell_here, c_nr, n, dev_name, min_amps, extend=True, c_nn_nr=1, first=True, xpos=1, second=True, val=1.5, ypos=1.35): v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_05, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n, reshuffled=reshuffled, min_amps=min_amps) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n, reshuffled=reshuffled, min_amps=min_amps) time = np.arange(0, len(arrays_05[a][0]) / sampling, 1 / sampling) arrays_here = [v_mems[1]] #::]#arrays_05[1::]#arrays_original[1::]# arrays_here_original = [arrays_original[1]] #::] spike_here = [arrays_spikes[1]] #::] stim_here = [arrays_stim[1]] #::] names = ['0', '01', '02', '012'] names_here = [names[1]] # extend=True for a in range(len(arrays_here)): grid_pt = gridspec.GridSpecFromSubplotSpec(4, 1, hspace=0.65, wspace=0.2, subplot_spec=grid_ll[c_nn]) # hspace=0.4,wspace=0.2,len(chirps) axs = plt.subplot(grid_pt[0]) axt = plt.subplot(grid_pt[1]) axp = plt.subplot(grid_pt[2]) axi = plt.subplot(grid_pt[3]) axts.append(axt) axps.append(axp) axis.append(axi) if f != 0: remove_yticks(axt) remove_yticks(axs) if a != len(arrays_here) - 1: remove_xticks(axt) remove_xticks(axs) if f_counter == 0: axt.set_ylabel(names[a]) axs.set_ylabel(names[a]) if a == 0: axs.set_title(' a1=' + str(c_nr) + ', a2=0') elif a == 1: axs.set_title(' a1=0,' + ' a2=' + str(a_f2s[0])) else: axs.set_title(' a1=' + str(c_nr) + ', a2=' + str(a_f2s[0])) xlim = [0.1, 0.1 + val / freq1] axs.plot(time, stim_here[a], color='grey') # color=colors_array_here[a],colors_contrasts[c_nn] axs.eventplot(spike_here[a][0], lineoffsets=np.mean(stim_here[a]), color='black') # np.max(v1)* axs.set_xlim(xlim) axt.plot(time, arrays_here[a], color='grey') # colors_array_here[a]colors_contrasts[c_nn] axt.eventplot(spike_here[a][0], lineoffsets=np.max(arrays_here[a]), color='black') # np.max(v1)* axt.set_xlim(xlim) # 1.5 axi.hist(np.diff(spike_here[a][0]) / (1 / eodf), bins=np.arange(0, np.max(np.diff(spike_here[a][0]) / (1 / eodf)), 0.1), color='grey') # colors_array_here[a] axi.axvline(x=1, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey', axi.axvline(x=2, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey', axi.axvline(x=3, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey', axi.axvline(x=4, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey', try: axi.set_xticks_delta(1) except: print('problem something') axi.set_xlim(0, 8) pp, ff = ml.psd(arrays_here_original[a][0] - np.mean(arrays_here_original[a][0]), Fs=sampling, NFFT=nfft, noverlap=nfft // 2) axp.plot(ff, pp, color='grey') # colors_contrasts[c_nn]#colors_array_here[a] maxx = 900 axp.set_xlim(0, maxx) if c_nn == c_nn_nr: if a == 0: # if second: second_part = 'F1=' + str(np.round(freq1 + eodf)) + 'Hz' + ' F1-EODf=' + str( freq1) + 'Hz' else: second_part = '' if first: first_part = 'only Frequency 1: ' else: first_part = '' axt.text(xpos, ypos, first_part + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) elif a == 1: if second: second_part = 'F2=' + str(np.round(freq2 + eodf)) + 'Hz ' + 'F2-EODf=' + str( freq2) + ' Hz ' else: second_part = '' if first: first_part = 'only Frequency 2: ' else: first_part = '' axt.text(xpos, ypos, first_part + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) else: if second: second_part = 'F1=' + str(np.round(freq1 + eodf)) + 'Hz' + ' F1-EODf=' + str( freq1) + 'Hz' + ' F2=' + str(freq2 + eodf) + 'Hz ' + 'F2-EODf=' + str(freq2) + ' Hz ' else: second_part = '' if first: first_part = 'Frequency 1 + Frequency 2: ' else: first_part = '' axt.text(xpos, ypos, first_part + second_part, fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, ) freqs, colors_peaks, labels, alphas = chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf='black', name=names_here[0], color_stim='pink', color_stim_mult='pink') plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, alphas=alphas, extend=extend, ms=18, clip_on=True) if c_nn != 0: remove_yticks(axt) remove_yticks(axs) remove_yticks(axp) remove_yticks(axi) else: axt.set_ylabel('mV') axp.set_ylabel('Hz') axi.set_ylabel('Nr') axt.set_xlabel('Time [s]') axi.set_xlabel('EOD mult') remove_xticks(axs) axp.set_xlabel('Frequency [Hz]') f_counter += 1 return f_counter def chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf='blue', stim_thing=True, color_stim='orange', name='01', color_stim_mult='orange'): if name == '01': alphas, labels, colors_peaks, freqs = mult_beat_freqs(eodf, maxx, np.abs(freq1), color_df_mult=colors_array[1], color_eodf=color_eodf, color_stim=color_stim, stim_thing=stim_thing, color_stim_mult=color_stim_mult, ) elif name == '02': freqs = [np.abs(freq2), np.abs(freq2) * 2, np.abs(freq2) * 2, np.abs(freq2) * 3, np.abs(freq2) * 4, np.abs(freq2) * 5, np.abs(freq2) * 6, np.abs(freq2) * 7, np.abs(freq2) * 8, np.abs(freq2) * 9, np.abs(freq2) * 10, eodf] colors_peaks = [colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2], 'black'] labels = ['DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'EODF'] alphas = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1, 0.5] elif name == 'eodf': freqs = [eodf] colors_peaks = ['black'] labels = ['EODF'] alphas = [1] else: freqs = [freq1, np.abs(freq2), eodf] colors_peaks = ['blue', 'red', 'black'] labels = ['DF1', 'DF2', 'EODF'] alphas = [1, 0.2, 0.5] return freqs, colors_peaks, labels, alphas def find_double_spikes(eod_fr, arrays_spikes, names, results_diff, position_diff, add=''): for a, sp_array in enumerate(arrays_spikes): hist, bin_edges = np.histogram(sp_array[0], bins=np.arange(0, np.max(sp_array[0]), 1 / eod_fr)) hist_big = hist[hist > 0] results_diff.loc[position_diff, 'dsp_perc95_' + names[a] + add] = np.percentile(hist_big, 95) results_diff.loc[position_diff, 'dsp_max_' + names[a] + add] = np.max(hist_big) results_diff.loc[position_diff, 'dsp_mean_' + names[a] + add] = np.mean(hist_big) return results_diff def upper_and_lower_fr(array_smoothed, results_diff, position_diff, eod_fr, names, add=''): lim = 10 for a, array in enumerate(array_smoothed): results_diff.loc[position_diff, 'lb_' + names[a] + add] = len(array[0][array[0] < lim]) / len(array[0]) results_diff.loc[position_diff, 'ub_' + names[a] + add] = len( array[0][(array[0] < eod_fr + lim) & (array[0] > eod_fr - lim)]) / len(array[0]) results_diff.loc[position_diff, 'ub_above_' + names[a] + add] = len( array[0][(array[0] >= eod_fr + lim)]) / len(array[0]) return results_diff def calc_vs_amps(results_diff, stim, eod_fr, arrays_spikes, position_diff, names, add=''): freq_comp = ['_f0', '_f1'] freqs = [eod_fr, stim] for f, ff in enumerate(freqs): for a, array in enumerate(arrays_spikes): vs = calc_vectorstrength(array[0], 1 / freqs[f]) results_diff.loc[position_diff, 'vs_' + freq_comp[f] + names[a] + add] = vs[0] vs = calc_vectorstrength(array[0], 1 / freqs[f] * 2) results_diff.loc[position_diff, 'vs_harm_' + freq_comp[f] + names[a] + add] = vs[0] return results_diff def plt_subpart_cocktail(results_diff, fs, p01, p02, p012, p0): fig, ax = plt.subplots(4, 1, sharex=True, sharey=True) # arrays = [p0[0], p01[0], p02[0], p012[0]] for a, array in enumerate(arrays): ax[a].plot(fs, array) B1 = results_diff.df1.iloc[-1] B2 = results_diff.df2.iloc[-1] fr = results_diff.fr.iloc[-1] f0 = results_diff.f0.iloc[-1] freqs = [np.abs(B1), np.abs(B2), np.abs(np.abs(B1) - np.abs(B2)), np.abs(B1) + np.abs(B2), np.mean(fr), f0] colors = ['blue', 'green', 'purple', 'orange', 'red', 'black'] labels = ['DF1', 'DF2', '|DF1-DF2|', '|DF1+DF2|', 'Baseline', 'eod_fr'] plt_peaks_several(freqs, arrays, ax[a], array, fs, labels, 0, colors) ax[-1].legend() plt.show() def calc_roc_amp_core_cocktail(freq1, freq2, datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, params_dict={'burst_corr': ''}, stimulus_length_first=0, p_xlim=0, a_f1s=[], dev_name=['05'], phaseshift_fr=0, min_amps='', n=1, reshuffled='', way_all='', test=False, AUCI='AUCI', phase_right='_phaseright_', SAM='', points=5, means_different=''): # '_means_' model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) f1 = 0 f2 = 0 sampling_factor = '' cell_recording = '' mimick = 'no' zeros = 'zeros' fish_morph_harmonics_var = 'harmonic' fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus'] fish_receiver = 'Alepto' # adapt_offset = 'adaptoffset_bisecting' lower_tol = 0.995 upper_tol = 1.005 damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35 damping_type = '' exponential = '' time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) spikes_base = [[]] * trials_nr spikes_bef = [[]] * trials_nr for run in range(runs): print(run) stim_lengths = [] for t in range(trials_nr): if (stimulus_length_first != 0) & (t == 0): stimulus_length_here = stimulus_length_first else: stimulus_length_here = stimulus_length stim_lengths.append(stimulus_length_here) if (t == 0) | (type(phaseshift_fr) == str): time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length_here, deltat=deltat) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length_here, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) if (stimulus_length_first != 0) & (t == 1): time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length_here, deltat=deltat) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length_here, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) # baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output stimulus_0 = eod_fish_r power_here = 'sinz' adapt_offset = 'adaptoffset_bisecting' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_bef[t], _, _, offset_new, v_mem0, noise_final = simulate(cell, offset, stimulus_0, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=power_here, power_alpha=alpha, power_nr=n, reshuffle=reshuffled, **model_params) isi = calc_isi(spikes_bef[t], eod_fr) try: spikes_base[t] = spikes_after_burst_corr(spikes_bef[t], isi, params_dict['burst_corr'], cell, eod_fr, model_params=model_params) except: print('assing spikes problem') print(' offset orig ' + str(offset)) test = False if test: from utils_test import test_cvs3 test_cvs3() if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 print(' Base ' + str(adapt_offset) + ' offset ' + str(offset)) if printing: print('Baseline time' + str(time.time() - t1)) if test: fig, ax = plt.subplots(2, 1, sharex=True) ax[0].eventplot(spikes_bef[t], color='red') ax[1].eventplot(spikes_base[t]) ax[1].eventplot(spikes_bef[t], color='red') ax[1].set_xlim(0, 0.1) plt.show() sampling_rate = 1 / deltat stim_length_max = np.max([stimulus_length_first, stimulus_length]) base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stim_length_max, dev=dev) fr = np.mean(base_cut) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) for aaa, a_f2 in enumerate(a_f2s): # [0] for aa, a_f1 in enumerate(a_f1s): # [0] eod_fish1, eod_fish2, eod_stimulus, t1 = stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2, fish_emitter, fish_jammer, freq1, freq2, fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, phaseshift_fr, sampling, stimulus_length, time_array, zeros) if test: from utils_test import test_timearray test_timearray() adapt_offset_later = '' v_mem, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset_later, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev, zeros=zeros, a_fr=a_fr, params_dict=params_dict, redo_stim=True, nfft=nfft, cell_recording=cell_recording, phaseshift_fr=phaseshift_fr, fish_emitter=fish_emitter, fish_receiver=fish_receiver, beat=beat, fish_jammer=fish_jammer, fish_morph_harmonics_var=fish_morph_harmonics_var, mimick=mimick, nfft_for_morph=nfft_for_morph, phase_right=phase_right, sampling=sampling) if test: fig, ax = plt.subplots(4, 1, sharex=True, sharey=True) ax[0].plot(stimulus_0) ax[1].plot(stimulus_01) ax[2].plot(stimulus_02) ax[3].plot(stimulus_012) plt.show() if printing: print('Generation process' + str(time.time() - t1)) v_mems = np.concatenate([[v_mem0], v_mem]) for dev_n in dev_name: results_diff.loc[position_diff, 'fr'] = fr results_diff.loc[position_diff, 'f1'] = freq1[0] results_diff.loc[position_diff, 'f2'] = freq2[0] results_diff.loc[position_diff, 'f0'] = eod_fr results_diff.loc[position_diff, 'df1'] = np.abs(eod_fr - freq1) results_diff.loc[position_diff, 'df2'] = np.abs(eod_fr - freq2) results_diff.loc[position_diff, 'cell'] = cell results_diff.loc[position_diff, 'c1'] = a_f1 results_diff.loc[position_diff, 'c2'] = a_f2 results_diff.loc[position_diff, 'trial_nr'] = trials_nr results_diff = calc_cv_three_wave(results_diff, position_diff, arrays=[spikes_base, spikes_01, spikes_02, spikes_012], adds=['_0', '_01', '_02', '_012']) if dev_n == '05': dev = 0.0005 # test wie ich die std und var und psd eigentlich gruppieren müsste if dev_n == 'original': array0 = [np.mean(mat0, axis=0)] array01 = [np.mean(mat01, axis=0)] array02 = [np.mean(mat02, axis=0)] array012 = [np.mean(mat012, axis=0)] elif dev_n == '05': array0 = [np.mean(smoothed0, axis=0)] array01 = [np.mean(smoothed01, axis=0)] array02 = [np.mean(smoothed02, axis=0)] array012 = [np.mean(smoothed012, axis=0)] p0, p02, p01, p012, ff = calc_ps(nfft, [array012[0][time_array > p_xlim]], [array01[0][time_array > p_xlim]], [array02[0][time_array > p_xlim]], [array0[0][time_array > p_xlim]], sampling_rate=sampling_rate) if test: plt_subpart_cocktail(results_diff, ff, p01, p02, p012, p0) ######################################################## # also hier sind eben diese ganzen Amplituden # 'amp_max_012_mean' ,'amp_max_02_mean', 'amp_max_01_mean', 'amp_max_0_mean' ist zwangsweise das was wir suchen # 'amp_max_012_mean', 'amp_max_02_mean', 'amp_max_01_mean' # 'amp_B2_012_mean','amp_B1_012_mean' 'amp_B2_02_mean','amp_B1_01_mean' # das ganze ist zwangsweise auf dem gemittelten Arrays # in dieser einen Version mache ich das ohne sqrt hier und dann if np.isnan(results_diff.loc[position_diff, 'f0']): print('isnan thing4') embed() results_diff = calc_amps(ff, p0, p02, p01, p012, position_diff, [dev], 0, results_diff, results_diff, add='_mean' + '_' + dev_n, timesstamp=False, min_amps=min_amps, points=points) printing = False if printing: print(' a_f1 ' + str(aa) + ' ' + str(adapt_offset) + ' offset ' + str(offset) + ' time ' + str( time.time() - t1)) ####################################### # here calculate the fft # die arrays hier sind immer eindimmensional deswegen muss man hier nicht auf trials achten! # embed() # das ist das was wir vergleichen ffts_right1, freq = calc_fft(array0, array01, array012, array02, deltat, sampling) results_diff.loc[position_diff, 'diff_fft' + '_' + str(dev_n)] = np.sum(ffts_right1) * \ freq[1] #################################################################### arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012] arrays = [array0, array01, array02, array012] arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012] names = ['0', '01', '02', '012'] for a, array in enumerate(arrays): results_diff.loc[position_diff, 'std_' + names[a] + '_' + dev_n] = np.std(array) results_diff.loc[position_diff, 'var_' + names[a] + '_' + dev_n] = np.var(array) names_saved = ['var', 'std'] for name_saved in names_saved: results_diff = calculate_the_difference(position_diff, results_diff, name_saved, dev_n, results_diff.loc[ position_diff, name_saved + '_012' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_01' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_02' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_0' + '_' + dev_n]) test = False if test: fig, ax = plt.subplots(4, 1) ax[0].plot(time, results_diff.loc[position_diff, name_saved + '_0' + '_' + dev_n]) ax[1].plot(time, results_diff.loc[position_diff, name_saved + '_01' + '_' + dev_n]) ax[2].plot(time, results_diff.loc[position_diff, name_saved + '_02' + '_' + dev_n]) ax[3].plot(time, results_diff.loc[position_diff, name_saved + '_012' + '_' + dev_n]) plt.show() ########################################## # hier ist eine extra kondition falls wir das ohne vorheriges Mitteln vergleichen wollen würden! # die brauchen wir im default nicht aber zum abgleichen was was ist ist # d asmanchmal ganz nett # if (trials != 1 ) & ('_means_' & means_differnet): # für die verschiedenen Trials wollen wir verschiedene Konditions einführen # embed() if (trials_nr != 1) & ( '_means_' in means_different): # calc_model_amp_freqs_param. # für die verschiedenen Trials wollen wir verschiedene Konditions einführen if dev_n == 'original': array0 = mat0 array01 = mat01 array02 = mat02 array012 = mat012 # , axis = 0 elif dev_n == '05': array0 = smoothed0 array01 = smoothed01 array02 = smoothed02 array012 = smoothed012 p0, p02, p01, p012, ff = calc_ps(nfft, array012, array01, array02, array0, sampling_rate=sampling_rate) results_diff = calc_amps(ff, p0, p02, p01, p012, position_diff, [dev], 0, results_diff, results_diff, add='' + '_' + dev_n, timesstamp=False, min_amps=min_amps, points=points) ####################################### # here calculate the fft # die arrays hier sind immer eindimmensional deswegen muss man hier nicht auf trials achten! # embed() ffts_right1, freq = calc_fft(array0, array01, array012, array02, deltat, sampling) results_diff.loc[position_diff, 'diff_mean(fft)' + '_' + str(dev_n)] = np.sum(ffts_right1) * \ freq[1] arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012] arrays = [array0, array01, array02, array012] arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012] names = [cl_3names.c0, cl_3names.c01, cl_3names.c02, cl_3names.c012, ] for a, array in enumerate(arrays): results_diff.loc[position_diff, 'mean(std)_' + names[a] + '_' + str(dev_n)] = np.mean( np.std(array, axis=1)) results_diff.loc[position_diff, 'mean(var)_' + names[a] + '_' + str(dev_n)] = np.mean( np.var(array, axis=1)) names_saved = ['mean(var)', 'mean(std)'] for name_saved in names_saved: results_diff = calculate_the_difference(position_diff, results_diff, name_saved, dev_n, results_diff.loc[ position_diff, name_saved + '_012' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_01' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_02' + '_' + dev_n], results_diff.loc[ position_diff, name_saved + '_0' + '_' + dev_n]) ################################################## # für den Fall dass ich das Testen will und das zurück transferieren will # zum testen von dem phase sorting algorithmus für die Daten! if params_dict['phase_undo'] == True: embed() mean_type = 'MeanTrialsIndexPhaseSort_Min0.25sExcluded' _, _, _, _ = phase_sort_and_cut(mean_type, frame, synaptic_flt_analysis, t, sampling_time, sorted_on=sorted_on) # not there # todo: der Teil ist halt noch nicht fürs Mitteln ausgelegt, aber vielleich tbrauche ich das auch nicht else: p_array = [] ########################################## # diese Mehrfachen berechnen für eine # das hier ist so eine Analyse um, die Zenter zu berechnen und zwar ohne die Beats erstmal # ich denke 4 Harmonische reichen da schon nicht whar? # hier muss man noch einbau dass das davon abhängt ob c1 oder c2 variert! if dev_n == 'original': array0 = [np.mean(mat0, axis=0)] array01 = [np.mean(mat01, axis=0)] array02 = [np.mean(mat02, axis=0)] array012 = [np.mean(mat012, axis=0)] elif dev_n == '05': array0 = [np.mean(smoothed0, axis=0)] array01 = [np.mean(smoothed01, axis=0)] array02 = [np.mean(smoothed02, axis=0)] array012 = [np.mean(smoothed012, axis=0)] B1 = results_diff.loc[position_diff, 'B1'] if B1 != 0: try: beats_range = np.arange(np.abs(B1), 1000, np.abs(B1)) except: print('B1 thing') embed() idx = [] for beat in beats_range: idx.append(np.argmin(np.abs(ff - beat))) names = ['0', '01', '02', '012'] p_array = [p0, p01, p02, p012] for n_nr, name in enumerate(names): vals = np.sqrt(np.mean(p_array[n_nr][0][idx] * ff[1])) results_diff.loc[position_diff, 'B1_harms_all_mean_' + name + '_' + dev_n] = vals vals = np.sqrt(np.sum(p_array[n_nr][0][idx] * ff[1])) results_diff.loc[position_diff, 'B1_harms_all_sum_' + name + '_' + dev_n] = vals nrs_excluded = [0, 1, 2, 3, 4] for nr_excluded in nrs_excluded: if len(idx) > nr_excluded: vals_bef = p_array[n_nr][0][idx[nr_excluded::]] * ff[1] vals = np.sqrt(np.mean(vals_bef)) results_diff.loc[position_diff, 'B1_harms_' + str( nr_excluded) + '_all_mean_' + name + '_' + dev_n] = vals vals = np.sqrt(np.sum(vals_bef)) results_diff.loc[position_diff, 'B1_harms_' + str( nr_excluded) + '_all_mean_' + name + '_' + dev_n] = vals results_diff.loc[position_diff, 'B1_harms_' + str( nr_excluded) + '_all_center_' + name + '_' + dev_n] = ff[ idx[nr_excluded::][np.argmax(vals_bef)]] # nochmal die Vector Strength berechnen results_diff = calc_vs_amps(results_diff, freq1[0], eod_fr, arrays_spikes, position_diff, names, add='_' + dev_n) ###################################### # upper and lower bound berechnen array_smoothed = [smoothed0, smoothed01, smoothed02, smoothed012] names = ['0', '01', '02', '012'] results_diff = upper_and_lower_fr(array_smoothed, results_diff, position_diff, eod_fr, names, add='') # '' ###################################### # hist für doppelte spikes vom Phase Locking try: results_diff = find_double_spikes(eod_fr, arrays_spikes, names, results_diff, position_diff, add='_' + dev_n) except: print('double spikes problem') embed() ###################################### # phasen zu dem EOD if 'AUCI' in AUCI: add = '_' + way_all + str(datapoints) + '_' + dev_n trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, base_here, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, add=add, t_off=5, way=way_all, datapoints=datapoints, f0='f0') position_diff += 1 try: v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_array, ff except: print('missing') embed() return v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_array, ff def calc_fft(array0, array01, array012, array02, deltat, sampling): arrays = [array012, array01, array02, array0] names = [cl_3names.c012, cl_3names.c01, cl_3names.c02, cl_3names.c0] fft = {} ffts_all = calc_FFT3(arrays, deltat, fft, names) ffts_right1 = equal_to_temporal_mean(ffts_all) freq = np.fft.fftfreq(len(ffts_right1), d=1 / sampling) return ffts_right1, freq def phase_sort_and_cut(mean_type, frame, synaptic_flt_orig, t, sampling_time, sorted_on='local_reconst_big_norm'): if 'DetectionAnalysis' not in mean_type: delays_length = define_delays_trials(mean_type, frame, sorted_on=sorted_on) array012_all, array01_all, array02_all, array0_all = cut_uneven_trials( frame, synaptic_flt_orig[t], mean_type, delays_length, sampling=sampling_time) if 'extended' in mean_type: # Also das hier generiert mehr trials, indem das alles in neue gruppen regeneriert, # also die gleichen Daten sind mehrmals drin nur anders gruppiert (2*15, 3*10 und am ende auch 1*30, letzte Zeile!) # brauchen wir vor allem für das ROC Ding und im Falle vom Mitteln # also das finde icch braut array0_gr = find_group_variants(array0_all, []) array01_gr = find_group_variants(array01_all, []) array02_gr = find_group_variants(array02_all, []) array012_gr = find_group_variants(array012_all, []) print('extended ' + str(len(array0_gr))) if len(array0_gr) > 0: print('extended thing') embed() # und hier appenden wir nochmal alles als variante array0_gr.append([array0_all]) array01_gr.append([array01_all]) array02_gr.append([array02_all]) array012_gr.append([array012_all]) else: # DEFAULT array0_gr = [[array0_all]] array01_gr = [[array01_all]] array02_gr = [[array02_all]] array012_gr = [[array012_all]] else: array012_gr = [[[]]] array01_gr = [[[]]] array02_gr = [[[]]] array0_gr = [[[]]] return array012_gr, array02_gr, array01_gr, array0_gr def calc_amp_value(names, freq_step, ff, ps, position_diff, fs, devname, t, results_diff, name, add, fish, points=5): length = len(ps[ff]) results_diff = results_diff.copy() vals = [] printing = False if str(names[name]) != 'no': # wir machen das wegen der Baseline, weil die nicht immer da ist! for trial_nr in range(len(ps[ff])): if not (np.isnan(ps[ff][trial_nr])).any(): if trial_nr == 0: if printing: print('started') results_diff.loc[position_diff, 'amp_' + name + fish + add] = 0 if 'max' in name: results_diff.loc[position_diff, 'f_' + name + fish + add] = 0 if name == '': if printing: print('name == ') results_diff.loc[position_diff, 'amp_' + name + fish + add] += np.sum(ps[ff][trial_nr]) * fs[1] else: if 'max' in name: if printing: print('max') if (devname[t] == 'original') or (devname[t] == '_original') or (devname[t] == '_eod'): try: arg = np.argmax( ps[ff][trial_nr][fs < 0.5 * results_diff.EODf.loc[position_diff] + fs[1]]) except: try: arg = np.argmax( ps[ff][trial_nr][fs < 0.5 * results_diff.f0.loc[position_diff] + fs[1]]) except: print('arg stuff') embed() else: if 'harm' in name: arg = np.argmax(ps[ff][trial_nr]) * 2 else: arg = np.argmax(ps[ff][trial_nr]) if arg < len( fs): # also wenn einer der trials drüber ist dann setzten wir alles NAN einfach weil # irgendwo später im Code könnte es komische Effetke geben try: results_diff.loc[position_diff, 'f_' + name + fish + add] += fs[arg] except: print('results diff problems') embed() else: if printing: print('else') results_diff.loc[position_diff, 'f_' + name + fish + add] += float('nan') else: try: arg = np.argmin(np.abs(fs - names[name])) except: print('arg something') embed() try: if printing: print('val') # also bei den Phaselocking Sachen da nehme ich immer nur einen Peak, weil ich kriege so viele Peaks ich möchte # einen Überlapp vermeiden! Aber wenn die Auflösung fein genug ist sollte das schon passen! if arg < len(ps[ff][trial_nr]): if points == 1: val = np.sum((ps[ff][trial_nr][arg]) * freq_step) elif points == 3: val = np.sum((ps[ff][trial_nr][arg - 1:arg + 2]) * freq_step) elif points == 5: val = np.sum((ps[ff][trial_nr][arg - 2:arg + 3]) * freq_step) results_diff.loc[position_diff, 'amp_' + name + fish + add] += val vals.append(val) except: print('calc_ amp') embed() else: length -= 1 # das muss ganz am Ende stehen! # davor wurde das aufsummiert jetzt wird das geteilt! if length != 0: # ok das hier ist so ein Ding wenn ich mitten in der Funtion anhalten will das ich trotzdem bei der Hälfte rauskomme? if printing: print('div') try: results_diff.loc[position_diff, 'amp_' + name + fish + add] = results_diff.loc[ position_diff, 'amp_' + name + fish + add] / length except: print('amp problem') embed() if 'max' in name: results_diff.loc[position_diff, 'f_' + name + fish + add] = results_diff.loc[ position_diff, 'f_' + name + fish + add] / length return results_diff def peaks_of_interest(df1, df2, beat1, beat2, fr, f1, f2, eod_fr, min_amps=''): # ok das sind alle potentiell interessanten Peaks aber meistens wollen wir ja nur bestimmte, # das sollten wir stark reduzieren, hier kann man sagen nur an den und den Peaks interesisert # eignetlich interessiren uns nur 'B1_': np.abs(beat1),'B2_': np.abs(beat2), if 'min' in min_amps: names = { 'B1_': np.abs(beat1), 'B2_': np.abs(beat2), 'B1-B2_': np.abs(beat1 - beat2), 'B2-B1_': np.abs(beat2 - beat1), 'B2+B1_': np.abs(beat2 + beat1), 'B1+B2_': np.abs(beat2 + beat1), 'f0_': np.abs(eod_fr), 'f0_harm_': np.abs(eod_fr) * 2, 'f1_': f1, 'f2_': f2, 'env_beat_': np.abs(np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs( create_beat_corr(eod_fr - f2, np.array([eod_fr])))), 'fr_': fr} else: names = { 'DeltaF1_': np.abs(df1), 'DeltaF2_': np.abs(df2), 'DeltaF1_harm_': np.abs(df1) * 2, 'DeltaF2_harm_': np.abs(df2) * 2, 'B1_harm_': np.abs(beat1) * 2, 'B2_harm_': np.abs(beat2) * 2, 'B1_2harm_': np.abs(beat1) * 3, 'B2_2harm_': np.abs(beat2) * 3, 'B1_3harm_': np.abs(beat1) * 4, 'B2_3harm_': np.abs(beat2) * 4, 'F2+F1_': np.abs(f2 + f1), 'F1-F2_': np.abs(f1 - f2), 'F2-F1_': np.abs(f2 - f1), 'B1_': np.abs(beat1), 'B2_': np.abs(beat2), 'B1-B2_': np.abs(beat1 - beat2), 'B2-B1_': np.abs(beat2 - beat1), 'B2+B1_': np.abs(beat2 + beat1), 'B1+B2_': np.abs(beat2 + beat1), 'fr-B2_': np.abs(fr - beat2), 'fr-B1_': np.abs(fr - beat1), 'fr-(B2+B1)_': np.abs(fr - (beat2 + beat1)), 'fr-(B1-B2)_': np.abs(fr - np.abs(beat1 - beat2)), 'fr-(B2-B1)_': np.abs(fr - np.abs(beat2 - beat1)), 'fr+B2_': np.abs(fr + beat2), 'fr+B1_': np.abs(fr + beat1), 'fr+(B2+B1)_': np.abs(fr + (beat2 + beat1)), 'fr+(B1-B2)_': np.abs(fr + np.abs(beat1 - beat2)), 'fr+(B2-B1)_': np.abs(fr + np.abs(beat2 - beat1)), 'f0-B2_': np.abs(eod_fr - beat2), 'f0-B1_': np.abs(eod_fr - beat1), 'f0-(B2+B1)_': np.abs(eod_fr - (beat2 + beat1)), 'f0-(B1-B2)_': np.abs(eod_fr - np.abs(beat1 - beat2)), 'f0-(B2-B1)_': np.abs(eod_fr - np.abs(beat2 - beat1)), 'f0+B2_': np.abs(eod_fr + beat2), 'f0+B1_': np.abs(eod_fr + beat1), 'f0_': np.abs(eod_fr), 'f0+(B2+B1)_': np.abs(eod_fr + (beat2 + beat1)), 'f0+(B1-B2)_': np.abs(eod_fr + np.abs(beat1 - beat2)), 'f0+(B2-B1)_': np.abs(eod_fr + np.abs(beat2 - beat1)), 'f1_': f1, 'f2_': f2, 'f1_harm_': f1 * 2, 'f2_harm_': f2 * 2, 'env_': np.abs(np.abs(df1) - np.abs(df2)), 'env_beat_': np.abs(np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs( create_beat_corr(eod_fr - f2, np.array([eod_fr])))), 'env_beat_beatf0_': create_beat_corr(np.abs( np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs( create_beat_corr(eod_fr - f2, np.array([eod_fr])))), np.array([eod_fr])), 'fr_': fr, 'fr_harm_': fr * 2} return names def plt_calc_amps(results, p0, p01, p02, p012, frame, fs): fig, ax = plt.subplots(2, 4, sharex=True, sharey=True) # ax = ax.flatten() arrays = [p0[0], p01[0], p02[0], p012[0], p012[0] - p02[0], p012[0] - p01[0], p012[0] - p02[0] - p01[0], p012[0] - p02[0] - p01[0] + p0[0]] titles = ['0', '01', '02', '012', '012-02', '012-01', '012-01-02', '012-01-02+0', ] plt.suptitle(frame.cell.iloc[0]) for a, array in enumerate(arrays): ax[a].plot(fs, array, color='black') ax[a].set_title(titles[a]) ax[a].set_xlim(0, 1000) B1 = results.df1.iloc[0] B2 = results.df2.iloc[0] fr = results.fr.iloc[0] freqs = [np.abs(B2), np.abs(np.abs(np.abs(B1) - np.abs(B2))), np.abs(np.abs(B1) + np.abs(B2)), np.mean(fr), np.abs(B1), np.abs(B1) * 2, np.abs(B1) * 3, np.abs(B1) * 4, ] colors = ['green', 'purple', 'orange', 'red', 'blue', 'blue', 'blue', 'blue', ] labels = ['DF2', '|DF1-DF2|', '|DF1+DF2|', 'Baseline', 'DF1', 'DF1', 'DF1', 'DF1'] plt_peaks_several(freqs, arrays, ax[a], array, fs, labels, 0, colors, alpha=0.5) ax[a].set_xlim([0, 400]) ax[-1].legend(loc=(1, 0), ncol=1) plt.subplots_adjust(right=0.8) plt.show() def calc_pure_amps_diffs(frame, pos, names, fishes, freq_step, ps, fs, devname, t, add, points=5): for nn, name in enumerate(names): # Hier werden die Einzelfrequenzen gemacht # also hier berechne ich die als std dann ist das Hz for ff, fish in enumerate(fishes): frame = calc_amp_value(names, freq_step, ff, ps, pos, fs, devname, t, frame, name, add, fish, points=points) # Hier werden die Diff Scores gemacht # in dem default fall sqrt == '' schauen wir uns gleich die sqrt peaks an # Summieren können wir aber nur varianzen! # if sqrt = '_sqrt_': # das mache ich um das von früheren Versionen zu differnezieren! if 'amp_' + name + '01' + add in frame.keys(): frame.loc[pos, 'amp_' + name + '012-01' + add] = frame.loc[pos, 'amp_' + name + '012' + add] - \ frame.loc[pos, 'amp_' + name + '01' + add] if 'amp_' + name + '02' + add in frame.keys(): frame.loc[pos, 'amp_' + name + '012-02' + add] = (frame.loc[ pos, 'amp_' + name + '012' + add]) - \ (frame.loc[ pos, 'amp_' + name + '02' + add]) if 'amp_' + name + '01' + add in frame.keys(): frame.loc[pos, 'amp_' + name + 'diff' + add] = (frame.loc[ pos, 'amp_' + name + '012' + add]) - \ (frame.loc[ pos, 'amp_' + name + '01' + add]) - \ (frame.loc[ pos, 'amp_' + name + '02' + add]) + \ (frame.loc[ pos, 'amp_' + name + '0' + add]) frame.loc[pos, 'amp_' + name + '012-02-01' + add] = (frame.loc[ pos, 'amp_' + name + '012' + add]) - \ frame.loc[ pos, 'amp_' + name + '02' + add] - \ frame.loc[ pos, 'amp_' + name + '01' + add] return frame def find_B1B2_norm_amp_diffs(frame, norms_name, norms, pos, add): for n, norm in enumerate(norms): ################## # B1 & B2 divs = [2, 1] divs_name = ['/2', '', ] for d, div in enumerate(divs): # IMPORTANT B1_B2 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add] - frame.loc[pos, 'amp_' + 'B1_' + '01' + add]) + (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B2_' + '02' + add])) / norm # embed() B1_0102 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add] - frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[ pos, 'amp_' + 'B1_' + '02' + add])) / norm # frame.loc[pos, 'amp_' + 'B1_' + '0' + add] B2_0102 = (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[pos, 'amp_' + 'B2_' + '01' + add]) / norm B1_01020 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add] - frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[ pos, 'amp_' + 'B1_' + '02' + add]) + frame.loc[ pos, 'amp_' + 'B1_' + '0' + add]) / norm # frame.loc[pos, 'amp_' + 'B1_' + '0' + add] B2_01020 = (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[pos, 'amp_' + 'B2_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '0' + add]) / norm B1_B2_0102 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add] - frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[ pos, 'amp_' + 'B1_' + '02' + add]) + (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[ pos, 'amp_' + 'B2_' + '01' + add])) / norm frame.loc[pos, 'amp_' + 'B1&B2' + divs_name[d] + '_012-01_012-02' + norms_name[n] + add] = B1_B2 / div frame.loc[pos, 'amp_' + 'B1&B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_B2 / div frame.loc[pos, 'amp_' + 'B1' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_0102 / div frame.loc[pos, 'amp_' + 'B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B2_0102 / div frame.loc[pos, 'amp_' + 'B1' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_01020 / div frame.loc[pos, 'amp_' + 'B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B2_01020 / div # B1 & B2 Harmonische divs = [4, 1] divs_name = ['/4', '', ] Bs = ['B1_', 'B2_'] ends = ['01', '02'] B_harms = [] # here takes only two combis of Bs and ends for bb, B in enumerate(Bs): # IMPORTANT B_harm = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + 'harm_' + ends[bb] + add]) + (frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + '2harm_' + ends[bb] + add]) + (frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + '3harm_' + ends[bb] + add])) / norm B_harm_0102 = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + 'harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + 'harm_' + '02' + add]) + (frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + '2harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + '2harm_' + '02' + add]) + (frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \ -frame.loc[pos, 'amp_' + B + '3harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + '3harm_' + '02' + add])) / norm B_harm_01020 = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + 'harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + 'harm_' + '02' + add] + frame.loc[pos, 'amp_' + B + 'harm_' + '0' + add]) + (frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \ frame.loc[pos, 'amp_' + B + '2harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + '2harm_' + '02' + add] + frame.loc[pos, 'amp_' + B + '2harm_' + '0' + add]) + (frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \ -frame.loc[pos, 'amp_' + B + '3harm_' + '01' + add] - frame.loc[pos, 'amp_' + B + '3harm_' + '02' + add] + frame.loc[pos, 'amp_' + B + '3harm_' + '0' + add])) / norm if 'B1_' in B: B1_harm = B_harm B1_harm_0102 = B_harm_0102 B1_harm_01020 = B_harm_01020 else: B2_harm = B_harm B2_harm_0102 = B_harm_0102 B2_harm_01020 = B_harm_01020 B_harms.append(B_harm) for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-' + ends[bb] + norms_name[n] + add] = B_harm / div frame.loc[ pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B_harm_0102 / div frame.loc[pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-01-02+0' + norms_name[ n] + add] = B_harm_01020 / div # für das aufsummierte # here I have different controls, I guess the t the mean control ist the best diff_parts = [('02', '02'), ('01', '01'), ('02', '01'), ('01', '02')] for diff_part in diff_parts: # 3) B1-B2 & B1+B2, für die verschiedenen Kontrollen prev = ((frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + diff_part[0] + add]) + (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + diff_part[1] + add])) / norm divs = [2, 1] divs_name = ['/2', ''] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1-B2&B1+B2' + divs_name[d] + '_012-' + diff_part[0] + '_012-' + diff_part[1] + '' + norms_name[n] + add] = prev / div # 4) B1 & B2 & B1 - B2 & B1 + B2 # B1 - B2 & B1 + B2 AND EXTRA B1 & B2 divs = [4, 1] divs_name = ['/4', ''] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1&B2&B1-B2&B1+B2' + divs_name[d] + '_012-' + diff_part[0] + '_012-' + diff_part[ 1] + '' + norms_name[n] + add] = (prev + B1_B2) / div # B1-B2 B1_minus_B2_0102 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add]) / norm B1_minus_B2_01020 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add] + frame.loc[ pos, 'amp_' + 'B1-B2_' + '0' + add]) / norm B1_minus_B2 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] + frame.loc[ pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add]) / (2 * norm) divs = [2, 1] divs_name = ['/2', '', ] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1-B2' + divs_name[d] + '_mean(012-0102)' + norms_name[n] + add] = B1_minus_B2 / div frame.loc[ pos, 'amp_' + 'B1-B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_minus_B2_0102 / div frame.loc[ pos, 'amp_' + 'B1-B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_minus_B2_01020 / div # B1+B2 B1_plus_B2_0102 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / norm B1_plus_B2_01020 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add] + frame.loc[pos, 'amp_' + 'B1+B2_' + '0' + add]) / norm B1_plus_B2 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] + frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / (2 * norm) divs = [2, 1] divs_name = ['/2', ''] for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1+B2' + divs_name[d] + '_mean(012-0102)' + norms_name[n] + add] = B1_plus_B2 / div frame.loc[pos, 'amp_' + 'B1+B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_plus_B2_0102 / div frame.loc[ pos, 'amp_' + 'B1+B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_plus_B2_01020 / div # IMPORTANT # und hier kommt die Fortsetzung das ist das gleiche nur mit Mean B1_minus_B2_B1_plus_B2 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] + frame.loc[ pos, 'amp_' + 'B1-B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add] + frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] + frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \ frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / (2 * norm) # IMPORTANT divs = [2, 1] divs_name = ['/2', ''] for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[ n] + add] = B1_minus_B2_B1_plus_B2 / div divs = [4, 1] divs_name = ['/4', ''] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1&B2&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[n] + add] = ( B1_minus_B2_B1_plus_B2 + B1_B2) / div # VERY IMPORTANT # B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2 divs = [8, 1] divs_name = ['/8', ''] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[ n] + add] = (B1_minus_B2_B1_plus_B2 + B1_harm + B2_harm) / div # B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2 divs = [10, 1] divs_name = ['/10', ''] frame = frame.copy() for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[n] + add] = (B1_minus_B2_B1_plus_B2 + B1_harm + B2_harm + B1_B2) / div # VERY IMPORTANT (all with difference to two # B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2 divs = [8, 1] divs_name = ['/8', ''] for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = ( B1_minus_B2_0102 + B1_plus_B2_0102 + B1_harm_0102 + B2_harm_0102) / div # B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2 divs = [10, 1] divs_name = ['/10', ''] frame = frame.copy() for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = ( B1_minus_B2_0102 + B1_plus_B2_0102 + B1_harm_0102 + B2_harm_0102 + B1_B2_0102) / div # VERY IMPORTANT (all with difference to two # B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2 divs = [8, 1] divs_name = ['/8', ''] for d, div in enumerate(divs): frame.loc[ pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02+0' + norms_name[ n] + add] = (B1_minus_B2_01020 + B1_plus_B2_01020 + B1_harm_01020 + B2_harm_01020) / div # B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2 divs = [10, 1] divs_name = ['/10', ''] frame = frame.copy() for d, div in enumerate(divs): frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = ( B1_minus_B2_01020 + B1_plus_B2_01020 + B1_harm_01020 + B2_harm_01020 + B1_01020 + B2_01020) / div test = False if test: from utils_test import test_calc_amps2 test_calc_amps2(frame, B1_B2) return frame def find_norm_amp_diff(norms, pos, frame, names, norms_name, add): frame = frame.copy() for nn, name in enumerate(names): frame = frame.copy() for n, norm in enumerate(norms): frame.loc[pos, 'amp_' + name + '012-01' + norms_name[n] + add] = (frame.loc[ pos, 'amp_' + name + '012' + add] - \ frame.loc[ pos, 'amp_' + name + '01' + add]) / norm frame.loc[pos, 'amp_' + name + '012-02' + norms_name[n] + add] = (frame.loc[ pos, 'amp_' + name + '012' + add] - \ frame.loc[ pos, 'amp_' + name + '02' + add]) / norm return frame def find_norms(min_amps, pos, frame, add): if 'min' in min_amps: if 'norm' in min_amps: norms_name = ['_norm_01B1', '_norm_02B2', '_norm_01B1+02B2', '_norm_eodf'] norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add], frame.loc[pos, 'amp_' + 'B2_' + '02' + add], frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add], np.mean( frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add])] else: norms_name = [] norms = [] else: norms_name = ['_norm_01B1', '_norm_02B2', '_norm_01B1+02B2', '_norm_eodf'] norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add], frame.loc[pos, 'amp_' + 'B2_' + '02' + add], frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add], np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add])] return norms, norms_name def calc_var_psd_same_score(fs, p012, p01, p02, p0, pos, frame, add): # embed()power_distance_int_sqr # also den Wert könnten wir doch auch bei den Daten nehmen! # das ist der wert der uns vor allem interessiert p_val = np.sum(np.mean(np.array(p012) - np.array(p02) - np.array(p01) + np.array(p0), axis=0)) * fs[1] frame.loc[pos, 'power_distance_int' + add] = p_val if p_val > 0: frame.loc[pos, 'power_distance_int_sqrt' + add] = np.sqrt(p_val) else: frame.loc[pos, 'power_distance_int_sqrt' + add] = -np.sqrt(-p_val) ################################################## # das ist die eine ROC condition p_val = np.sum(np.mean(np.array(p012) - np.array(p02), axis=0)) * fs[1] frame.loc[pos, '012-02_power_distance_int' + add] = p_val if p_val > 0: frame.loc[pos, '012-02_power_distance_int_sqrt' + add] = np.sqrt(p_val) else: frame.loc[pos, '012-02_power_distance_int_sqrt' + add] = -np.sqrt(-p_val) ################################################### # das ist die andere ROC condition p_val = np.sum(np.mean(np.array(p01) - np.array(p0), axis=0)) * fs[1] frame.loc[pos, '01-0_power_distance_int' + add] = p_val if p_val > 0: frame.loc[pos, '01-0_power_distance_int_sqrt' + add] = np.sqrt(p_val) else: frame.loc[pos, '01-0_power_distance_int_sqrt' + add] = -np.sqrt(-p_val) return frame def find_norms_euc(min_amps, frame, pos, p012, p01, p02, p0, add): if 'min' in min_amps: if 'norm' in min_amps: norms_name = ['_norm_01B1', '_norm_02B2', '_norm_01B1+02B2', '_norm_eodf', '_norm_p012', '_norm_p01', '_norm_p02', '_norm_p0', ] norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add], frame.loc[pos, 'amp_' + 'B2_' + '02' + add], frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add], np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add]), p012, p01, p02, p0, ] else: norms_name = [] norms = [] else: norms_name = ['_norm_01B1', '_norm_02B2', '_norm_01B1+02B2', '_norm_eodf', '_norm_p012', '_norm_p01', '_norm_p02', '_norm_p0', ] norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add], frame.loc[pos, 'amp_' + 'B2_' + '02' + add], frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add], np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add]), p012, p01, p02, p0, ] return norms_name, norms def calc_euc_amp_norm(fs, diff_parts_names, add, norms_name, frame, pos, norms, diff_parts): for n, norm in enumerate(norms): for dd in range(len(diff_parts)): diffs = [] diffs_norm = [] for i in range(len(diff_parts[dd][0])): for j in range(len(diff_parts[dd][0])): diffs.append(diff_parts[dd][0][i] - diff_parts[dd][1][j]) # hier kommen die zusätzlichen norm sachen if ('B' in norms_name[n]) | ('eod' in norms_name[n]): diffs_norm.append(diff_parts[dd][0][i] / norms[n] - diff_parts[dd][1][j] / norms[n]) else: diffs_norm.append( diff_parts[dd][0][i] / np.sum(norms[n][i] * fs[1]) - diff_parts[dd][1][j] / np.sum( norms[n][i] * fs[1])) prev = np.mean(np.linalg.norm(diffs_norm, axis=1)) frame.loc[pos, 'euclidean_all_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[ n] + add] = prev if ('B' in norms_name[n]) | ('eod' in norms_name[n]): frame.loc[pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[ n] + add] = np.linalg.norm( np.mean(np.array(diff_parts[dd][0]) / norms[n] - np.array(diff_parts[dd][1]) / norms[n], axis=0)) else: frame.loc[pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[ n] + add] = np.linalg.norm(np.mean( np.transpose(np.array(diff_parts[dd][0])) / np.sum(np.array(norms[n]) * fs[1], axis=1) - np.transpose( np.array(diff_parts[dd][1])) / np.sum(np.array(norms[n]) * fs[1], axis=1), axis=1)) frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + norms_name[n] + add] = np.mean( [frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + norms_name[n] + add], frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + norms_name[n] + add]]) try: frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + norms_name[n] + add] = np.mean( [frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + norms_name[n] + add], frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + norms_name[n] + add]]) except: print('problem euclidean') embed() frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + '_norm_p01p02' + add] = np.mean( [frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + '_norm_p01' + add], frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + '_norm_p02' + add]]) frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + '_norm_p01p02' + add] = np.mean( [frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + '_norm_p01' + add], frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + '_norm_p02' + add]]) return frame def calc_euc_amp(add, frame, diff_parts, pos, diff_parts_names): frame = frame.copy() for dd in range(len(diff_parts)): diffs = [] for i in range(len(diff_parts[dd][0])): for j in range(len(diff_parts[dd][0])): diffs.append(diff_parts[dd][0][i] - diff_parts[dd][1][j]) prev = np.mean(np.linalg.norm(diffs, axis=1)) frame.loc[pos, 'euclidean_all_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + add] = prev frame.loc[ pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + add] = np.linalg.norm( np.mean(np.array(diff_parts[dd][0]) - np.array(diff_parts[dd][1]), axis=0)) frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + add] = np.mean( [frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + add], frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + add]]) frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + add] = np.mean( [frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + add], frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + add]]) return frame def calc_amps(fs, p0, p02, p01, p012, pos, devname, t, frame, results, timesstamp=False, add='', min_amps='', points=5, printing=False): fishes = ['012', '01', '02', '0'] ps = [p012, p01, p02, p0] test = False if test: plt_calc_amps(results, p0, p01, p02, p012, frame, fs) freq_step = np.abs(fs[1] - fs[0]) try: fr = results.fr.loc[pos] except: print('fr prob') embed() f2 = results.f2.loc[pos] f1 = results.f1.loc[pos] try: df1 = results.DeltaF1.loc[pos] df2 = results.DeltaF2.loc[pos] eod_fr = results.EODf.loc[pos] except: eod_fr = results.f0.loc[pos] df1 = results.df1.loc[pos] df2 = results.df2.loc[pos] try: beat1 = create_beat_corr(np.array([np.abs(df1)]), np.array([eod_fr])) except: print('beat 1 problem') embed() beat2 = create_beat_corr(np.array([np.abs(df2)]), np.array([eod_fr])) names = peaks_of_interest(df1, df2, beat1, beat2, fr, f1, f2, eod_fr, min_amps=min_amps) for name in names: frame.loc[pos, name[0:-1]] = names[name] names[''] = '' names['max_'] = '' names['max_harm_'] = '' # drei sachen # 1) erstmal nur die Veränderungen von B1 und B2 # 2) dann die Veränderung von B1,B2, # 3) dann veränderung von B1+B2 und B1-B2 und dann B1-B2, B1+B2 # 4) B1, B2, B1-B2, B1+B2 # 5) Euclidische Distanz # 6) Und noch Normierung (mit B1, B2, B1+B2) ## # VARIABLEN: 4: NORM, MEAN, VON WAS SIE DIFFERENZ, 1-6 t1 = time.time() # Nur einzelfrequenzen und deren richtigen Diffs if np.isnan(frame.loc[pos, 'f0']): print('isnan thing2') embed() frame = calc_pure_amps_diffs(frame, pos, names, fishes, freq_step, ps, fs, devname, t, add, points=points) time_first = time.time() - t1 # embed() # 1) erstmal nur die Veränderungen von einzelnen Frequenzen B1 und B2, B1+B2 und B1-B2 # hier habe ich die drei Normierungen, wir normieren immer auf B1, B2 oder beides # weil diese von der Beat Frequenz abhängen können normieren wir auch auf das EODf # das charachterisiert das Antwortverhalten der P-unit # auch diese normierungen die brauchen wir denke ich nicht # wenn alle haben will schriebe ich nix # wenn ich das absolute minimum haben will sollte min drin sein # wenn ich ein bisschen mehr haben will dann aber auch norm norms, norms_name = find_norms(min_amps, pos, frame, add) # für alle Werte auch nochmal die normierten Peaks # bei der reduzierten Version lassen wir das mit den norms, braucht je kein Mensch if norms: t1 = time.time() frame = find_norm_amp_diff(norms, pos, frame, names, norms_name, add) time_second = time.time() - t1 # 2) dann die Veränderung von B1 & B2, B1-B2 & B1+B2, if norms: t1 = time.time() frame = find_B1B2_norm_amp_diffs(frame, norms_name, norms, pos, add) time_third = time.time() - t1 ############### # DAS IST EIN WICHIGER SCORE # das gleiche wie die varianz # zweiter Score Dezember 2022 if (len(p02) > 0) & (len(p01) > 0): # todo für verschiedene Trials frame = calc_var_psd_same_score(fs, p012, p01, p02, p0, pos, frame, add) # embed() ############## # die restlichen (für talk in lissbo) # 5) Euclidische Distanz # np.linalg.norm(np.array(p012)-np.array(p01)) # die zwei sind das gleiche, also ob ich die direkt subtrahiere # alle gegen alle vergleich # np.linalg.norm(np.array(p012) - np.array(p01), axis=0)) # Trial für Trial Vergleich t1 = time.time() # 1) verschiedene Normierungen, 2) all vs not 3) was gegen was # ich glaube diese normierung über das spectrum machen wir nur damit das über die Zellen vergleich bar bleibt? norms_name, norms = find_norms_euc(min_amps, frame, pos, p012, p01, p02, p0, add) # fishes = ['norm_B2'] # norms = [frame.loc[pos, 'amp_' + 'B2_' + '02' + add]] diff_parts_names = [('012', '02'), ('012', '01')] diff_parts = [(p012, p02), (p012, p01)] if len(p02) > 0: frame = calc_euc_amp(add, frame, diff_parts, pos, diff_parts_names) if norms: frame = calc_euc_amp_norm(fs, diff_parts_names, add, norms_name, frame, pos, norms, diff_parts) time_forht = time.time() - t1 if printing: print(time_first) print(time_second) print(time_third) print(time_forht) # embed() # hier nehmen wir die Wurzel damit die Werte am Ende eben keine varianzen sondern std sind also in Hz! frame = sqrt_values(pos, frame, add) # .replace('_mean','') test = False if test: plt.plot(fs, p02[0]) plt.scatter(frame['f0'], frame['amp_f0_02_original']) plt.scatter(frame['f1'], frame['amp_f1_02_original']) plt.scatter(frame['f2'], frame['amp_f2_02_original']) plt.show() # embed() # names = 'amp_fr_012-02-01_mean' # hier nehme ich also die Sachen mit den AMPS und den Euclidischen Distanzen # und auch die Fläche und nehme nochmal die Wurzel draus! # wenn amp and euc drin ist dann nehme ich hier nochmal die Wurzel! # embed() # start_pos = np.where(np.array(keys) == 'amp_'+keys_names[0]+fishes[0]+add) if test: embed() return frame def sqrt_values(pos, frame, add=''): keys = [k for k in frame] for k in keys: if (('amp' in k) | ('euclidean' in k)) & (add in k): if frame.loc[pos, k] < 0: # für den Fall das das negativ ist machen wir das erst wieder positiv und dann wieder negativ # das gilt vor allem für die Differenz Werte frame.loc[pos, k] = -np.sqrt(-frame.loc[pos, k]) else: frame.loc[pos, k] = np.sqrt(frame.loc[pos, k]) return frame def calc_ps(nfft, array012, array01, array02, array0, sampling_rate=40000, log = '', xlim = []): p012, f = calc_ps_single(array012, nfft, sampling_rate, log = log, xlim = xlim) p01, f = calc_ps_single(array01, nfft, sampling_rate, log = log, xlim = xlim) p02, f = calc_ps_single(array02, nfft, sampling_rate, log = log, xlim = xlim) p0, f = calc_ps_single(array0, nfft, sampling_rate, log = log, xlim = xlim) return p0, p02, p01, p012, f def calc_ps_single(array012, nfft, sampling_rate, log = '', xlim = []): p012 = [[]] * len(array012) for i in range(len(array012)): p012[i], f = ml.psd(array012[i] - np.mean(array012[i]), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) if log == 'log': if len(xlim)>0: p012[i] = p012[i][f 0) and (len(base_2) > 0): results_diff.loc[position_diff, '012-0-1-2' + '_' + name_saved + '_' + title] = np.mean( contdition12 - base_0 - base_1 - base_2) if name_saved == 'var': # das ist der SCORE 1 FÜR DIE DIFFERENZEN # das ist nochmal var squared var_val = results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title] # titles_all[names[0]][t] if var_val > 0: # wenns positiv ist behalten wir das results_diff.loc[position_diff, 'diff' + '_' + 'var_sqrt' + '_' + title] = np.sqrt( results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title]) else: results_diff.loc[position_diff, 'diff' + '_' + 'var_sqrt' + '_' + title] = -np.sqrt( -results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title]) return results_diff def equal_to_temporal_mean(ffts_all): if np.shape(ffts_all) == 3: fft_val = np.abs(np.mean(ffts_all, axis=0)[3]) ** 2 - np.abs(np.mean(ffts_all, axis=0)[2]) ** 2 - np.abs( np.mean(ffts_all, axis=0)[1]) ** 2 + np.abs(np.mean(ffts_all, axis=0)[0]) ** 2 else: fft_val = np.abs(np.mean(ffts_all[cl_3names.c012], axis=0)) ** 2 - np.abs( np.mean(ffts_all[cl_3names.c01], axis=0)) ** 2 - np.abs( np.mean(ffts_all[cl_3names.c02], axis=0)) ** 2 + np.abs(np.mean(ffts_all[cl_3names.c0], axis=0)) ** 2 return fft_val class cl_3names: """A simple example class""" c012 = '012' c02 = '02' c01 = '01' c0 = '0' def calc_FFT3(arrays, deltat, fft, names): for a, array in enumerate(arrays): try: fft[names[a]] = np.fft.fft(array - np.mean(array), norm='forward') # /nfft # nas sollte forward sein except: fft[names[a]] = np.fft.fft(array - np.mean(array)) * deltat return fft def data_tuning(show=True): cells = ['2021-08-03-ac-invivo-1'] _, _ = find_all_threewave_versions() save_name_alls = [ 'calc_auc_three_AllTrialsIndexEodLocSynch_Min0.25sExcluded__multsorted2__psdEOD__minindices___nfft_32768three_AUCI_sqrt__points1.pkl'] plot_style() default_figsize(column=2, length=2) # ts=12, ls=12, fs=12 for save_name_all0 in save_name_alls: for c, cell in enumerate(cells): save_name_all = load_folder_name('threefish') + '/' + save_name_all0 name0 = save_name_all.split('_nfft')[0] + cell + '_nfft' + save_name_all.split('_nfft')[1] if '_dev' in save_name_all: name1 = save_name_all.split('_dev')[0] + cell + '_dev' + save_name_all.split('_dev')[1] else: name1 = 'xyo' if os.path.exists(name0): print(name0 + 'exists') name = name0 elif os.path.exists(name1): print(name1 + 'exists') name = name0 else: print('PROBLEM ' + str(save_name_all)) name = name0 if os.path.exists(name): frame_orig = pd.read_pickle(name) contrasts = [10] # frame_orig.c2.unique() for c, contrast2 in enumerate(contrasts): contrasts1 = [10] # frame_orig.c1.unique() for contrast1 in contrasts1: if len(frame_orig) > 0: frame = frame_orig[(frame_orig['cell'] == cell) & ( frame_orig['c2'] == contrast2) & ( frame_orig['c1'] == contrast1) & (frame_orig['dev'] == '05')] # print(np.mean(np.mean(frame.EODf.unique()))) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( nr=2) # also hier die Rows bestimmen gridspacing = [] # 0.02 grid0 = gridspec.GridSpec(1, 1, bottom=0.25, top=0.75, left=0.09, right=0.98, wspace=0.04) # axs = [] p_nrs = [7, 4, 2] # np.arange(0, len(pivot), 1)# grid1 = gridspec.GridSpecFromSubplotSpec(1, len(p_nrs), wspace=0.35, hspace=0.35, subplot_spec=grid0[0]) dfs = ['DF1', 'DF2'] pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, scores[0], orientation=[], gridspacing=gridspacing, dfs=dfs, matrix_sorted='grid_sorted') if '2' in pivot.columns.name: scores = [scores[2], scores[3], scores[0], scores[1]] for s, score in enumerate(scores): pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, score, orientation=[], gridspacing=gridspacing, dfs=dfs, matrix_sorted='grid_sorted') print('min f0 ' + str(np.min(frame.f0))) print('min f1 ' + str(np.min(frame.f1))) print('max f1 ' + str(np.max(frame.f1))) print('min f2 ' + str(np.min(frame.f2))) print('max f2 ' + str(np.max(frame.f2))) if len(pivot) > 0: if s == 0: _, _ = find_row_col(pivot) for pp, p in enumerate(p_nrs): # range(len(pivot)): ax = plt.subplot(grid1[pp]) if 'm' in dfs[0]: try: ax.set_title(pivot.index.name + ' ' + str(pivot.index[p])) except: print('ax something') embed() else: if s == 0: ax.text(1, 1.05, '$\Delta f_{' + stable_val() + '}=%s$' % ( int(pivot.index[p])) + '\,Hz', ha='right', transform=ax.transAxes) ax.plot(pivot.columns, pivot.iloc[p], color=colors[s], label=labels[s], linestyle=linestyles[s], linewidth=linewidths[s]) ax.set_xlabel(xlabel_vary()) # pivot.columns.name if pp != 0: remove_yticks(ax) else: ax.set_ylabel(representation_ylabel()) axs.append(ax) join_y(axs) fig = plt.gcf() fig.tag(axs[0:3], xoffs=-3, yoffs=1) if len(pivot) > 0: axs[0].legend(loc=(0, 1.2), ncol=2) individual_tag = save_name_all0 + '_' + cell + '_c1_' + str(contrast1) + '_c2_' + str( contrast2) + '_gridpsacing_' + str(gridspacing) save_visualization(individual_tag, show=show) def xlabel_vary(): return '$\Delta f_{' + vary_val() + '}$\,[Hz]' def tuning_f(freqs=[(39.5, -135.5)], cells_here='2011-10-25-ad-invivo-1'): plot_style() model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) trials_nrs = [1] plot_style() default_figsize(column=2, length=5.35) # 5.5)#7.5 5.75 default_figsize(column=2, length=7.5) default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(17.2)) default_ticks_talks() for _ in trials_nrs: # +[trials_nrs[-1]] scatter_extra = False for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal'] # 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() # (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0) new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.13, top=0.85, left=0.1, right=0.975, wspace=0.15) # top=0.895 ################################################### squares = False if squares: full_names_square = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ] frame_square = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv') frame_cell_square = frame_square[(frame_square.cell == cell_here)] axes = [] axes.append(plt.subplot(grid_s[0])) axes.append(plt.subplot(grid_s[1])) axes.append(plt.subplot(grid_s[2])) frame_cell_square = single_frame_processing(c_grouped, frame_cell_square) lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05') plt_cross(matrix, axes[-1]) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] show = True # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren ax_upper = [] frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) f = -1 #################################################### # hier kommt die amplituden tuning curve # frame_cell = single_frame_processing(c_grouped, frame_cell_orig) c_heres = [0.1, 0.25] # 0.03, c_colors = ['dimgrey', 'darkgrey'] # ,'black', ],'silver' freq1s = np.unique(frame_cell_orig.df1) freq2s = np.unique(frame_cell_orig.df2) # np.argmin(frame_cell['amp_B1_01_mean_original'] - frame['amp_B1_012_mean_original'])].c1 f_counter = 0 ax_uss = [] letters_all = [['$\mathrm{A_{ii}}$', '$\mathrm{A_{iii}}$'], ['$\mathrm{B_{ii}}$', '$\mathrm{B_{iii}}$']] letters_all2 = ['$\mathrm{A_{i}}$', '$\mathrm{B_{i}}$'] for freq1, freq2 in freqs: grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.15, hspace=0.73, subplot_spec=grid0[f_counter], height_ratios=[1.4, 2.35]) # hspace=0.35 1, 2.55 grid_u = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.7, wspace=0.25, subplot_spec=grid00[ 0]) # hspace=0.4,wspace=0.2,len(chirps) grid_r = gridspec.GridSpecFromSubplotSpec(1, 2, hspace=0.3, wspace=0.1, subplot_spec=grid00[1]) ################################################ grid_s = gridspec.GridSpecFromSubplotSpec(1, 3, hspace=0.7, wspace=0.45, subplot_spec=grid00[-1]) freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))] freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))] f += 1 print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here)) ax_u1_upper = plt.subplot(grid_u[0]) c_dist_recalc = dist_recalc_phaselockingchapter() ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here, sum=False, nr=2, c_dist_recalc=c_dist_recalc, linestyles=['-', '--', '-', '--', '-']) ax_u1_upper.set_yticks_delta(100) # set_xticks_delta ax_u1_upper.set_xlim(0, 35) c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here, c_dist_recalc=c_dist_recalc) height = 355 # 0 letter_plus = 30 if not c_dist_recalc: c_nrs_here_cm = np.array(c_nrs_here_cm) * 100 try: ax_u1_upper.scatter(c_nrs_here_cm, height * np.ones(len(c_nrs_here_cm)), color=c_colors, marker='v', clip_on=False, s=7) except: print('embed something') embed() for cn, cnr in enumerate(c_nrs_here_cm): ax_u1_upper.text(cnr, height + letter_plus, letters_all[f_counter][cn], ha='center', va='center', color=c_colors[cn]) ax_u1_upper.plot([cnr, cnr], [0, height], color=c_colors[cn], linewidth=lw_tuning(), zorder=100) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean_original', nr=2) color_first = 'black' # 'red' # color01 color_second = 'black' # color02 color_third = 'black' ax_u1_upper.text(0, 1.2, ' $\Delta f_{' + vary_val() + '}=%s$' % freq1_here + ' Hz' + '\n $\Delta f_{' + stable_val() + '}=%s$' % ( freq2_here) + '\,Hz, ' + '$c_{' + stable_val() + '}=10\,\%$', transform=ax_u1_upper.transAxes) # transform if f_counter != 0: ax_u1_upper.set_ylabel('') remove_yticks(ax_u1_upper) frame_cell_chosen = frame_cell_orig[ (frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)] print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str( frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here)) ################################################## # hier kommt das mit der tuning kurve freq2_here_abs = str(int(frame_cell_chosen.f2.unique())) length = '2' nfft = '4096' full_names_tunings = [ 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', ] # 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', # 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_ ax_us = [] for ft_nr, full_names_tuning in enumerate(full_names_tunings): if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'): frame_tuning = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv') print(full_names_tuning) frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)] try: frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning) except: print('something') embed() f_fixes = ['df2'] # 'df1', f_variables = ['df1'] # 'df2', freqs_fixed = [freq2_here] # freq1_here, for f_nr, f_fixed in enumerate(freqs_fixed): indexes = [[0, 1, 2, 3]]#[0, 2], for i, idx in enumerate(indexes): grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.15, wspace=0.15, subplot_spec=grid_r[ft_nr]) plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first, color_second, color_third, f_counter, f_fixed, f_fixes, f_nr, f_variables, frame_cell_orig_tuning, freq1_here, freq2_here, ft_nr, grid_rr, height, i, letter_plus, letters_all2, scatter_extra, scores, xlabel_pos = 0) f_counter += 1 if len(ax_us) > 0: join_x(ax_us) join_y(ax_us) ax_uss.append(ax_us) ######################################################### if squares: set_clim_same(ims, clims='all', same='same') join_y(ax_upper) join_x(ax_upper) join_y(ax_upper) yoffs = np.array([4, 3.5, 3.5, 3.5]) x = -3.5 # ax_uss[0][1], #embed() #tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][1]]], xoffs=np.array([x, x, x, x]), # yoffs=yoffs) # ax_uss[0][1] #tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][1]]], xoffs=np.array([x, x, x, x]), # yoffs=yoffs) # , ,ax_uss[1][1]] save_visualization(cell_here, show) def vary_contrasts_big_with_tuning3_several0(freqs=[(39.5, -135.5)], cells_here='2011-10-25-ad-invivo-1'): plot_style() model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) trials_nrs = [1] plot_style() default_figsize(column=2, length=5.35) # 5.5)#7.5 5.75 for _ in trials_nrs: # +[trials_nrs[-1]] scatter_extra = False for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal'] # 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() # (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0) new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.095, top=0.95, left=0.1, right=0.975, wspace=0.15) # top=0.895 ################################################### squares = False if squares: full_names_square = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ] frame_square = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv') frame_cell_square = frame_square[(frame_square.cell == cell_here)] axes = [] axes.append(plt.subplot(grid_s[0])) axes.append(plt.subplot(grid_s[1])) axes.append(plt.subplot(grid_s[2])) frame_cell_square = single_frame_processing(c_grouped, frame_cell_square) lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05') plt_cross(matrix, axes[-1]) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] show = True # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren ax_upper = [] frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) f = -1 #################################################### # hier kommt die amplituden tuning curve # frame_cell = single_frame_processing(c_grouped, frame_cell_orig) c_heres = [0.1, 0.25] # 0.03, c_colors = ['dimgrey', 'darkgrey'] # ,'black', ],'silver' freq1s = np.unique(frame_cell_orig.df1) freq2s = np.unique(frame_cell_orig.df2) # np.argmin(frame_cell['amp_B1_01_mean_original'] - frame['amp_B1_012_mean_original'])].c1 f_counter = 0 ax_uss = [] letters_all = [['$\mathrm{A_{ii}}$', '$\mathrm{A_{iii}}$'], ['$\mathrm{B_{ii}}$', '$\mathrm{B_{iii}}$']] letters_all2 = ['$\mathrm{A_{i}}$', '$\mathrm{B_{i}}$'] for freq1, freq2 in freqs: grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.15, hspace=0.53, subplot_spec=grid0[f_counter], height_ratios=[1, 2.35]) # hspace=0.35 1, 2.55 grid_u = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.7, wspace=0.25, subplot_spec=grid00[ 0]) # hspace=0.4,wspace=0.2,len(chirps) grid_r = gridspec.GridSpecFromSubplotSpec(1, 2, hspace=0.3, wspace=0.1, subplot_spec=grid00[1]) ################################################ grid_s = gridspec.GridSpecFromSubplotSpec(1, 3, hspace=0.7, wspace=0.45, subplot_spec=grid00[-1]) freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))] freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))] f += 1 print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here)) ax_u1_upper = plt.subplot(grid_u[0]) c_dist_recalc = dist_recalc_phaselockingchapter() ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here, sum=False, nr=2, c_dist_recalc=c_dist_recalc, linestyles=['-', '--', '-', '--', '-']) ax_u1_upper.set_yticks_delta(100) # set_xticks_delta ax_u1_upper.set_xlim(0, 35) c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here, c_dist_recalc=c_dist_recalc) height = 355 # 0 letter_plus = 30 if not c_dist_recalc: c_nrs_here_cm = np.array(c_nrs_here_cm) * 100 try: ax_u1_upper.scatter(c_nrs_here_cm, height * np.ones(len(c_nrs_here_cm)), color=c_colors, marker='v', clip_on=False, s=7) except: print('embed something') embed() for cn, cnr in enumerate(c_nrs_here_cm): ax_u1_upper.text(cnr, height + letter_plus, letters_all[f_counter][cn], ha='center', va='center', color=c_colors[cn]) ax_u1_upper.plot([cnr, cnr], [0, height], color=c_colors[cn], linewidth=lw_tuning(), zorder=100) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean_original', nr=2) color_first = 'black' # 'red' # color01 color_second = 'black' # color02 color_third = 'black' ax_u1_upper.text(0, 1, ' $\Delta f_{' + vary_val() + '}=%s$' % freq1_here + ' Hz' + '\n $\Delta f_{' + stable_val() + '}=%s$' % ( freq2_here) + '\,Hz, ' + '$c_{' + stable_val() + '}=10\,\%$', transform=ax_u1_upper.transAxes) # transform if f_counter != 0: ax_u1_upper.set_ylabel('') remove_yticks(ax_u1_upper) frame_cell_chosen = frame_cell_orig[ (frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)] print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str( frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here)) ################################################## # hier kommt das mit der tuning kurve freq2_here_abs = str(int(frame_cell_chosen.f2.unique())) length = '2' nfft = '4096' full_names_tunings = [ 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', ] # 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', # 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_ ax_us = [] for ft_nr, full_names_tuning in enumerate(full_names_tunings): if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'): frame_tuning = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv') print(full_names_tuning) frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)] try: frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning) except: print('something') embed() f_fixes = ['df2'] # 'df1', f_variables = ['df1'] # 'df2', freqs_fixed = [freq2_here] # freq1_here, for f_nr, f_fixed in enumerate(freqs_fixed): indexes = [[0, 2], [0, 1, 2, 3]] for i, idx in enumerate(indexes): grid_rr = gridspec.GridSpecFromSubplotSpec(2, 1, hspace=0.15, wspace=0.15, subplot_spec=grid_r[ft_nr]) plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first, color_second, color_third, f_counter, f_fixed, f_fixes, f_nr, f_variables, frame_cell_orig_tuning, freq1_here, freq2_here, ft_nr, grid_rr, height, i, letter_plus, letters_all2, scatter_extra, scores) f_counter += 1 if len(ax_us) > 0: join_x(ax_us) join_y(ax_us) ax_uss.append(ax_us) ######################################################### if squares: set_clim_same(ims, clims='all', same='same') join_y(ax_upper) join_x(ax_upper) join_y(ax_upper) yoffs = np.array([3, 2.5, 2.5, 2.5]) x = -3.5 # ax_uss[0][1], tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][2]]], xoffs=np.array([x, x, x, x]), yoffs=yoffs) # ax_uss[0][1] tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][2]]], xoffs=np.array([x, x, x, x]), yoffs=yoffs) # , ,ax_uss[1][1]] save_visualization(cell_here, show) def plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first, color_second, color_third, f_counter, f_fixed, f_fixes, f_nr, f_variables, frame_cell_orig_tuning, freq1_here, freq2_here, ft_nr, grid_rr, height, i, letter_plus, letters_all2, scatter_extra, scores, xlabel_pos = 1): ax_u1 = plt.subplot(grid_rr[i]) frame_f = plt_tuning_curve(c_heres[ft_nr], ax_u1, frame_cell_orig_tuning, cell_here, f_fixed, f_fixed, f_fixed=f_fixes[f_nr], index=idx, f_variable=f_variables[f_nr]) if (i == 1) & (ft_nr == 0) & (f_counter == 0): ax_u1.legend(loc=(0, 2.325), ncol=2) # .legend() ax_u1.set_title('') ax_u1.set_yticks_delta(100) # if i == 0: # ax_u1.text(0, 1.2, 'one-beat conditions') # else: if (i == 0): # (ft_nr == 0) &(f_counter == 0) & ax_u1.text(0, 1.1, '$c_{1}=%s$' % (str(int(c_heres[ft_nr] * 100))) + '$\%$', color=c_colors[ft_nr], ha='left', va='top', transform=ax_u1.transAxes) df_extra = False if df_extra: if ft_nr == 0: ax_u1.text(0, 1.15, ' $\Delta f_{' + stable_val() + '}=%s$' % ( freq2_here) + '\,Hz ' + '$ ' + c_stable_name() + '=10 \%$', color=color_third, ha='left', va='top', transform=ax_u1.transAxes) # c_colors[ft_nr]% # + ax_u1.set_xlabel(f_variables[f_nr]) ax_u1.set_xlim(-265, 265) ax_u1.set_ylim(0, 420) # embed() frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr], frame_cell_orig_tuning, f_fixed) s_big = 25 s_small = 20 # s_big = 25 # s_small = 20 if scatter_extra: ax_u1.scatter(freq1_here, frame_f[(frame_f['df2'] == freq2_here) & ( frame_f['df1'] == freq1_here)][ scores[2]], edgecolor=color_second, facecolor='white', s=s_big, alpha=0.5, marker='o', clip_on=False, zorder=100) ax_u1.scatter(freq1_here, frame_f[frame_f['df1'] == freq1_here][scores[0]], edgecolor=color_first, marker='o', zorder=120, facecolor='white', s=s_small, alpha=0.5, clip_on=False) if ft_nr == 0: ax_u1.scatter(freq2_here, frame_f[ (frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][ scores[2]], edgecolor=color_third, facecolor='white', alpha=0.5, marker='o', clip_on=False, zorder=120, s=s_small) #################### # scatter to the upper one frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr], frame_cell_orig_tuning, f_fixed) if scatter_extra: ax_u1_upper.scatter(c_heres[ft_nr] * 100, frame_f[(frame_f['df2'] == freq2_here) & ( frame_f['df1'] == freq1_here)][ scores[2]], edgecolor=color_second, facecolor='white', s=s_big, alpha=0.5, marker='o', clip_on=False, zorder=100) ax_u1_upper.scatter(c_heres[ft_nr] * 100, frame_f[frame_f['df1'] == freq1_here][scores[0]], edgecolor=color_first, marker='o', zorder=120, facecolor='white', s=s_small, alpha=0.5, clip_on=False) ############################# add = -5 # if f_counter == 0: # if i == 1: ax_u1.scatter(freq1_here + add, height, color=c_colors[ft_nr], marker='v', clip_on=False, s=7) ax_u1.plot([freq1_here + add, freq1_here + add], [0, height], color=c_colors[ft_nr], linewidth=lw_tuning(), zorder=100) ax_u1.text(freq1_here + add, height + letter_plus, letters_all2[f_counter], ha='center', color=c_colors[ft_nr], va='center') # ax_u1.scatter(freq1_here, [0], color=c_colors[ft_nr], # marker='^', # clip_on=False, s=5) # ft_nr if (f_counter == 0) & (ft_nr == 0): # f_counter == 0:f_counter ax_u1.set_ylabel(representation_ylabel()) else: ax_u1.set_ylabel('') remove_yticks(ax_u1) if i in [xlabel_pos]: # f_counter ax_u1.set_xlabel(xlabel_vary()) # ax_upper.set_xlabel(xlabel_vary()) else: ax_u1.set_xlabel('') remove_xticks(ax_u1) ax_us.append(ax_u1) def lw_tuning(): return 0.55 def vary_contrasts_big_with_tuning3_several(freqs=[(39.5, -135.5)], cells_here='2011-10-25-ad-invivo-1'): default_settings() # ts=13, ls=13, fs=13, lw = 0.7 plot_style() model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) trials_nrs = [1] plot_style() default_settings(column=2, length=7.5) for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal'] # 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.1, top=0.87, left=0.09, right=0.95, wspace=0.3) # ################################################### # squares squares = False if squares: full_names_square = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ] frame_square = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv') frame_cell_square = frame_square[(frame_square.cell == cell_here)] axes = [] axes.append(plt.subplot(grid_s[0])) axes.append(plt.subplot(grid_s[1])) axes.append(plt.subplot(grid_s[2])) frame_cell_square = single_frame_processing(c_grouped, frame_cell_square) lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05') plt_cross(matrix, axes[-1]) ################################################################# show = True # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren ax_upper = [] frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) f = -1 #################################################### # hier kommt die amplituden tuning curve # frame_cell = single_frame_processing(c_grouped, frame_cell_orig) c_heres = [0.1, 0.25, 0.5] # 0.03, c_colors = ['black', 'darkgrey', 'silver'] freq1s = np.unique(frame_cell_orig.df1) freq2s = np.unique(frame_cell_orig.df2) f_counter = 0 ax_uss = [] for freq1, freq2 in freqs: grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.35, hspace=0.35, subplot_spec=grid0[f_counter], height_ratios=[1, 3.55]) # grid_u = gridspec.GridSpecFromSubplotSpec(1, 1, hspace=0.7, wspace=0.25, subplot_spec=grid00[ 0]) # hspace=0.4,wspace=0.2,len(chirps) grid_r = gridspec.GridSpecFromSubplotSpec(3, 1, hspace=0.3, wspace=0.25, subplot_spec=grid00[1]) ################################################ grid_s = gridspec.GridSpecFromSubplotSpec(1, 3, hspace=0.7, wspace=0.45, subplot_spec=grid00[-1]) freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))] freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))] f += 1 print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here)) ax_u1_upper = plt.subplot(grid_u[0]) c_dist_recalc = dist_recalc_phaselockingchapter() ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here, sum=False, c_dist_recalc=c_dist_recalc, linestyles=['-', '--', '-', '--', '-']) c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here, c_dist_recalc=c_dist_recalc) lw = 0.75 if not c_dist_recalc: c_nrs_here_cm = np.array(c_nrs_here_cm) * 100 try: ax_u1_upper.scatter(c_nrs_here_cm, np.zeros(len(c_nrs_here_cm)), color=c_colors, marker='^', clip_on=False, s=5) except: print('embed something') embed() for m in range(len(c_nrs_here_cm)): ax_u1_upper.axvline(c_nrs_here_cm[m], color=c_colors[m], linewidth=lw) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean_original') color_first = 'red' # color01 color_second = 'purple' # color02 color_third = 'darkblue' rainbow_title(plt.gcf(), ax_u1_upper, [' $\Delta f_{s}=%s$' % freq2_here + ' Hz', ' $\Delta f_{p}=%s$' % freq1_here + ' Hz', '$c_{2}=10\%$'], [[color_second, color_first, 'black']], start_xpos=0, ha='left', y_pos=1.02) if f_counter != 0: ax_u1_upper.set_ylabel('') remove_yticks(ax_u1_upper) if f_counter == 0: ax_u1_upper.legend(loc=(0, 1.25), ncol=2) frame_cell_chosen = frame_cell_orig[ (frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)] print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str( frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here)) # hier kommt das mit der tuning kurve freq2_here_abs = str(int(frame_cell_chosen.f2.unique())) length = '2' nfft = '4096' full_names_tunings = [ 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', ] # 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_ ax_us = [] for ft_nr, full_names_tuning in enumerate(full_names_tunings): if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'): frame_tuning = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv') print(full_names_tuning) frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)] try: frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning) except: print('something') embed() f_fixes = ['df2'] # 'df1', f_variables = ['df1'] # 'df2', freqs_fixed = [freq2_here] # freq1_here, for f_nr, f_fixed in enumerate(freqs_fixed): ax_u1 = plt.subplot(grid_r[ft_nr]) ax_u1.set_title('') ax_u1.text(1, 1.15, '$c_{1}=%s$' % (str(int(c_heres[ft_nr] * 100))) + '$\%$', color=color01, ha='right', va='top', transform=ax_u1.transAxes) if ft_nr == 0: ax_u1.text(0, 1.15, ' $\Delta ' + f_stable_name() + '=%s$' % ( freq2_here) + '\,Hz ' + '$ ' + c_stable_name() + '=10 \%$', color=color_third, ha='left', va='top', transform=ax_u1.transAxes) # c_colors[ft_nr]% ax_u1.set_xlabel(f_variables[f_nr]) ax_u1.set_xlim(-300, 300) ax_u1.set_ylim(0, 420) frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr], frame_cell_orig_tuning, f_fixed) s_big = 25 s_small = 20 ax_u1.scatter(freq1_here, frame_f[(frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][ scores[2]], edgecolor=color_second, facecolor='white', s=s_big, alpha=0.5, marker='o', clip_on=False, zorder=100) ax_u1.scatter(freq1_here, frame_f[frame_f['df1'] == freq1_here][scores[0]], edgecolor=color_first, marker='o', zorder=120, facecolor='white', s=s_small, alpha=0.5, clip_on=False) if ft_nr == 0: ax_u1.scatter(freq2_here, frame_f[ (frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][scores[2]], edgecolor=color_third, facecolor='white', alpha=0.5, marker='o', clip_on=False, zorder=120, s=s_small) #################### # scatter to the upper one frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr], frame_cell_orig_tuning, f_fixed) ax_u1_upper.scatter(c_heres[ft_nr] * 100, frame_f[(frame_f['df2'] == freq2_here) & ( frame_f['df1'] == freq1_here)][ scores[2]], edgecolor=color_second, facecolor='white', s=s_big, alpha=0.5, marker='o', clip_on=False, zorder=100) ax_u1_upper.scatter(c_heres[ft_nr] * 100, frame_f[frame_f['df1'] == freq1_here][scores[0]], edgecolor=color_first, marker='o', zorder=120, facecolor='white', s=s_small, alpha=0.5, clip_on=False) ############################# ax_u1.axvline(freq1_here, color=c_colors[ft_nr], linewidth=lw) ax_u1.scatter(freq1_here, [0], color=c_colors[ft_nr], marker='^', clip_on=False, s=5) if f_counter == 0: ax_u1.set_ylabel(representation_ylabel()) else: ax_u1.set_ylabel('') remove_yticks(ax_u1) if ft_nr in [2]: ax_u1.set_xlabel(xlabel_vary()) # ax_upper.set_xlabel(xlabel_vary()) else: ax_u1.set_xlabel('') remove_xticks(ax_u1) ax_us.append(ax_u1) f_counter += 1 if len(ax_us) > 0: join_x(ax_us) join_y(ax_us) ax_uss.append(ax_us) ######################################################### if squares: set_clim_same(ims, clims='all', same='same') join_y(ax_upper) join_x(ax_upper) join_y(ax_upper) yoffs = np.array([2, 2.5, 2.5, 2.5]) x = -3 tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][1], ax_uss[0][2]]], xoffs=np.array([x, x, x, x]), yoffs=yoffs) tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][1], ax_uss[1][2]]], xoffs=np.array([x, x, x, x]), yoffs=yoffs) save_visualization(cell_here, show) def dist_recalc_phaselockingchapter(): c_dist_recalc = False return c_dist_recalc def vary_contrasts_big_with_tuning3(freqs=[(39.5, -135.5)], cells_here='2011-10-25-ad-invivo-1'): default_settings() # ts=13, ls=13, fs=13, lw = 0.7 plot_style() model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) trials_nrs = [1] plot_style() default_settings(column=2, length=6.5) for _ in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] # full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal'] # 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) grid0 = gridspec.GridSpec(1, 1, bottom=0.1, top=0.87, left=0.09, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.6, subplot_spec=grid0[0], height_ratios=[1, 2]) # grid_u = gridspec.GridSpecFromSubplotSpec(1, len(freqs), hspace=0.7, wspace=0.25, subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps) grid_r = gridspec.GridSpecFromSubplotSpec(2, 2, hspace=0.15, wspace=0.25, subplot_spec=grid00[1]) grid_s = gridspec.GridSpecFromSubplotSpec(1, 3, hspace=0.7, wspace=0.45, subplot_spec=grid00[-1]) ################################################### # squares squares = False if squares: full_names_square = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ] frame_square = pd.read_csv( load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv') frame_cell_square = frame_square[(frame_square.cell == cell_here)] axes = [] axes.append(plt.subplot(grid_s[0])) axes.append(plt.subplot(grid_s[1])) axes.append(plt.subplot(grid_s[2])) frame_cell_square = single_frame_processing(c_grouped, frame_cell_square) lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05') plt_cross(matrix, axes[-1]) ################################################################# show = True # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren ax_upper = [] frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) f = -1 #################################################### # hier kommt die amplituden tuning curve ax_upper, nfft, ax_us = amplitude_tuning_curve(ax_upper, c_grouped, cell_here, f, frame_cell_orig, freqs, grid_r, grid_u) ax_upper[0].legend(loc=(0, 1.4), ncol=6) # , f_fixed, if squares: set_clim_same(ims, clims='all', same='same') join_y(ax_upper) join_x(ax_upper) join_y(ax_upper) save_visualization(cell_here, show) def amplitude_tuning_curve(ax_upper, c_grouped, cell_here, f, frame_cell_orig, freqs, grid_r, grid_u): ################################################## # frame_cell = single_frame_processing(c_grouped, frame_cell_orig) c_heres = [0.03, 0.1, 0.25, 0.5] c_colors = ['black', 'darkgrey', 'silver', 'lightgrey'] freq1s = np.unique(frame_cell_orig.df1) freq2s = np.unique(frame_cell_orig.df2) for freq1, freq2 in freqs: freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))] freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))] f += 1 print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here)) ax_u1 = plt.subplot(grid_u[0, f]) ax_upper = plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1_here, freq2_here, sum=False, linestyles=['-', '--', '-', '--', '-']) c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here) ax_u1.scatter(c_nrs_here_cm, np.zeros(len(c_nrs_here_cm)), color=c_colors, marker='^', clip_on=False) plt.suptitle(cell_here) ax_u1.set_title(' $\Delta f_{1}=%s$' % freq1_here + ' Hz $\Delta f_{2}=%s$' % freq2_here + ' Hz') ax_upper[-1].legend(loc=(0, 0.9), ncol=4) frame_cell_chosen = frame_cell_orig[(frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)] print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str( frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here)) # hier kommt das mit der tuning kurve freq2_here_abs = str(int(frame_cell_chosen.f2.unique())) length = '2' nfft = '4096' full_names_tunings = [ 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.03_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal', ] ax_us = [] for ft_nr, full_names_tuning in enumerate(full_names_tunings): if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'): frame_tuning = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv') print(full_names_tuning) frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)] try: pass except: print('something') embed() f_variables = ['df1'] # 'df2', freqs_fixed = [freq2_here] # freq1_here, for f_nr, f_fixed in enumerate(freqs_fixed): ax_u1 = plt.subplot(grid_r[ft_nr]) plt_tuning_curve(ax_u1, f_fixed) ax_u1.set_title('') ax_u1.text(1, 1, '$c=%s$' % (c_heres[ft_nr]), color=c_colors[ft_nr], ha='right', va='top', transform=ax_u1.transAxes) ax_u1.set_xlabel(f_variables[f_nr]) ax_u1.set_xlim(-300, 300) ax_u1.scatter(freq1_here, 1, color='green', marker='^') ax_u1.scatter(freq1_here, 1, color='red', marker='^') if ft_nr in [0, 2]: ax_u1.set_ylabel('Peak Amp. [Hz]') else: ax_u1.set_ylabel('') remove_yticks(ax_u1) if ft_nr in [2, 3]: ax_u1.set_xlabel('$\Delta f_{1}$ [Hz]') else: ax_u1.set_xlabel('') remove_xticks(ax_u1) ax_us.append(ax_u1) if len(ax_us) > 0: join_x(ax_us) join_y(ax_us) return ax_upper, nfft, ax_us def single_frame_processing(c_grouped, frame_cell): frame_cell = area_vs_single_peaks_frame(frame_cell) frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell) diffs = find_deltas(frame_cell, c_grouped[0]) frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original') #new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score'] #matrix = new_frame.pivot(index='df2', columns='df1', values='diff') return frame_cell def plt_tuning_curve(c_here, ax, frame_cell, cell, freq2, dfs, f_fixed='f2', f_variable='f1', index=[0, 1, 2, 3]): labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean_original', nr=1) frame_f = frame_f_reference(c_here, cell, f_fixed, frame_cell, freq2) # try: # df2s ax.set_title(' DF2=' + str(int(dfs)) + ' Hz', fontsize=10) # , fontsize=7 except: print('f1 f2 problem') for sss in index: test = False # False if test: print('example') ax.scatter(np.array(frame_f[f_variable]), frame_f[score[sss]], zorder=100, linestyle=np.array(linestyles)[sss], color=np.array(colors)[sss], label=np.array(labels)[sss], alpha=np.array(alpha)[sss], s=3, linewidths=np.array(linewidths)[sss]) # , color = colors[sss],not found try: ax.plot(np.array(frame_f[f_variable]), frame_f[scores[sss]], zorder=100, linestyle=np.array(linestyles)[sss], color=np.array(colors)[sss], linewidth=np.array(linewidths)[sss], label=np.array(labels)[sss], alpha=np.array(alpha)[sss]) # , color = colors[sss], except: # - np.array(frame_f.f0) print('f1 thing') embed() return frame_f def frame_f_reference(c_here, cell, f_fixed, frame_cell, freq2): frame_cell = frame_cell[frame_cell['c1'] == c_here] frame_f = frame_cell[(frame_cell.cell == cell) & (frame_cell[f_fixed] == freq2)] frame_f = frame_f[frame_f.f1 != frame_f.f2] frame_f = frame_f[np.abs(frame_f.f1) != np.abs(frame_f.f2)] frame_f = frame_f[np.abs(frame_f.df1) != np.abs(frame_f.df2)] df_extra = True if df_extra: # das machen wir weil sonst kriegen wir da resonanz und die peaks sind sehr stark confidence = 10 frame_f = frame_f[np.abs(np.abs(frame_f.df1) - np.abs(frame_f.df2)) > confidence] return frame_f def plt_show_nonlin_effect_didactic_final2_only(min=0.2, cells=[], single_waves=['_SingleWave_', '_SeveralWave_', ], cell_start=13, a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1], add_half=0, show=False, nfft=int(2 ** 15), gain=1, us_name=''): model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells) < 1: cells = model_cells.cell.loc[range(cell_start, len(model_cells))] plot_style() for cell in cells: # sachen die ich variieren will ########################################### ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: trials_nr = 15 # 150 titles_amp = ['base eodf', 'baseline to Zero', ] for a, a_fr in enumerate(a_frs): default_figsize(column=2, length=3.5) grid = gridspec.GridSpec(3, 1, wspace=0.35, left=0.095, hspace=0.3, top=0.95, bottom=0.15, right=0.98) ax = {} vmem = False for aa, a_f1 in enumerate(a_f1s): SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output = outputmodel( a_fr, add_half, cell, model_cells, single_wave, trials_nr) ax[1] = plt.subplot(grid[0]) ax[1].show_spines('l') ax[1].set_ylabel('$s(t)$') ax[2] = plt.subplot(grid[1]) ax[2].show_spines('l') ax[2].set_ylabel('Repeat Nr.') ax[3] = plt.subplot(grid[2]) ax[3].show_spines('lb') power_extra = False if power_extra: ax[4] = plt.subplot(grid[:, 1]) ax[4].show_spines('lb') ax[1].set_xlim(0, xlim_here()) ax[2].set_xlim(0, xlim_here()) ax[3].set_xlim(0, xlim_here()) _, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array) _, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0) isi = np.diff(spike_adapted[0]) cv0 = np.std(isi) / np.mean(isi) for ff, freq1 in enumerate(freqs1): print('freq1' + str(freq1 - eod_fr)) print('freq2' + str(freqs2[ff] - eod_fr)) print('a_f1' + str(a_f1)) print('a_f2' + str(freqs2[ff])) freq1 = [freq1] freq2 = [freqs2[ff]] beat1 = freq1 - eod_fr titles = False if titles: plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str( np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str( beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str( np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str( a_f2) + ' tr_nr ' + str(trials_nr)) _, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1) eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2) eod_stimulus = eod_fish1 + eod_fish2 for t in range(trials_nr): stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1, eod_fr, time_array, a_f1) # damping variants std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants( stimulus, time_array, damping_type, eod_fr, gain, damping, us_name, plot=False, std_dump=0, max_dump=0, range_dump=0) stimuli.append(stimulus) cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \ rate_baseline_after[t], spikes[t], \ stimulus_altered[t], \ v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset, stimulus, adaptation_yes_e=f1, **model_params) spikes_mat = [[]] * len(spikes) pps = [[]] * len(spikes) for s in range(len(spikes)): spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat)) pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) pp_mean = np.mean(pps, axis=0) sampling_rate = 1 / deltat smoothed05 = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate) mat05 = np.mean(smoothed05, axis=0) beat1 = (freq1 - eod_fr)[0] beat2 = (freq2 - eod_fr)[0] if 'Several' in single_wave: freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(beat2 + beat1), ] # np.abs(beat2 - beat1) colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all() colors = [color_01, color_02, color_012] # 'blue' labels = ['intruder', 'female', 'intruder+female'] # , '|B1-B2|' else: freqs_beat = [np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3), np.abs(beat1 * 4)] # np.abs(beat1) / 2, colors = colors_didactic() labels = labels_didactic() # colors_didactic, labels_didactic if 'Several' in single_wave: color_beat = 'black' else: color_beat = 'black' if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1): stim_redo = True if stim_redo: eod_interp = np.cos(time_array * beat1 * 2 * np.pi) + 1 else: eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if (titles_amp[a] != 'baseline to Zero') and not ( (a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)): ax[1].plot((time_array - min) * 1000, eod_interp - 1, color=color_beat, clip_on=True) ax[1].set_ylim(np.min(eod_interp - 1) * 1.05, np.max(eod_interp - 1) * 1.05) for l in range(len(spikes)): spikes[l] = (spikes[l] - min) * 1000 if vmem: ax[0].plot((time_array - min) * 1000, v_mem_output[0], color='black') ax[0].eventplot(np.array(spikes[0]), lineoffsets=np.max(v_mem_output[0]), color='black') ax[0].set_xlim([0, 350]) ax[2].eventplot(np.array(spikes), color='black') ax[3].plot((time_array - min) * 1000, mat05, color='black') power_extra = False if power_extra: pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) log = 'log' if log: pp_mean = calc_log(pp_mean) plt_peaks_several(freqs_beat, pp_mean, ax[4], pp_mean, f, labels, 0, colors, add_log=2.5, exact=False, text_extra=True, perc_peaksize=0.2, rel='rel', ms=14, clip_on=True, log=log) # True ax[4].plot(f, pp_mean, color='black', zorder=0) ax[4].set_xlim([0, 350]) test = False if test: from utils_test import test_spikes_clusters test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000) ax[1].set_xticks([]) ax[2].set_xticks([]) ax[3].set_ylabel('Firing Rate [Hz]') ax[3].set_xlabel('Time [ms]') ax[1].set_xticks([]) ax[2].set_xticks([]) fig = plt.gcf() fig.tag(fig.axes, xoffs=-6, yoffs=1.3) plt.subplots_adjust(top=0.7, left=0.15, right=0.95, hspace=0.75, wspace=0.1) individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half) save_visualization(individual_tag, show, counter_contrast=0, savename='') def calc_log(pp_mean): pp_mean = 10 * np.log10(pp_mean / np.max(pp_mean)) return pp_mean def plt_show_nonlin_effect_didactic_final2(min=0.2, cells=[], single_waves=['_SingleWave_', '_SeveralWave_', ], cell_start=13, a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1], add_half=0, xlim=[0, 350], show=False, nfft=int(2 ** 15), gain=1, us_name=''): model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells) < 1: cells = model_cells.cell.loc[range(cell_start, len(model_cells))] plot_style() for cell in cells: ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: trials_nr = 15 # 150 titles_amp = ['base eodf', 'baseline to Zero', ] for a, a_fr in enumerate(a_frs): default_figsize(column=2, length=2.3) # 3 grid = gridspec.GridSpec(3, 2, wspace=0.35, left=0.095, hspace=0.2, top=0.94, bottom=0.25, right=0.95) ax = {} for aa, a_f1 in enumerate(a_f1s): SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output = outputmodel( a_fr, add_half, cell, model_cells, single_wave, trials_nr) ax[1] = plt.subplot(grid[0]) ax[1].show_spines('') ax[2] = plt.subplot(grid[2]) ax[2].show_spines('') ax[3] = plt.subplot(grid[4]) ax[3].show_spines('lb') ax[4] = plt.subplot(grid[:, 1]) ax[4].show_spines('lb') ax[1].set_xlim(0, xlim_here()) ax[2].set_xlim(0, xlim_here()) ax[3].set_xlim(0, xlim_here()) _, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array) _, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0) isi = np.diff(spike_adapted[0]) cv0 = np.std(isi) / np.mean(isi) for ff, freq1 in enumerate(freqs1): freq1 = [freq1] freq2 = [freqs2[ff]] beat1 = freq1 - eod_fr titles = False if titles: plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str( np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str( beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str( np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str( a_f2) + ' tr_nr ' + str(trials_nr)) _, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1) eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2) eod_stimulus = eod_fish1 + eod_fish2 for t in range(trials_nr): stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1, eod_fr, time_array, a_f1) std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants( stimulus, time_array, damping_type, eod_fr, gain, damping, us_name, plot=False, std_dump=0, max_dump=0, range_dump=0) stimuli.append(stimulus) cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \ rate_baseline_after[t], spikes[t], \ stimulus_altered[t], \ v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset, stimulus, adaptation_yes_e=f1, **model_params) spikes_mat = [[]] * len(spikes) pps = [[]] * len(spikes) for s in range(len(spikes)): spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat)) pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) pp_mean = np.mean(pps, axis=0) sampling_rate = 1 / deltat smoothed05 = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate) mat05 = np.mean(smoothed05, axis=0) beat1 = (freq1 - eod_fr)[0] beat2 = (freq2 - eod_fr)[0] nr = 2 if 'Several' in single_wave: freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(np.abs(beat2) + np.abs(beat1)) ] # np.abs(beat2 - beat1),np.abs(beat2 + beat1), colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all() colors = [color_01, color_02, color_012] # 'blue' labels = ['$f_{1}=%d$' % beat1 + '\,Hz', '$f_{2}=%d$' % beat2 + '\,Hz', '$f_{1} + f_{2}=f'+basename()+'=%d$' % ( beat1 + beat2 - 1) + '\,Hz'] # small , '|B1-B2|' add_texts = [nr, nr + 0.35, nr + 0.2] # [1.1,1.1,1.1] texts_left = [-7, -7, -7, -7] else: freqs_beat = [np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3), np.abs(beat1 * 4)] # np.abs(beat1) / 2, colors = colors_didactic() add_texts = [nr + 0.1, nr + 0.1, nr + 0.1, nr + 0.1] # [1.1,1.1,1.1,1.1] texts_left = [3, 0, 0, 0] labels = labels_didactic2() # colors_didactic, labels_didactic if 'Several' in single_wave: color_beat = 'black' else: color_beat = colors[0] if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1): eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if (titles_amp[a] != 'baseline to Zero') and not ( (a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)): ax[1].plot((time_array - min) * 1000, eod_interp, color=color_beat, clip_on=True) ax[1].set_ylim(np.min(eod_interp) * 0.98, np.max(eod_interp) * 1.02) for l in range(len(spikes)): spikes[l] = (spikes[l] - min) * 1000 ax[2].eventplot(np.array(spikes), color='black') ax[3].plot((time_array - min) * 1000, mat05, color='black') pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) log = 'log' if log: pp_mean = 10 * np.log10(pp_mean / np.max(pp_mean)) print(freqs_beat) print(labels) plt_peaks_several(freqs_beat, pp_mean, ax[4], pp_mean, f, labels, 0, colors, ha='center', add_texts=add_texts, texts_left=texts_left, add_log=2.5, rots=[0, 0, 0, 0], exact=False, text_extra=True, perc_peaksize=5, rel='rel', ms=14, clip_on=True, several_peaks=True, log=log) # True ax[4].plot(f, pp_mean, color='black', zorder=0) # 0.45 ax[4].set_xlim(xlim) test = False if test: from utils_test import test_spikes_clusters test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000) ax[1].set_xticks([]) ax[2].set_xticks([]) ax[1].set_ylabel('Beat') ax[2].set_ylabel('Spikes') ax[3].set_ylabel('Firing Rate [Hz]') if log == 'log': ax[4].set_ylabel('dB') else: ax[4].set_ylabel('Amplitude [Hz]') ax[4].set_xlabel('Frequency [Hz]') ax[3].set_xlabel('Time [ms]') ax[1].set_xticks([]) ax[2].set_xticks([]) fig = plt.gcf() tag2(fig=fig, xoffs=[-4.5, -4.5, -4.5, -5.5], yoffs=1.25) plt.subplots_adjust(top=0.6, left=0.15, right=0.95, hspace=0.5, wspace=0.1) individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half) save_visualization(individual_tag, show, counter_contrast=0, savename='') def outputmodel(a_fr, add_half, cell, model_cells, single_wave, trials_nr, freqs_mult1=None, freqs_mult2=None): try: model_params = model_cells[model_cells['cell'] == cell].iloc[0] except: print('model extract something') embed() eod_fr = model_params['EODf'] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) f1 = 0 f2 = 0 sampling_factor = '' stimulus_length = 1 phaseshift_fr = 0 phase_right = '_phaseright_' adapt_offset = 'adaptoffsetallall2' SAM = '' # , damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35 damping_type = '' exponential = '' # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr) sampling = 1 / deltat if exponential == '': pass # prepare for adapting offset due to baseline modification # now we are ready for the final modeling part in this function rate_adapted = [[]] * trials_nr rate_baseline_before = [[]] * trials_nr rate_baseline_after = [[]] * trials_nr spikes = [[]] * trials_nr v_dent_output = [[]] * trials_nr stimulus_altered = [[]] * trials_nr v_mem_output = [[]] * trials_nr spike_adapted = [[]] * trials_nr stimuli = [] offset, spike_adapted = calc_the_model_spikes(a_fr, adapt_offset, cell, deltat, eod_fish_r, f1, f2, model_params, offset, spike_adapted, trials_nr) base_cut, mat_base = find_base_fr(spike_adapted, deltat, stimulus_length, time_array) fr = np.mean(base_cut) if freqs_mult1: freqs1 = [eod_fr + fr * freqs_mult1] freqs2 = [eod_fr + fr * freqs_mult2] else: if 'Several' in single_wave: if 'Sum' in single_wave: freqs1 = [eod_fr + fr * 0.3] freqs2 = [eod_fr + fr * 0.7] else: freqs1 = [eod_fr - fr / 2 + add_half] freqs2 = [0] * len(freqs1) return SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output def xlim_here(): # 075 return 0.1 * 1000 def calc_the_model_spikes(a_fr, adapt_offset, cell, deltat, eod_fish_r, f1, f2, model_params, offset, spike_adapted, trials_nr, add=0, dent_tau_change=1, constant_reduction=1, n=1, exp_tau=1, exponential='', lower_tol=0.995, plus=1, sig_val=1, slope=1, v_exp=1, zeros='zeros', upper_tol=1.005): for t in range(trials_nr): # get the baseline properties here # baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output if a_fr == 0: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=power_here, power_alpha=alpha, power_nr=n, tau_change_choice=constant_reduction, tau_change_val=dent_tau_change, sigmoidal_mult=1, sigmoidal_plus=plus, sigmoidal_slope=slope, sigmoidal_add=add, sigmoidal_sigmoidal_val=sig_val, LIF_exponential=exponential, LIF_exponential_tau=exp_tau, LIF_expontential__v=v_exp, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 return offset, spike_adapted def plt_show_nonlin_effect_didactic(min=0.2, text='text', cells=[], add_pp=50, single_waves=['_SingleWave_', '_SeveralWave_', ], cell_start=13, zeros='zeros', a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ] , a_frs=[1], add_half=0, show=False, nfft=int(2 ** 15)): model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells) < 1: cells = model_cells.cell.loc[range(cell_start, len(model_cells))] for cell in cells: ########################################### ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: trials_nr = 150 titles_amp = ['base eodf', 'baseline to Zero', ] for a, a_fr in enumerate(a_frs): grid = gridspec.GridSpec(4, 2, wspace=0.2, left=0.05, top=0.8, bottom=0.15, right=0.98) ax = {} for aa, a_f1 in enumerate(a_f1s): try: model_params = model_cells[model_cells['cell'] == cell].iloc[0] except: print('model extract something') embed() eod_fr = model_params['EODf'] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) f1 = 0 f2 = 0 sampling_factor = '' stimulus_length = 1 phaseshift_fr = 0 phase_right = '_phaseright_' adapt_offset = 'adaptoffsetallall2' n = 1 lower_tol = 0.995 upper_tol = 1.005 SAM = '' # , exponential = '' # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr) # now we are ready for the final modeling part in this function rate_adapted = [[]] * trials_nr rate_baseline_before = [[]] * trials_nr rate_baseline_after = [[]] * trials_nr spikes = [[]] * trials_nr v_dent_output = [[]] * trials_nr stimulus_altered = [[]] * trials_nr v_mem_output = [[]] * trials_nr spike_adapted = [[]] * trials_nr stimuli = [] for t in range(trials_nr): # get the baseline properties here # baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output if a_fr == 0: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=power_here, power_alpha=alpha, power_nr=n, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 base_cut, mat_base = find_base_fr(spike_adapted, deltat, stimulus_length, time_array) fr = np.mean(base_cut) titles = [''] if 'Several' in single_wave: if 'Sum' in single_wave: freqs1 = [eod_fr + fr * 0.3] freqs2 = [eod_fr + fr * 0.7] else: freqs1 = [eod_fr - fr / 2 + add_half] freqs2 = [0] * len(freqs1) ax[0] = plt.subplot(grid[0]) ax[1] = plt.subplot(grid[2]) ax[2] = plt.subplot(grid[4]) ax[3] = plt.subplot(grid[6]) ax[4] = plt.subplot(grid[:, 1]) ax[0].set_xlim(0, 0.125 * 1000) # 0.1 * 1000 ax[1].set_xlim(0, 0.125 * 1000) ax[2].set_xlim(0, 0.125 * 1000) ax[3].set_xlim(0, 0.125 * 1000) _, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array) _, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0) isi = np.diff(spike_adapted[0]) cv0 = np.std(isi) / np.mean(isi) fs = 11 for ff, freq1 in enumerate(freqs1): freq1 = [freq1] freq2 = [freqs2[ff]] beat1 = freq1 - eod_fr plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str( np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str( beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str( np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str( a_f2) + ' tr_nr ' + str(trials_nr)) _, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1) eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2) eod_stimulus = eod_fish1 + eod_fish2 for t in range(trials_nr): stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1, eod_fr, time_array, a_f1) std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants( stimulus, time_array) stimuli.append(stimulus) cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \ rate_baseline_after[t], spikes[t], \ stimulus_altered[t], \ v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=variant, power_alpha=alpha, power_nr=n, **model_params) spikes_mat = [[]] * len(spikes) pps = [[]] * len(spikes) for s in range(len(spikes)): spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat)) pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) pp_mean = np.mean(pps, axis=0) sampling_rate = 1 / deltat smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate) mat05 = np.mean(smoothed05, axis=0) ax[0].set_title('a_f1 ' + str(a_f1), fontsize=fs) ax[0].plot((time_array - min) * 1000, stimulus, color='grey', linewidth=0.5) if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1): eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if (titles_amp[a] != 'baseline to Zero') and not ( (a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)): ax[1].plot((time_array - min) * 1000, eod_interp, color='red', clip_on=True) ax[0].plot((time_array - min) * 1000, eod_interp, color='red', clip_on=True) for l in range(len(spikes)): spikes[l] = spikes[l] * 1000 ax[2].eventplot(spikes, color='black') ax[3].plot((time_array - min) * 1000, mat05, color='black') pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) beat1 = (freq1 - eod_fr)[0] beat2 = (freq2 - eod_fr)[0] if 'Several' in single_wave: freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(beat2 + beat1), np.abs(beat2 - beat1)] colors = ['red', 'green', 'orange', 'blue'] labels = ['B1', 'B2', 'B1+B2', '|B1-B2|'] else: freqs_beat = [np.abs(beat1) / 2, np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3), np.abs(beat1 * 4)] colors = ['grey', 'red', 'orange', 'blue', 'purple'] labels = ['', 'S1', 'S2 / B1', 'S3', 'S4 / B2'] for f_nr, freq_beat in enumerate(freqs_beat): f_pos = f[np.argmin(np.abs(f - np.abs(freq_beat)))] pp_pos = pp_mean[np.argmin(np.abs(f - np.abs(freq_beat)))] ax[4].scatter(f_pos, pp_pos, color=colors[f_nr], label=labels[f_nr]) if text == 'text': ax[4].text(f_pos - 15, pp_pos + add_pp, labels[f_nr], color=colors[f_nr], fontsize=15, rotation=65) if text != 'text': plt.legend() ax[4].plot(f, pp_mean, color='black') ax[4].set_xlim([0, 700]) test = False if test: from utils_test import test_spikes_clusters test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000) ax[0].set_xticks([]) ax[1].set_xticks([]) ax[2].set_xticks([]) ax[0].set_ylabel('Amplitude') ax[1].set_ylabel('Beat') ax[2].set_ylabel('Spikes') ax[3].set_ylabel('Fr [Hz]') ax[4].set_ylabel('Amplitude [Hz]') ax[4].set_xlabel('f [Hz]') ax[3].set_xlabel('Time [ms]') ax[0].set_xticks([]) ax[1].set_xticks([]) ax[2].set_xticks([]) plt.subplots_adjust(top=0.7, left=0.15, right=0.95, hspace=0.5, wspace=0.1) individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half) save_visualization(individual_tag, show, counter_contrast=0, savename='') def get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr): if phase_right == '_phaseright_': if a_f1 == 0: phaseshift_f1 = 0 phaseshift_f2 = 0 if a_f2 == 0: phaseshift_f1 = 0 phaseshift_f2 = 0 if (a_f2 != 0) & (a_f1 != 0): phaseshift_f1 = 2 * np.pi / 4 phaseshift_f2 = 2 * np.pi / 4 else: phaseshift_f1 = phaseshift_fr phaseshift_f2 = phaseshift_fr return phaseshift_f1, phaseshift_f2 def plt_serach_nonlinearity_cell2(color=['red', 'blue', 'orange', 'purple'], log='', show=True, cells=[], add_half=0): stimulus_lengths = [100] # , 10, 30, ] # [10, 10, 100, 10, 100]#, 10, 100, 10 for _, _ in enumerate(stimulus_lengths): _, _ = find_row_col(cells) # [ for t, cell in enumerate(cells): plot_style() default_figsize(column=2, length=1.95) # .5 , figsize=(5.5, 5,) grid = gridspec.GridSpec(1, 1, left=0.09, bottom=0.27, hspace=0.3, top=0.97, wspace=0.27, right=0.97) # width_ratios=[1.7, 1], ffts = ['fft4'] sampling = '_dt' for _, _ in enumerate(ffts): labels = labels_didactic2() names = ['_all'] # '_mean', '_one', axes = [] for n, name in enumerate(names): names_key = ['c_0' + name + sampling, 'c_1' + name + sampling, 'c_2' + name + sampling, 'c_3' + name + sampling] ########################################################################### first = True if first: save_name = 'calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_0.3_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.03' frame = pd.read_pickle(load_folder_name( 'calc_model') + '/' + save_name + '.pkl') # calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_1_noAdapt__afr_1_zeros_trNr_500_fft5__dev_original_len_30_adaptoffset_bisecting__transient_1s__until_0.03 ax = plt.subplot(grid[0]) axes.append(ax) ax.axvline(frame.fr.unique() / 2, color='grey', linewidth=0.5) # 'Fr='++' Hz' for s, score_name in enumerate(names_key): ax.plot(np.abs(frame.f1 - frame.eod_fr), frame[score_name], color=color[s], label=labels[s].replace('\n', ',')) ax.set_xlim(0, 250) ax.set_xlabel('Frequency [Hz]') ax.set_ylabel('Power [Hz]') # Signal amplitude ax.show_spines('lb') try: ax.legend(loc=(0.7, 0.5), prop={'size': 9}) except: pass individual_tag = cell + '_AddHalf_' + str(add_half) + '_' + log fig = plt.gcf() fig.tag(axes, xoffs=-7.5) save_visualization(individual_tag, show, counter_contrast=0, savename='') if show: plt.show() def plt_serach_nonlinearity_cell(color=['red', 'blue', 'orange', 'purple'], log='', show=True, cells=[], add_half=0): trials_nr = [500] # , 500, 500, ] # [1, 100, 10, 150, 15]#, 300, 30, 500 stimulus_lengths = [100] # , 10, 30, ] # [10, 10, 100, 10, 100]#, 10, 100, 10 for _, _ in enumerate(stimulus_lengths): _, _ = find_row_col(cells) # [ for t, cell in enumerate(cells): plot_style() default_settings(column=2, length=3.5) # , figsize=(5.5, 5,) grid = gridspec.GridSpec(1, 2, width_ratios=[1.7, 1], left=0.11, bottom=0.15, hspace=0.3, wspace=0.27, right=0.99) ffts = ['fft4'] sampling = '_dt' for f, fft in enumerate(ffts): labels = labels_didactic() names = ['_all'] # '_mean', '_one', axes = [] for n, name in enumerate(names): names_key = ['c_0' + name + sampling, 'c_1' + name + sampling, 'c_2' + name + sampling, 'c_3' + name + sampling] first = True if first: save_name = 'calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_0.3_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.03' frame = pd.read_pickle(load_folder_name( 'calc_model') + '/' + save_name + '.pkl') # calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_1_noAdapt__afr_1_zeros_trNr_500_fft5__dev_original_len_30_adaptoffset_bisecting__transient_1s__until_0.03 ax = plt.subplot(grid[0]) axes.append(ax) ax.axvline(frame.fr.unique() / 2, color='grey', linewidth=0.5) # 'Fr='++' Hz' for s, score_name in enumerate(names_key): ax.plot(np.abs(frame.f1 - frame.eod_fr), frame[score_name], color=color[s], label=labels[s]) ax.set_xlim(0, 250) ax.set_xlabel('Beat [Hz]') ax.set_ylabel('Signal amplitude [Hz]') ax.show_spines('lb') try: ax.legend(loc=(0.7, 0.8)) except: pass save_name = load_folder_name( 'calc_model') + '/calc_nonlinearity_contrasts-_beat__AddToHalfFr_0_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.5.pkl' ax = plt.subplot(grid[1]) axes.append(ax) if os.path.exists(save_name): frame = pd.read_pickle(save_name) # load_folder_name('calc_model')+'/nonlinearity_amp_var2.pkl' frame_cell = frame[frame['cell'] == cell] if fft == 'psd': plt_nonlin(ax[t, f], frame_cell) ax[t, f].set_title('trNr ' + str( trials_nr[0]) + ' len ' + str(stimulus_lengths[0]) + ' ' + fft + ' FinalTr ' + str( np.round(stimulus_lengths[0] * trials_nr[0] * np.mean(frame_cell.fr.unique()) / 2))) else: for n, name in enumerate(names): title = False if title: plt.suptitle('trNr ' + str( trials_nr[0]) + ' len ' + str( stimulus_lengths[0]) + ' ' + fft + ' ' + name + ' Sampling ' + str( sampling) + ' FinalTr ' + str( np.round( stimulus_lengths[0] * trials_nr[0] * np.nanmean(frame_cell.fr.unique()) / 2))) ax.set_title(cell + ' CV ' + str(np.mean(frame_cell.cv.unique()))) ax_axis = frame_cell['a_f1'] * 100 for s, score_name in enumerate(names_key): ax.plot(ax_axis[1::], frame_cell[score_name][1::], color=color[s], label=labels[s]) ax.set_ylabel('Signal Amplitude [Hz]') ax.set_aspect('equal') ax.show_spines('lb') ax.set_xlabel('contrast [%]') if log == 'log': ax.set_yscale('log') ax.set_xscale('log') individual_tag = cell + '_AddHalf_' + str(add_half) + '_' + log ax = make_simple_tags(axes) save_visualization(individual_tag, show, counter_contrast=0, savename='') if show: plt.show() def labels_didactic(): labels = ['Beat ', '2 Beat / Baseline Fr ', '3 Beat', '4 Beat / 2 Baseline Fr'] return labels # $\cdot$ def labels_didactic2(): labels = [r' $f_{Stim}$ ', '$2f_{Stim}$, $f'+basename()+'$ ', r' $3f_{Stim}$ ', ' $4f_{Stim}$, $2 f'+basename()+'$'] return labels # $\cdot$ def make_simple_tags(axes, xpos=-0.03, ypos=1.02, letters=['A', 'B'], ): fig = plt.gcf() ppi = 72.0 # points per inch: fs = mpl.rcParams['font.size'] * fig.dpi / ppi for aa, ax in enumerate(axes): ax.text(xpos, ypos, letters[aa], transform=ax.transAxes, ha='right', va='bottom', fontsize=fs) return ax def make_tags(axes=[], xoffs=-3, yoffs=1.2): fig = plt.gcf() if len(axes) < 1: axes = plt.gca() fig.tag(axes, xoffs=xoffs, yoffs=yoffs) return axes def plt_nonlin(ax, frame_cell, first='c_0_all_dt', second='c_1_all_dt', third='c_2_all_dt', forth='c_3_all_dt'): # first = 'a_fundamental_original', second = 'a_h1_original', third = 'a_h2_original', forth = 'a_h3_original' ax.plot(frame_cell['a_f1'] * 100, frame_cell[first], color='blue', label='S1') ax.plot(frame_cell['a_f1'] * 100, frame_cell[second], color='orange', label='S2 / B1 ') # Baseline f [B1] / Stimulus [S2] ax.plot(frame_cell['a_f1'] * 100, frame_cell[third], color='green', label='S3') ax.plot(frame_cell['a_f1'] * 100, frame_cell[forth], color='red', label='S4 / B2') def save_name_nonlinearity(add_half, a_f1_end=0.2, transient_s=0, adapt_offset='', n=1, stimulus_length=2, freq_type='', adapt='', a_f2s=[0], freqs2=[0], dev='original', fft='fft', a_fr=1, trials_nr=150, zeros='zeros'): dev_name = '_dev_' + str(dev) version_name = '_' + fft + '_' if a_f1_end == 0.2: end_name = '' else: end_name = '_until_' + str(a_f1_end) trials_nr_name = '_trNr_' + str(trials_nr) if transient_s != 0: transient_s_name = '_transient_' + str(transient_s) + 's_' else: transient_s_name = '' if n != 1: n_name = '_power' + str(n) else: n_name = '' a_fr_name = '_afr_' + str(a_fr) + '_' + zeros if 'psd' in fft: freq_type = '' add_half_name = '_AddToHalfFr_' + str(add_half) if adapt_offset != '': adapt_offset_name = '_' + adapt_offset + '_' else: adapt_offset_name = '' # die funktion dazu ist calc_nonlinearity_contrasts NOT calc_nonlinearity_contrasts_fft if (len(freqs2) != 0) & (len(a_f2s) != 0): if (len(freqs2) > 1) & (len(a_f2s) == 1): freq_afname = 'frange_from_' + str(freqs2[0]) + '_to_' + str(freqs2[-1]) + '_in_' + str( np.diff(freqs2)[0]) + '_af2_' + str(a_f2s[0]) # a_f2s =a_f2s, freqs2 = freqs2 elif (len(freqs2) > 1) & (len(a_f2s) > 1): freq_afname = 'frange_from_' + str(freqs2[0]) + '_to_' + str(freqs2[-1]) + '_in_' + str( np.diff(freqs2)[0]) + 'af2range_from_' + str(a_f2s[0]) + '_to_' + str(a_f2s[-1]) + '_in_' + str( np.diff(a_f2s)[0]) # a_f2s =a_f2s, freqs2 = freqs2 else: freq_afname = 'freq2_' + str(freqs2[0]) + '_af2_' + str(a_f2s[0]) save_name = load_folder_name( 'calc_model') + '/' + calc_nonlinearity_contrasts.__name__ + '-' + freq_type + add_half_name + a_fr_name + trials_nr_name + version_name + dev_name + '_len_' + str( stimulus_length) + adapt + adapt_offset_name + n_name + transient_s_name + end_name + freq_afname + '.pkl' return save_name def plt_single_phaselockloss(colors, frame_cell, df, scores, cell, ax, df_name='df'): frame_df = frame_cell[(frame_cell[df_name] == df) | (np.isnan(frame_cell[df_name]))] mt_types = frame_cell.mt_type.unique() if len(mt_types) < 1: embed() for s, score in enumerate(scores): ax.set_title(cell[0:14] + ' DF=' + str(df), fontsize=8) score_vals = [] score_vals05 = [] score_vals25 = [] score_vals75 = [] score_vals95 = [] contrasts_here = [] for m, mt_type in enumerate(mt_types): if 'base' in mt_type: marker = '*' elif 'chirp' in mt_type: marker = '^' elif 'SAM' in mt_type: marker = '.' else: marker = 'o' frame_type = frame_df[frame_df.mt_type == mt_type] if mt_type != 'base': frame_type = frame_type.groupby('contrast').mean().reset_index() frame_type75 = frame_type.groupby('contrast').quantile(0.75).reset_index() frame_type25 = frame_type.groupby('contrast').quantile(0.25).reset_index() frame_type95 = frame_type.groupby('contrast').quantile(1).reset_index() frame_type05 = frame_type.groupby('contrast').quantile(0).reset_index() contasts = np.array(list(map(float, frame_type['contrast']))) frame_type['contrast'] = contasts sorted = np.argsort(contasts) score_val = frame_type[score].iloc[sorted] score_val75 = frame_type75[score].iloc[sorted] score_val25 = frame_type25[score].iloc[sorted] score_val95 = frame_type95[score].iloc[sorted] score_val05 = frame_type05[score].iloc[sorted] contrast_here = frame_type['contrast'].iloc[sorted] nr = 1 else: score_val = [np.mean(frame_type[score])] score_val75 = [np.percentile(frame_type[score], 75)] score_val25 = [np.percentile(frame_type[score], 25)] score_val95 = [np.percentile(frame_type[score], 100)] score_val05 = [np.percentile(frame_type[score], 0)] contrast_here = [0] # np.zeros(len(frame_type[score])) nr = 2 try: ax.scatter(contrast_here, score_val, marker=marker, color=colors[s], zorder=100 * nr, alpha=0.5, s=8.5) except: print('axis problem') embed() score_vals.extend(np.array(score_val)) score_vals05.extend(np.array(score_val05)) score_vals95.extend(np.array(score_val95)) score_vals75.extend(np.array(score_val75)) score_vals25.extend(np.array(score_val25)) contrasts_here.extend(np.array(contrast_here)) ax.fill_between(np.array(contrasts_here)[np.argsort(contrasts_here)], np.array(score_vals05)[np.argsort(contrasts_here)], np.array(score_vals95)[np.argsort(contrasts_here)], color=colors[s], alpha=0.2) ax.fill_between(np.array(contrasts_here)[np.argsort(contrasts_here)], np.array(score_vals25)[np.argsort(contrasts_here)], np.array(score_vals75)[np.argsort(contrasts_here)], color=colors[s], alpha=0.6) ax.plot(np.array(contrasts_here)[np.argsort(contrasts_here)], np.array(score_vals)[np.argsort(contrasts_here)], color=colors[s], label=score) def plt_beats_modulation_several_with_overview_nice_from_three_final(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass # path_new2 = load_folder_name('calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl' # aber ich will ja die neuen Datensätzte _, _, _ = find_cells_for_phaselocking() datasets = ['2023-05-12-ar-invivo-1'] frame_all = pd.read_pickle(load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') plot_style() default_settings(column=2, length=6) for i, cell in enumerate(datasets): path = load_folder_name('data') + 'cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: pass else: pass DF1s_here = [DF1s[0]] DF2s_here = [DF2s[0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s_here): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[(np.abs(frame_data_cell.m1 - DF1) < 0.01) & ( np.abs(frame_data_cell.m2 - DF2) < 0.01)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure() gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # frame_df = frame_df0[(frame_df0.c2 == c2)] # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) if len(frame_df) > 0: frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) cs = {} means = {} scores_datas = [['amp_f0_01_original', 'amp_f0_012_original', 'amp_B1_01_original', 'amp_B1_012_original'], ['amp_f0_02_original', 'amp_f0_0_original', 'amp_B2_02_original', 'amp_B2_012_original']] colorss = [['green', 'purple', 'green', 'blue'], ['orange', 'black', 'orange', 'red']] linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']] show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means, scores_datas) find_mt_all(b) contrasts = np.unique(frame_df.c1) if len(contrasts) > 0: contrasts_1_chosen, indeces_show = choice_specific_indices(contrasts, negativ='positiv', units=5, cut_val=1) nr_col = int(len(contrasts_1_chosen)) grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axps = [] for c_nr, c1 in enumerate(contrasts_1_chosen): frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) key = 'control_01' plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + key) xlim = [200, 250] nr_example = 0 ########################################### # time spikes axt = plt.subplot(grid2[1, c_nr]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key) ########################################## axt = plt.subplot(grid2[0, c_nr]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key) ########################################## # time psd axp = plt.subplot(grid2[3, c_nr]) axp2 = plt.subplot(grid2[4, c_nr]) axps.append(axp) axps.append(axp2) axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(DF2) + '_c2_' + str( c2) save_visualization(individual_tag, show) print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_from_three(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass # aber ich will ja die neuen Datensätzte datasets, loss, gain = find_cells_for_phaselocking() frame_all = pd.read_pickle(load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') # cells_for_phaselocking, loss, gain = find_cells_for_phaselocking() plot_style() default_settings() for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: DF1s_here = [DF1s[0]] else: DF1s_here = DF1s # [0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[ (np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure(figsize=(15, 9)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # frame_df = frame_df0[(frame_df0.c2 == c2)] # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) if len(frame_df) > 0: # frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) cs = {} means = {} scores_datas = [['amp_f0_01_original', 'amp_f0_012_original', 'amp_B1_01_original', 'amp_B1_012_original'], ['amp_f0_02_original', 'amp_f0_0_original', 'amp_B2_02_original', 'amp_B2_012_original']] colorss = [['green', 'purple', 'green', 'blue'], ['orange', 'black', 'orange', 'red']] linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']] show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means, scores_datas) find_mt_all(b) contrasts = np.unique(frame_df.c1) if len(contrasts) > 0: contrasts_1_chosen, indeces_show = choice_specific_indices(contrasts, negativ='positiv', units=5, cut_val=1) nr_col = int(len(contrasts_1_chosen)) grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axps = [] for c_nr, c1 in enumerate(contrasts_1_chosen): frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) key = 'control_01' plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + key) xlim = [200, 250] nr_example = 0 ########################################### # time spikes axt = plt.subplot(grid2[1, c_nr]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key) ########################################## axt = plt.subplot(grid2[0, c_nr]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key) axp = plt.subplot(grid2[3, c_nr]) axp2 = plt.subplot(grid2[4, c_nr]) axps.append(axp) axps.append(axp2) axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(DF2) + '_c2_' + str( c2) save_visualization(individual_tag, show) print('finished examples') embed() def show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means, scores_datas): for s_nr, score in enumerate(scores_datas): scores_data = scores_datas[s_nr] linestyles = linestyless[s_nr] colors = colorss[s_nr] ax = plt.subplot(grid1[s_nr]) for sss, score in enumerate(scores_data): ax.plot(np.sort(frame_df_mean['c1']), frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])], color=colors[sss], linestyle=linestyles[ sss]) # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big) if sss not in means.keys(): means[sss] = [] cs[sss] = [] ax.set_ylabel(score.replace('_mean', '').replace('amp_', '') + '[Hz]', fontsize=8) ax.set_xlabel('Contrast small') ax.set_xlabel('Contrast small') return ax def color_three(name): dict_here = {'0': 'grey', '01': 'green', '02': 'blue', '012': 'purple', 'base_0': 'grey', 'control_01': 'green', 'control_02': 'blue'} return dict_here[name] def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_final(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass _, _, _ = find_cells_for_phaselocking() datasets = ['2023-05-12-ar-invivo-1'] frame_all = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') plot_style() default_settings() for i, cell in enumerate(datasets): path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: DF1s_here = [DF1s[0]] else: DF1s_here = DF1s # [0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[ (np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): frame_df = frame_df0[(frame_df0.c2 == c2)] contrasts = np.unique(frame_df.c1)[::-1] if len(contrasts) > 0: contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1) for c_nr, c1 in enumerate(contrasts_1_chosen): if len(frame_df) > 0: frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: find_mt_all(b) keys = ['base_0', 'control_01', 'control_02', '012'] # ] nr_col = len(keys) axts = [] axps = [] frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) for nr_example in range(len(spike_times)): ################################### plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure(figsize=(20, 14)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) plt_lines_phaselockingloss(frame_df_mean, grid1) grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str( DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str( c1) + ' Trial ' + str(nr_example)) xlim = [200, 250] ########################################### # time spikes for k, key in enumerate(keys): axt = plt.subplot(grid2[1, k]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key, color=color_three(key)) ########################################## axt = plt.subplot(grid2[0, k]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key, color=color_three(key)) axt.set_title(key) axt.set_ylabel('') ########################################## # time psd axp = plt.subplot(grid2[3, k]) axp2 = plt.subplot(grid2[4, k]) spikes_mat = plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, key=key, color=color_three(key)) axps.append(axp) axps.append(axp2) axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str( DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str( nr_example) save_visualization(individual_tag, show) print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass datasets, loss, gain = find_cells_for_phaselocking() frame_all = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') # cells_for_phaselocking, loss, gain = find_cells_for_phaselocking() plot_style() default_settings() for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: DF1s_here = [DF1s[0]] else: DF1s_here = DF1s # [0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[ (np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): frame_df = frame_df0[(frame_df0.c2 == c2)] contrasts = np.unique(frame_df.c1)[::-1] if len(contrasts) > 0: contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1) for c_nr, c1 in enumerate(contrasts_1_chosen): if len(frame_df) > 0: frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: find_mt_all(b) # todo: man könnte auch heir ienfach das mt und den mt name abspeichern keys = ['base_0', 'control_01', 'control_02', '012'] # ] nr_col = len(keys) axts = [] axps = [] frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) for nr_example in range(len(spike_times)): plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure(figsize=(20, 14)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) plt_lines_phaselockingloss(frame_df_mean, grid1) grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str( DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str( c1) + ' Trial ' + str(nr_example)) xlim = [200, 250] ########################################### # time spikes for k, key in enumerate(keys): axt = plt.subplot(grid2[1, k]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key, color=color_three(key)) axt = plt.subplot(grid2[0, k]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key, color=color_three(key)) axt.set_title(key) axt.set_ylabel('') ########################################## # time psd axp = plt.subplot(grid2[3, k]) axp2 = plt.subplot(grid2[4, k]) spikes_mat = plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, key=key, color=color_three(key)) axps.append(axp) axps.append(axp2) axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str( DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str( nr_example) save_visualization(individual_tag, show) print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_single_pdf(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass # path_new2 = load_folder_name('calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl' # aber ich will ja die neuen Datensätzte datasets, loss, gain = find_cells_for_phaselocking() frame_all = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') # cells_for_phaselocking, loss, gain = find_cells_for_phaselocking() plot_style() default_settings() for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: DF1s_here = [DF1s[0]] else: DF1s_here = DF1s # [0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[ (np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)] # frame_df0 = frame_data_cell[(np.abs(frame_data_cell.m1 - DF1) < 0.02) & ( # np.abs(frame_data_cell.m2 - DF2) < 0.02)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): frame_df = frame_df0[(frame_df0.c2 == c2)] contrasts = np.unique(frame_df.c1)[::-1] if len(contrasts) > 0: contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1) for c_nr, c1 in enumerate(contrasts_1_chosen): if len(frame_df) > 0: frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: # # plt_cocktailparty_lines(ax, frame_df) find_mt_all(b) keys = ['control_01'] # ,'base_0', 'control_02', '012'] nr_col = len(keys) axts = [] axps = [] frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) for nr_example in range(len(spike_times)): ################################### plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure(figsize=(20, 14)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) plt_lines_phaselockingloss(frame_df_mean, grid1) grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col, height_ratios=[1, 1, 0.5, 1, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str( DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str( c1) + ' Trial ' + str(nr_example)) xlim = [] # time spikes for k, key in enumerate(keys): axt = plt.subplot(grid2[1, k]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key, color=color_three(key)) ########################################## axt = plt.subplot(grid2[0, k]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key, color=color_three(key)) axt.set_title(key) axt.set_ylabel('') ########################################## # time psd axp = plt.subplot(grid2[3, k]) axp2 = plt.subplot(grid2[4, k]) spikes_mat = plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, key=key, color=color_three(key)) axps.append(axp) axps.append(axp2) # hists axi = plt.subplot(grid2[5, k]) isi = calc_isi(spike_times[nr_example][key], frame_type.iloc[nr_example].EODf) axi.hist(np.concatenate(isi), bins=100) # color = 'grey', axi.axvline(x=1, color='black', linestyle='--') if len(axts) > 0: axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str( DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str( nr_example) save_visualization(individual_tag, show, pdf=True) print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_single(only_first=True, limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = False if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: pass # aber ich will ja die neuen Datensätzte datasets, loss, gain = find_cells_for_phaselocking() frame_all = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') plot_style() default_settings() for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) if os.path.exists(path): file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_in_plot_phaselocking test_in_plot_phaselocking(b, path) if cont2 == True: counter = 0 DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty( cell, frame_all) if only_first: DF1s_here = [DF1s[0]] else: DF1s_here = DF1s # [0]] for d1, DF1 in enumerate(DF1s_here): for d2, DF2 in enumerate(DF2s): # das ist blöd man sollte die abgespeicherten Ms machen frame_df0 = frame_data_cell[ (np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)] contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5] for c_nr2, c2 in enumerate(contrasts_2_chosen): frame_df = frame_df0[(frame_df0.c2 == c2)] contrasts = np.unique(frame_df.c1)[::-1] if len(contrasts) > 0: contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1) for c_nr, c1 in enumerate(contrasts_1_chosen): if len(frame_df) > 0: frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base') print(c_nr) mt_types = frame_c1.mt_type.unique() for mt_type in mt_types: find_mt_all(b) keys = ['control_01'] # ,'base_0', 'control_02', '012'] nr_col = len(keys) axts = [] axps = [] frame_type = frame_c1[ (frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base') V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock( b, cell, datas_new, frame_type, names, nfft) for nr_example in range(len(spike_times)): ################################### plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2)) plt.figure(figsize=(20, 14)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # # Vergleichsplot grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) plt_lines_phaselockingloss(frame_df_mean, grid1) grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col, height_ratios=[1, 1, 0.5, 1, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str( DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str( c1) + ' Trial ' + str(nr_example)) xlim = [200, 250] ########################################### # time spikes for k, key in enumerate(keys): axt = plt.subplot(grid2[1, k]) time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key=key, color=color_three(key)) axt = plt.subplot(grid2[0, k]) plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key=key, color=color_three(key)) axt.set_title(key) axt.set_ylabel('') ########################################## # time psd axp = plt.subplot(grid2[3, k]) axp2 = plt.subplot(grid2[4, k]) spikes_mat = plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, key=key, color=color_three(key)) axps.append(axp) axps.append(axp2) # embed() axi = plt.subplot(grid2[5, k]) # frame_type.iloc[nr_example].EODf isi = calc_isi(spike_times[nr_example][key], frame_type.iloc[nr_example].EODf) axi.hist(np.concatenate(isi), bins=100) # color = 'grey', axi.axvline(x=1, color='black', linestyle='--') if len(axts) > 0: axts[0].get_shared_y_axes().join(*axts[0::2]) axts[1].get_shared_y_axes().join(*axts[1::2]) join_x(axts) join_y(axps) set_same_ylim(axps) join_x(axps) individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str( DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str( nr_example) save_visualization(individual_tag, show) print('finished examples') embed() def plt_lines_phaselockingloss(frame_df_mean, grid1): cs = {} means = {} scores_datas = [['amp_f0_01_original', 'amp_f0_012_original', 'amp_B1_01_original', 'amp_B1_012_original'], ['amp_f0_02_original', 'amp_f0_0_original', 'amp_B2_02_original', 'amp_B2_012_original']] colorss = [['green', 'purple', 'green', 'blue'], ['orange', 'red', 'orange', 'red']] linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']] for s_nr, score in enumerate(scores_datas): scores_data = scores_datas[s_nr] linestyles = linestyless[s_nr] colors = colorss[s_nr] ax = plt.subplot(grid1[s_nr]) for sss, score in enumerate(scores_data): ax.plot(np.sort(frame_df_mean['c1']), frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])], color=colors[sss], linestyle=linestyles[ sss], label=score.replace('_mean', '').replace('amp_', '') + '[Hz]') # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big) if sss not in means.keys(): means[sss] = [] cs[sss] = [] ax.set_ylabel('Peak Amplitude', fontsize=8) ax.set_xlabel('Contrast small') ax.set_xlabel('Contrast small') ax.legend() def find_cells_for_phaselocking(): cells_for_phaselocking = [ '2023-05-12-aq-invivo-1', '2023-05-03-aa-invivo-1', '2023-05-12-al-invivo-1', '2023-05-12-ai-invivo-1', '2023-05-24-ac-invivo-1', '2023-05-12-at-invivo-1', '2023-05-12-as-invivo-1', '2023-05-12-ap-invivo-1', '2023-05-12-af-invivo-1', '2023-05-12-ae-invivo-1', '2023-05-12-ar-invivo-1', ] loss = '2023-05-12-ap-invivo-1' # (Verlust) gain = '2023-05-03-aa-invivo-1' return cells_for_phaselocking, loss, gain def load_spikes_eods_phaselock(b, cell, datas_new, frame_type, names, nfft): frame_name = frame_type # [frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mt_names = frame_type.mt_name.unique() mts = b.multi_tags[mt_names[0]] print(mts.name) eod_frs, eod_redo = get_eod_fr_simple(b, names) names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') datas_new.append(cell) try: pass except: print('rlx problem') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] spikes_mats = [] for m in mt_idxs: # range(len(mts.positions[:])) frame_features = feature_extract_cut(mts, m) zeroth_cut, first_cut, second_cut, third_cut, fish_number, fish_cuts, whole_duration, delay, cont = load_four_durations( mts, frame_features, 0, m) try: eods, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='LocalEOD-1') except: print('eods thing') embed() sampling_rate = get_sampling(b, 'EOD') if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) else: pass cut = 0.05 eod_mt, spikes_mt, time_eod, time_laod_eods, timepoint = spike_times_cocktailparty(b, delay, mts, m) v_1, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='V-1') eods_g, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='EOD') devname, smoothened2, smoothed05, mat, time_here, arrays_calc, effective_duration, spikes_cut = cut_spikes_sequences( delay, spikes_mt, sampling_rate, fish_cuts, cut=cut, fish_number=fish_number, cut_compensate=True, devname_orig=['original'], cut_length=False) spike_times.append(spikes_cut) v_1_cut, _ = cut_eod_sequences(v_1, fish_cuts, time_eod, cut=cut, rec=False, fish_number=fish_number) eods_cut, _ = cut_eod_sequences(eods, fish_cuts, time_eod, cut=cut, rec=False, fish_number=fish_number) eods_g_cut, _ = cut_eod_sequences(eods_g, fish_cuts, time_eod, cut=cut, rec=False, fish_number=fish_number) spikes_mats.append(arrays_calc[0]) test = False if test: fig, ax = plt.subplots(2, 1) ax[0].plot(np.arange(0, len(v_1_cut['control_01']) / sampling_rate, 1 / sampling_rate), v_1_cut['control_01']) ax[0].scatter(spikes_cut['control_01'][0], np.max(v_1_cut['control_01']) * np.ones(len(spikes_cut['control_01'][0]))) ax[1].plot(np.arange(0, len(arrays_calc[0]['control_01']) / sampling_rate, 1 / sampling_rate), arrays_calc[0]['control_01']) ax[1].scatter(spikes_cut['control_01'][0], np.max(arrays_calc[0]['control_01']) * np.ones(len(spikes_cut['control_01'][0]))) plt.show() eods_all.append(eods_cut) V_1.append(v_1_cut) eods_all_g.append(eods_g_cut) return V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mats[0], spikes_mats def plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key='01', color='purple'): axt.set_ylabel('local') time = np.arange(0, len(V_1[nr_example][key]) / sampling_rate, 1 / sampling_rate) * 1000 axt.plot(time, V_1[nr_example][key], color=color, linewidth=0.5) if (len(spike_times[nr_example][key][0]) > 0) & (len(V_1[nr_example][key]) > 0): try: axt.scatter((spike_times[nr_example][key][0]) * 1000, np.max(V_1[nr_example][key]) * np.ones(len(spike_times[nr_example][key][0])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), except: print('spikes something') embed() if len(xlim) > 0: axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') axts.append(axt) return time def plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, color='purple', key='01'): ps = [] for s, spikes_mat in enumerate(spikes_mats): try: p, f = ml.psd(spikes_mat[key] - np.mean(spikes_mat[key]), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) except: print('p something') embed() ps.append(p) if s == nr_example: color = color zorder = 100 axp.plot(f, p, color=color, zorder=zorder) else: color = 'grey' zorder = 1 axp2.plot(f, p, color=color, zorder=zorder) axp2.set_xlim(0, 1000) axp.set_xlim(0, 1000) remove_xticks(axp) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='--') axp2.set_xlabel('Power [Hz]') if counter != 0: remove_yticks(axp2) axp2.set_ylabel('') if counter != 0: remove_yticks(axp) axp.set_ylabel('') return spikes_mat def plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time, xlim, key='01', color='red'): stimulus = eods_all[nr_example][key] # eods_g + Efield if len(stimulus) > 0: axt.set_title(' c1' + str(np.unique(frame_type.c1)) + ' c2' + str(np.unique(frame_type.c2))) axts.append(axt) try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all') embed() try: eods_am, eod_norm = extract_am(stimulus, time, norm=False) except: print('am something') axt.plot(time, eod_norm, color='grey', linewidth=0.5) axt.plot(time, eods_am, color=color) axt.scatter(np.array(spike_times[nr_example][key][0]) * 1000, np.mean(eod_norm) * np.ones(len(spike_times[nr_example][key][0])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), if len(xlim) > 0: axt.set_xlim(xlim) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') def get_features_and_info(mts, dfs=[], contrasts=[]): features = [] id = [] for ff, f in enumerate(mts.features): if 'id' in f.data.name: id = f.data.name elif 'Contrast' in f.data.name: contrasts = mts.features[f.data.name].data[:] elif 'DeltaF' in f.data.name: dfs = mts.features[f.data.name].data[:] else: features.append(f.data.name) return features, dfs, contrasts, id def get_most_similiar_spikes(all_spikes, am_corr_cut, beat_cut, error, maxima, spikes_cut): most_similiar = np.where(error < np.sort(error)[6])[0] beat_final = [] am_final = [] spike_sm = [] spike = [] max = [] # ok wir machen das erstmal am ähnlichsten das sollte schon passen! max_corr = True if max_corr: for l in range(len(most_similiar)): beat_final.append(beat_cut[most_similiar[l]]) spike_sm.append(spikes_cut[most_similiar[l]]) spike.append(all_spikes[most_similiar[l]]) max.append(maxima[most_similiar[l]]) am_final.append(am_corr_cut[most_similiar[l]]) else: beat_final = beat_cut spike_sm = spikes_cut spike = all_spikes am_final = am_corr_cut return am_final, beat_final, most_similiar, spike, spike_sm def plt_beats_modulation_several_with_overview_nice_big_final3(contrasts_given=[], datasets=['2020-10-20-ad-invivo-1'], dfs_all_unique_given=[25], limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl') plot_style() for i, cell in enumerate(datasets): path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1'] df_pos = False if cell not in cells_exclude: if os.path.exists(path): print('exists') file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True if cont2 == True: print('cont2') counter = 0 frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) < 1: frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) > 0: if len(dfs_all_unique_given) < 1: dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1] df_name = 'df_sign' df_pos = '' # 'min_df' dfs_all_unique = list(dfs_all_unique) # todo: also hier gibts halt noch pobleme if len(np.unique(np.array(dfs_all_unique))) < 2: df_name = 'df' dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1] dfs_all_unique = list(dfs_all_unique) else: dfs_all_unique = dfs_all_unique_given df_name = 'df_sign' if len(dfs_all_unique) > 0: if df_pos == 'min_df': try: dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]] except: print('df min') embed() for df_chosen in dfs_all_unique: if not np.isnan(df_chosen): frame_df = frame_cell[frame_cell[df_name] == df_chosen] contrasts_all_unique = np.unique(frame_df.contrast) contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)] if len(contrasts_given) > 0: contrasts_all_unique = contrasts_given if len(contrasts_all_unique) > 1: mt_types = frame_df.mt_type.unique() for mt_type in mt_types: if 'base' not in mt_type: contrasts_here = [] frame_type = frame_df[ (frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base') default_figsize(column=2, length=3.5) # 5 gs0 = gridspec.GridSpec(1, 1, hspace=0.4, bottom=0.18, left=0.1, top=0.94, right=0.97) # width_ratios=[4, 1], if (cell == '2020-10-20-ad-invivo-1') & ( 50 == df_chosen): # das erst fehlt aus welchem Grund auch immer reduce = 0 else: reduce = 0 nr_col = int(len(np.unique(contrasts_all_unique))) - reduce grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 0.5, 1, 0.1, 1, ], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) # 0.7,1 axts = [] axfs = [] axps = [] mt_names = frame_type.mt_name.unique() counters = [] for m, mt_name in enumerate(mt_names): frame_name = frame_type[frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mts = b.multi_tags[mt_name] print(mts.name) name = mts.name contrast = name.split('=')[1].split('%')[0] if contrast not in contrasts_here: print(contrast) if len(np.where(np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[0]) > 0: if np.isnan(float(contrast)): counter = 0 else: try: counter = np.where( np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[0][0] - reduce # +1 except: print('something') embed() counters.append(counter) try: dfs = [mts.metadata[mts.name]['DeltaF']] * len( mts.positions[:]) except: dfs = mts.metadata['DeltaF'] features, dfs, contrasts, id = get_features_and_info(mts, dfs=dfs) eod_frs, eod_redo = get_eod_fr_simple(b, names) eod = b.data_arrays['LocalEOD-1'][:] names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] for m in mt_idxs: # range(len(mts.positions[:])) try: eods, _ = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'LocalEOD-1') except: print('eods thing') embed() eods_all.append(eods) eods_g, sampling_rate = link_arrays_eod(b, mts.positions[ :][m], mts.extents[:][ m], 'EOD') v_1, sampling_rate = link_arrays_eod(b, mts.positions[:][ m], mts.extents[:][m], 'V-1') eods_all_g.append(eods_g) V_1.append(v_1) if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) eod_fr = f[np.argmax(p)] else: eod_fr = eod_frs[m] print('EODF' + str(eod_fr)) spike_times.append( (mts.retrieve_data(m, 'Spikes-1')[:] - mts.positions[ m]) * 1000) # - cut print(len(spike_times)) smooth = [] spikes_mats = [] for s in range(len(spike_times)): try: spikes_mat = cr_spikes_mat(spike_times[s] / 1000, sampling_rate, int( mts.extents[:][ mt_idxs[ s]] * sampling_rate)) # time[-1] * sampling_rate except: print('mts prob') embed() spikes_mats.append(spikes_mat) # für den Mean danach schneiden wir das wie das kürzeste try: smooth.append(gaussian_filter( spikes_mat[ 0:int(np.min(mts.extents[:]) * sampling_rate)], sigma=0.002 * sampling_rate)) except: print('embed problem') embed() try: pass except: print('smoothed thing') embed() skip_nr = 2 xlim = [0, 1000 * skip_nr / np.abs(dfs[m])] nr_example = 0 # 'no'#0 ########################################## try: axt = plt.subplot(grid2[0, counter]) except: print('axt something') embed() axts.append(axt) stimulus = eods_all[nr_example] # eods_g + Efield try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all2') embed() eods_am, eod_norm = extract_am(stimulus, time, norm=False, kind='linear') axt.plot(time, eod_norm, color='grey', linewidth=0.5) am = False if am: axt.plot(time, eods_am, color='red') scatter_extra = False if scatter_extra: axt.scatter(spike_times[nr_example], np.mean(eod_norm) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), axt.show_spines('') axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') axt.text(1, 1, '$c_{' + vary_val() + '}=%s$' % ( contrast) + '$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % ( int(dfs[m])), ha='right', transform=axt.transAxes) ########################################### # time spikes axt = plt.subplot(grid2[1, counter]) axt.set_ylabel('local') axt.show_spines('') time = np.arange(0, len(V_1[nr_example]) / sampling_rate, 1 / sampling_rate) * 1000 # ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht # das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte umstuelp = False if umstuelp: # ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix! spikes_umstuelpen(eod, sampling_rate, time) eods_cut, spikes_cut, times_cut, cut_next, smoothed_cut = cut_spike_snippets( spike_times[nr_example], period_based=True, array_cut2=np.arange(0, len( eods_all[nr_example]) / sampling_rate, skip_nr / np.abs(dfs[m])), end=2000, smoothened=smooth[nr_example], time_eod=time / 1000, norming=False) axt.eventplot(np.array(spikes_cut[0:4]) * 1000, color='black') # lineoffsets=np.max(V_1[nr_example])* np.ones( axt.set_xlim(xlim) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') axt.show_spines('') axts.append(axt) # convolved firing rate axf = plt.subplot(grid2[2, counter]) if len(smooth[nr_example]) != len(time): time_here = time[0:len(smooth[nr_example])] else: time_here = time # [0:len(smooth[nr_example])] mean_firing = True if mean_firing: lengths = [] for sm in smoothed_cut[0:4]: lengths.append(len(sm)) sms = [] for sm in smoothed_cut[0:4]: sms.append(sm[0:np.min(lengths)]) time_here = time[0:np.shape(sms)[1]] axf.plot(time_here, np.mean(sms, axis=0), color='grey') else: axf.plot(time_here, smooth[nr_example], color='grey', ) axf.show_spines('') axf.set_xlim(xlim) axfs.append(axf) ########################################## # time psd axp2 = plt.subplot(grid2[4, counter]) ps = [] maxx = 1000 for s, spikes_mat in enumerate(spikes_mats): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) if s == nr_example: pass else: pass axp2.set_xlim(0, maxx) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='-') pp = np.mean(ps, axis=0) eodf = np.mean(frame_name.eod_fr) names = ['0', '01', '02', '012'] names_here = [names[1]] # extend = False labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept() colors_array = ['pink', color01] if float(contrast) > 2: name = names_here[ 0] else: name = 'eodf' freqs, colors_peaks, labels, alphas = chose_all_freq_combos( [], colors_array, df_chosen, maxx, eodf, color_eodf=coloer_eod_fr_core(), name= name, color_stim=color_stim_core(), color_stim_mult=color_stim_core()) # 'black'color_stim_core() plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0, colors_peaks, limit=1200, perc_peaksize=0.15, alphas=alphas, extend=extend, ms=18, clip_on=False) legend_here = False if legend_here: if (counter == 2) & (name != 'eodf'): try: handles, labels = axp2.get_legend_handles_labels() reorder_legend_handles(axp2, order=[len(labels) - 3, len(labels) - 2, len(labels) - 1], loc=(-2.5, 1), fs=9, handlelength=1, ncol=3) except: print('label something') embed() axp2.set_xlabel('Frequency [Hz]') if counter != 0: remove_yticks(axp2) else: axp2.set_ylabel(power_spectrum_name()) axps.append(axp2) ############################# # spike_times[nr_example] isis = False if isis: axi = plt.subplot(grid2[-1, counter]) plt_isis_phaselocking(axi, frame_name, spike_times) axi.set_xticks_delta(2) axi.set_xlim(0, 13) try: axts[0].get_shared_y_axes().join(*axts[0::2]) except: print('axt problem') embed() axts[1].get_shared_y_axes().join(*axts[1::2]) axts[0].get_shared_x_axes().join(*axts) join_y(axfs) join_y(axps) join_x(axps) fig = plt.gcf() fig.tag([axts[4], axts[2], axts[0], ], xoffs=-2, yoffs=1) firing_rate_scalebars(axfs[np.where(np.array(counters) == 0)[0][0]], length=10) individual_tag = 'data ' + cell + '_DF_chosen_' + str( df_chosen) + mt_type save_visualization(individual_tag, show) print('plotted') file.close() print('finished examples') def reorder_legend_handles(ax1, order=[0, 2, 4, 1, 3, 5], ncol=None, rev=False, loc=(0.65, 0.6), fs=9, handlelength=0.5): handles, labels = ax1.get_legend_handles_labels() if rev: order = [len(labels) - order[0], len(labels) - order[1], len(labels) - order[2]] hand_new = [handles[i] for i in order] label_new = [labels[i] for i in order] if fs: if ncol: first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, fontsize=fs, handlelength=handlelength, ncol=ncol) else: first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, fontsize=fs, handlelength=handlelength) else: if ncol: first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, handlelength=handlelength, ncol=ncol) else: first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, handlelength=handlelength) return first_legend def plt_beats_modulation_several_with_overview_nice_big_final2(contrasts_given=[], datasets=['2020-10-20-ad-invivo-1'], dfs_all_unique_given=[25], limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl') plot_style() for i, cell in enumerate(datasets): path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1'] df_pos = False if cell not in cells_exclude: if os.path.exists(path): print('exists') file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True if cont2 == True: print('cont2') counter = 0 frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) < 1: frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) > 0: if len(dfs_all_unique_given) < 1: dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1] df_name = 'df_sign' df_pos = '' # 'min_df' dfs_all_unique = list(dfs_all_unique) # todo: also hier gibts halt noch pobleme if len(np.unique(np.array(dfs_all_unique))) < 2: df_name = 'df' dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1] dfs_all_unique = list(dfs_all_unique) else: dfs_all_unique = dfs_all_unique_given df_name = 'df_sign' if len(dfs_all_unique) > 0: if df_pos == 'min_df': try: dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]] except: print('df min') embed() for df_chosen in dfs_all_unique: if not np.isnan(df_chosen): frame_df = frame_cell[frame_cell[df_name] == df_chosen] contrasts_all_unique = np.unique(frame_df.contrast) contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)] if len(contrasts_given) > 0: contrasts_all_unique = contrasts_given if len(contrasts_all_unique) > 1: mt_types = frame_df.mt_type.unique() for mt_type in mt_types: if 'base' not in mt_type: contrasts_here = [] frame_type = frame_df[ (frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base') default_settings(column=2, length=5) gs0 = gridspec.GridSpec(1, 1, hspace=0.4, left=0.1, top=0.94, right=0.97) # width_ratios=[4, 1], if (cell == '2020-10-20-ad-invivo-1') & ( 50 == df_chosen): # das erst fehlt aus welchem Grund auch immer reduce = 0 else: reduce = 0 nr_col = int(len(np.unique(contrasts_all_unique))) - reduce grid2 = gridspec.GridSpecFromSubplotSpec(7, nr_col, height_ratios=[1, 0.5, 1, 0.1, 1, 0.7, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axfs = [] axps = [] mt_names = frame_type.mt_name.unique() counters = [] for m, mt_name in enumerate(mt_names): frame_name = frame_type[frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mts = b.multi_tags[mt_name] print(mts.name) name = mts.name contrast = name.split('=')[1].split('%')[0] if contrast not in contrasts_here: print(contrast) if len(np.where(np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[0]) > 0: if np.isnan(float(contrast)): counter = 0 else: try: counter = np.where( np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[0][0] - reduce # +1 except: print('something') embed() counters.append(counter) try: dfs = [mts.metadata[mts.name]['DeltaF']] * len( mts.positions[:]) except: dfs = mts.metadata['DeltaF'] features, dfs, contrasts, id = get_features_and_info(mts, dfs=dfs) eod_frs, eod_redo = get_eod_fr_simple(b, names) eod = b.data_arrays['LocalEOD-1'][:] names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] for m in mt_idxs: # range(len(mts.positions[:])) try: eods, _ = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'LocalEOD-1') except: print('eods thing') embed() eods_all.append(eods) eods_g, sampling_rate = link_arrays_eod(b, mts.positions[ :][m], mts.extents[:][ m], 'EOD') v_1, sampling_rate = link_arrays_eod(b, mts.positions[:][ m], mts.extents[:][m], 'V-1') eods_all_g.append(eods_g) V_1.append(v_1) if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) eod_fr = f[np.argmax(p)] else: eod_fr = eod_frs[m] print('EODF' + str(eod_fr)) spike_times.append( (mts.retrieve_data(m, 'Spikes-1')[:] - mts.positions[ m]) * 1000) # - cut print(len(spike_times)) smooth = [] spikes_mats = [] for s in range(len(spike_times)): try: spikes_mat = cr_spikes_mat(spike_times[s] / 1000, sampling_rate, int( mts.extents[:][ mt_idxs[ s]] * sampling_rate)) # time[-1] * sampling_rate except: print('mts prob') embed() spikes_mats.append(spikes_mat) # für den Mean danach schneiden wir das wie das kürzeste try: smooth.append(gaussian_filter( spikes_mat[ 0:int(np.min(mts.extents[:]) * sampling_rate)], sigma=0.002 * sampling_rate)) except: print('embed problem') embed() try: pass except: print('smoothed thing') embed() skip_nr = 2 xlim = [0, 1000 * skip_nr / np.abs(dfs[m])] nr_example = 0 # 'no'#0 ########################################## try: axt = plt.subplot(grid2[0, counter]) except: print('axt something') embed() axts.append(axt) stimulus = eods_all[nr_example] # eods_g + Efield try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all2') embed() eods_am, eod_norm = extract_am(stimulus, time, norm=False, kind='linear') axt.plot(time, eod_norm, color='grey', linewidth=0.5) am = False if am: axt.plot(time, eods_am, color='red') scatter_extra = False if scatter_extra: axt.scatter(spike_times[nr_example], np.mean(eod_norm) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), axt.show_spines('') axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') axt.text(1, 1, '$c=%s$' % contrast + '$\%$', ha='right', transform=axt.transAxes) ########################################### # time spikes axt = plt.subplot(grid2[1, counter]) axt.set_ylabel('local') axt.show_spines('') time = np.arange(0, len(V_1[nr_example]) / sampling_rate, 1 / sampling_rate) * 1000 # ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht # das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte umstuelp = False if umstuelp: # ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix! spikes_umstuelpen(eod, sampling_rate, time) eods_cut, spikes_cut, times_cut, cut_next, smoothed_cut = cut_spike_snippets( spike_times[nr_example], period_based=True, array_cut2=np.arange(0, len( eods_all[nr_example]) / sampling_rate, skip_nr / np.abs(dfs[m])), end=2000, smoothened=smooth[nr_example], time_eod=time / 1000, norming=False) axt.eventplot(np.array(spikes_cut[0:4]) * 1000, color='black') # lineoffsets=np.max(V_1[nr_example])* np.ones( axt.set_xlim(xlim) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') axt.show_spines('') axts.append(axt) # convolved firing rate axf = plt.subplot(grid2[2, counter]) if len(smooth[nr_example]) != len(time): time_here = time[0:len(smooth[nr_example])] else: time_here = time # [0:len(smooth[nr_example])] mean_firing = True if mean_firing: lengths = [] for sm in smoothed_cut[0:4]: lengths.append(len(sm)) sms = [] for sm in smoothed_cut[0:4]: sms.append(sm[0:np.min(lengths)]) time_here = time[0:np.shape(sms)[1]] axf.plot(time_here, np.mean(sms, axis=0), color='grey', linewidth=0.5) else: axf.plot(time_here, smooth[nr_example], color='grey', linewidth=0.5) axf.show_spines('') axf.set_xlim(xlim) axfs.append(axf) ########################################## # time psd axp2 = plt.subplot(grid2[4, counter]) ps = [] maxx = 1000 for s, spikes_mat in enumerate(spikes_mats): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) if s == nr_example: pass else: pass axp2.set_xlim(0, maxx) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='-') pp = np.mean(ps, axis=0) eodf = np.mean(frame_name.eod_fr) names = ['0', '01', '02', '012'] names_here = [names[1]] # extend = True colors_array = ['pink', 'green'] if float(contrast) > 2: name = names_here[ 0] else: name = 'eodf' freqs, colors_peaks, labels, alphas = chose_all_freq_combos( [], colors_array, df_chosen, maxx, eodf, color_eodf='black', name= name, color_stim='grey', color_stim_mult='grey') plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0, colors_peaks, limit=1200, perc_peaksize=0.15, alphas=alphas, extend=extend, ms=18, clip_on=False) axp2.set_xlabel('Frequency [Hz]') if counter != 0: remove_yticks(axp2) else: axp2.set_ylabel(power_spectrum_name()) axps.append(axp2) ############################# # spike_times[nr_example] axi = plt.subplot(grid2[-1, counter]) plt_isis_phaselocking(axi, frame_name, spike_times) axi.set_xticks_delta(2) axi.set_xlim(0, 13) try: axts[0].get_shared_y_axes().join(*axts[0::2]) except: print('axt problem') embed() axts[1].get_shared_y_axes().join(*axts[1::2]) axts[0].get_shared_x_axes().join(*axts) join_y(axfs) join_y(axps) join_x(axps) fig = plt.gcf() fig.tag([axts[4], axts[2], axts[0], ], xoffs=-3) firing_rate_scalebars(axfs[np.where(np.array(counters) == 0)[0][0]], length=10) individual_tag = 'data ' + cell + '_DF_chosen_' + str( df_chosen) + mt_type save_visualization(individual_tag, show) print('plotted') file.close() print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_big_final(datasets=['2020-10-20-ad-invivo-1'], dfs_all_unique_given=[25], limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl') colors = ['red', 'green', 'purple', 'blue'] try: plot_style() except: print('plotstyle not there') if len(datasets) < 1: datasets, data_dir = find_all_dir_cells() datasets = np.sort(datasets)[::-1] stop_cell = '2018-11-20-af-invivo-1' datasets = find_stop_cell(datasets, stop_cell) for i, cell in enumerate(datasets): path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1'] df_pos = False if cell not in cells_exclude: if os.path.exists(path): print('exists') try: file = nix.File.open(path, nix.FileMode.ReadOnly) cont0 = True except: cont0 = False if cont0: b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True if cont2 == True: print('cont2') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) < 1: # falls frame 5 noch nicht fertig ist haben wir ja den Backup von davor! frame_all = pd.read_pickle( '../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) > 0: if len(dfs_all_unique_given) < 1: dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1] df_name = 'df_sign' df_pos = '' # 'min_df' dfs_all_unique = list(dfs_all_unique) # todo: also hier gibts halt noch pobleme if len(np.unique(np.array(dfs_all_unique))) < 2: df_name = 'df' dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1] dfs_all_unique = list(dfs_all_unique) else: dfs_all_unique = dfs_all_unique_given df_name = 'df_sign' if len(dfs_all_unique) > 0: if df_pos == 'min_df': try: dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]] except: print('df min') embed() for df_chosen in dfs_all_unique: if not np.isnan(df_chosen): frame_df = frame_cell[frame_cell[df_name] == df_chosen] contrasts_all_unique = np.unique(frame_df.contrast) contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)] if len(contrasts_all_unique) > 0: mt_types = frame_df.mt_type.unique() for mt_type in mt_types: if ('base' not in mt_type) & ('chirp' not in mt_type) & ( 'SAM DC-1' not in mt_type): contrasts_here = [] frame_type = frame_df[ (frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base') default_settings(column=2, length=5) plt.figure(figsize=(30, 8)) gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.1, right=0.97) # grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) scores = ['amp_stim', 'amp_df', 'amp_f0', 'amp_fmax_interval'] # 'stim', 'f0', plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) axs = plt.subplot(grid1[1]) scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_'] if scores[0] in frame_cell.keys(): plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) if (cell == '2020-10-20-ad-invivo-1') & ( 50 == df_chosen): # das erst fehlt aus welchem Grund auch immer reduce = 0 else: reduce = 0 nr_col = int(len(np.unique(contrasts_all_unique))) - reduce grid2 = gridspec.GridSpecFromSubplotSpec(7, nr_col, height_ratios=[1, 0.5, 1, 0.1, 1, 0.7, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axfs = [] axps = [] mt_names = frame_type.mt_name.unique() for m, mt_name in enumerate(mt_names): frame_name = frame_type[frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mts = b.multi_tags[mt_name] print(mts.name) name = mts.name try: contrast = name.split('=')[1].split('%')[0] cont3 = True except: cont3 = False print('contrasts') if cont3: if contrast not in contrasts_here: print(contrast) if np.isnan(float(contrast)): counter = 0 else: counter = np.where( np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[0][0] - reduce # +1 try: contrasts_here.append(contrast) except: print('embed problem') embed() try: pass except: pass features = [] for ff, f in enumerate(mts.features): if 'id' in f.data.name: pass elif 'Contrast' in f.data.name: pass elif 'DeltaF' in f.data.name: pass else: features.append(f.data.name) eod_frs, eod_redo = get_eod_fr_simple(b, names) eod = b.data_arrays['LocalEOD-1'][:] names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] for m in mt_idxs: # range(len(mts.positions[:])) try: eods, _ = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'LocalEOD-1') except: print('eods thing') embed() eods_all.append(eods) eods_g, sampling_rate = link_arrays_eod(b, mts.positions[ :][m], mts.extents[ :][m], 'EOD') v_1, sampling_rate = link_arrays_eod(b, mts.positions[ :][m], mts.extents[:][ m], 'V-1') eods_all_g.append(eods_g) V_1.append(v_1) if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) else: pass spikes = link_arrays_spikes(b, first=mts.positions[:][ m], second=mts.extents[:][ m], minus_spikes= mts.positions[:][ m]) * 1000 spike_times.append(spikes) # - cut# print(len(spike_times)) smooth = [] spikes_mats = [] for s in range(len(spike_times)): try: spikes_mat = cr_spikes_mat( spike_times[s] / 1000, sampling_rate, int( mts.extents[:][ mt_idxs[ s]] * sampling_rate)) # time[-1] * sampling_rate except: print('mts prob') embed() spikes_mats.append(spikes_mat) # für den Mean danach schneiden wir das wie das kürzeste try: smooth.append(gaussian_filter( spikes_mat[ 0:int(np.min( mts.extents[:]) * sampling_rate)], sigma=0.0005 * sampling_rate)) except: print('embed problem') embed() try: smooth_mean = np.mean(smooth, axis=0) except: print('smoothed thing') embed() plt.suptitle('data ' + cell + ' ' + mts.name) xlim = [0, 40] nr_example = 0 # 'no'#0 ########################################## try: axt = plt.subplot(grid2[0, counter]) except: print('axt something') embed() axts.append(axt) stimulus = eods_all[nr_example] # eods_g + Efield try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all2') embed() eods_am, eod_norm = extract_am(stimulus, time, norm=False, kind='linear') # 'cubic' axt.plot(time, eod_norm, color='grey', linewidth=0.5) axt.plot(time, eods_am, color='red') axt.scatter(spike_times[nr_example], np.mean(eod_norm) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), try: axt.show_spines('') except: print('not there') axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') axt.text(1, 1, 'c=' + str(contrast), ha='right', transform=axt.transAxes) ########################################### # time spikes axt = plt.subplot(grid2[1, counter]) axt.set_ylabel('local') try: axt.show_spines('') except: print('not there') time = np.arange(0, len(V_1[nr_example]) / sampling_rate, 1 / sampling_rate) * 1000 # ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht # das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte umstuelp = False if umstuelp: # ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix! spikes_umstuelpen(eod, sampling_rate, time) axt.scatter(spike_times[nr_example], np.max(V_1[nr_example]) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), axt.set_xlim(xlim) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') try: axt.show_spines('') except: print('not there') axts.append(axt) # convolved firing rate axf = plt.subplot(grid2[2, counter]) if len(smooth[nr_example]) != len(time): time_here = time[0:len(smooth[nr_example])] else: time_here = time # [0:len(smooth[nr_example])] mean_firing = True # smooth_mean if mean_firing: axf.plot(time_here, smooth_mean, color='grey', linewidth=0.5) else: axf.plot(time_here, smooth[nr_example], color='grey', linewidth=0.5) try: axt.show_spines('') except: print('not there') axf.set_xlim(xlim) axfs.append(axf) if counter == 0: firing_rate_scalebars(axf) ########################################## # time psd axp2 = plt.subplot(grid2[4, counter]) ps = [] maxx = 1000 for s, spikes_mat in enumerate(spikes_mats): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) axp2.set_xlim(0, maxx) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='-') pp = np.mean(ps, axis=0) eodf = np.mean(frame_name.eod_fr) names = ['0', '01', '02', '012'] names_here = [names[1]] # extend = True colors_array = ['pink', 'green'] if contrast > 1: name = names_here[0] else: name = 'eodf' freqs, colors_peaks, labels, alphas = chose_all_freq_combos( [], colors_array, np.abs(df_chosen), maxx, eodf, color_eodf='black', name= name, color_stim='pink', color_stim_mult='pink') plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0, colors_peaks, limit=1200, alphas=alphas, extend=extend, ms=18, clip_on=False) axp2.set_xlabel('Frequency [Hz]') if counter != 0: remove_yticks(axp2) else: axp2.set_ylabel(power_spectrum_name()) axps.append(axp2) ############################# # spike_times[nr_example] axi = plt.subplot(grid2[-1, counter]) plt_isis_phaselocking(axi, frame_name, spike_times) if len(axts) > 0: try: axts[0].get_shared_y_axes().join(*axts[0::2]) except: print('axt problem') embed() axts[1].get_shared_y_axes().join(*axts[1::2]) axts[0].get_shared_x_axes().join(*axts) join_y(axfs) join_y(axps) join_x(axps) individual_tag = 'data_' + cell + '_DF_chosen_' + str( df_chosen) + mt_type save_visualization(individual_tag, show) print('plotted') file.close() print('finished examples') embed() def plt_isis_phaselocking(axi, frame_name, spike_times): isis = [] for sp_nr, sp in enumerate(np.array(spike_times)): isis.append( calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr])) axi.hist(np.concatenate(isis), bins=100, color='grey') axi.axvline(1, color='black', linestyle='--', linewidth=0.5) try: axi.show_spines('b') except: pass axi.set_xlabel(isi_xlabel()) def firing_rate_scalebars(axt, length=10): try: axt.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0] axt.yscalebar(-0.02, 0.1, 500, 'Hz', va='bottom', ha='left') except: pass def spikes_umstuelpen(eod, sampling_rate, time): shift_period = 0.005 # period * 2# shifts = np.arange(0, 200 * shift_period, shift_period) time_b = np.arange(0, len(beat) / sampling_rate, 1 / sampling_rate) am_corr = extract_am(beat, time_b, eodf=eod, norm=False, extract='globalmax', kind='linear')[0] len_smoothed, smoothed_trial, all_spikes, maxima, error, spikes_cut, beat_cut, am_corr_cut = create_shifted_spikes( eod, len_smoothed_b, len_smoothed, beat, am_corr, sampling_rate, time_b, time, smoothed, shifts, plot_segment, tranformed_spikes, version=version) _, _, _, _, _ = get_most_similiar_spikes( all_spikes, am_corr_cut, beat_cut, error, maxima, spikes_cut) def plt_beats_modulation_several_with_overview_nice_big_max(limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = True if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') big_adapt = True if big_adapt: frame_big = frame_desired[(frame_desired.contrast > 25) | (frame_desired.contrast_true > 25)] else: frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)] datasets = frame_big.cell.unique() datasets_loaded = datasets[::-1] else: frame = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') datasets_loaded = np.sort(frame.cell.unique())[::-1] datasets = ['2020-10-20-ad-invivo-1', '2020-10-27-ac-invivo-1', '2020-10-29-ai-invivo-1', '2018-09-13-aa-invivo-1', '2020-10-29-ac-invivo-1'] # [,'2020-10-29-ai-invivo-1',] datasets.extend(datasets_loaded) frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') colors = ['red', 'green', 'purple', 'blue'] for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1'] if cell not in cells_exclude: if os.path.exists(path): print('exists') file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import test_rlx test_rlx() if cont2 == True: print('cont2') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) < 1: frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) > 0: dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1] df_name = 'df_sign' df_pos = '' # 'min_df' dfs_all_unique = list(dfs_all_unique) if len(np.unique(np.array(dfs_all_unique))) < 2: df_name = 'df' dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1] dfs_all_unique = list(dfs_all_unique) if len(dfs_all_unique) > 0: if df_pos == 'min_df': try: dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]] except: print('df min') embed() contrasts_all_unique = np.unique(frame_cell.contrast) if len(contrasts_all_unique) > 1: for df_chosen in dfs_all_unique: if np.abs(df_chosen) < 75: if not np.isnan(df_chosen): frame_df = frame_cell[frame_cell[df_name] == df_chosen] mt_types = frame_df.mt_type.unique() for mt_type in mt_types: if 'base' not in mt_type: contrasts_here = [] frame_type = frame_df[ (frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base') gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) scores = ['amp_stim', 'amp_df', 'amp_f0', 'amp_fmax_interval'] # 'stim', 'f0', plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) axs = plt.subplot(grid1[1]) scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_'] if scores[0] in frame_cell.keys(): plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) nr_col = int(len(np.unique(contrasts_all_unique)) - 1) grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col, height_ratios=[1, 1, 0.5, 1, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axps = [] mt_names = frame_type.mt_name.unique() for m, mt_name in enumerate(mt_names): frame_name = frame_type[frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mts = b.multi_tags[mt_name] print(mts.name) name = mts.name contrast = name.split('=')[1].split('%')[0] if contrast not in contrasts_here: print(contrast) if np.isnan(float(contrast)): counter = 0 else: counter = np.where( np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[ 0][0] # +1 try: contrasts_here.append(contrast) except: print('embed problem') embed() try: dfs = [mts.metadata[mts.name]['DeltaF']] * len( mts.positions[:]) except: dfs = mts.metadata['DeltaF'] features, dfs, contrasts, id = get_features_and_info(mts, dfs=dfs, contrasts=contrasts) eod_frs, eod_redo = get_eod_fr_simple(b, names) names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') datas_new.append(cell) try: pass except: print('rlx problem') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] for m in mt_idxs: # range(len(mts.positions[:])) try: eods, _ = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'LocalEOD-1') except: print('eods thing') embed() eods_all.append(eods) eods_g, sampling_rate = link_arrays_eod(b, mts.positions[ :][m], mts.extents[:][ m], 'EOD') v_1, sampling_rate = link_arrays_eod(b, mts.positions[:][ m], mts.extents[:][m], 'V-1') eods_all_g.append(eods_g) V_1.append(v_1) if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) else: pass spike_times.append( (mts.retrieve_data(m, 'Spikes-1')[:] - mts.positions[ m]) * 1000) # - cut print(len(spike_times)) smooth = [] spikes_mats = [] for s in range(len(spike_times)): try: spikes_mat = cr_spikes_mat(spike_times[s] / 1000, sampling_rate, int( mts.extents[:][ mt_idxs[ s]] * sampling_rate)) # time[-1] * sampling_rate except: print('mts prob') embed() spikes_mats.append(spikes_mat) # für den Mean danach schneiden wir das wie das kürzeste try: smooth.append(gaussian_filter( spikes_mat[ 0:int(np.min(mts.extents[:]) * sampling_rate)], sigma=0.0005 * sampling_rate)) except: print('embed problem') embed() try: pass except: print('smoothed thing') embed() plt.suptitle('data ' + cell + ' ' + mts.name) xlim = [] nr_example = 0 ########################################## # time psd axp = plt.subplot(grid2[3, counter]) axp2 = plt.subplot(grid2[4, counter]) ps = [] maxx = 1000 for s, spikes_mat in enumerate(spikes_mats): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) if s == nr_example: color = 'purple' zorder = 100 axp.plot(f, p, color=color, zorder=zorder) eodf = np.mean(frame_name.eod_fr) names = ['0', '01', '02', '012'] names_here = [names[1]] # extend = True colors_array = ['pink', 'green'] freqs, colors_peaks, labels, alphas = chose_all_freq_combos( [], colors_array, df_chosen, maxx, eodf, color_eodf='black', name= names_here[ 0], color_stim='pink', color_stim_mult='pink') plt_peaks_several(freqs, [p], axp, p, f, labels, 0, colors_peaks, alphas=alphas, extend=extend, ms=18, clip_on=True) else: color = 'grey' zorder = 1 axp2.plot(f, p, color=color, zorder=zorder) axp2.set_xlim(0, maxx) axp.set_xlim(0, maxx) remove_xticks(axp) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='--') axp2.set_xlabel('Power [Hz]') if counter != 0: remove_yticks(axp2) axp2.set_ylabel('') if counter != 0: remove_yticks(axp) axp.set_ylabel('') axps.append(axp) axps.append(axp2) ########################################### # time spikes stimulus = eods_all[nr_example] # eods_g + Efield axt = plt.subplot(grid2[0, counter]) axt.set_ylabel('local') time = np.arange(0, len(V_1[nr_example]) / sampling_rate, 1 / sampling_rate) * 1000 axt.plot(time, V_1[nr_example], color='purple', linewidth=0.5) axt.scatter(spike_times[nr_example], np.max(V_1[nr_example]) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), if len(xlim) > 0: axt.set_xlim(xlim) axt.set_title(contrast) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') axts.append(axt) axt = plt.subplot(grid2[1, counter]) axts.append(axt) try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all') embed() eods_am, eod_norm = extract_am(stimulus, time, norm=False) axt.plot(time, eod_norm, color='grey', linewidth=0.5) axt.plot(time, eods_am, color='red') axt.scatter(spike_times[nr_example], np.mean(eod_norm) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), if len(xlim) > 0: axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') ############################# # spike_times[nr_example] axi = plt.subplot(grid2[-1, counter]) isis = [] for sp_nr, sp in enumerate(np.array(spike_times)): isis.append( calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr])) axi.hist(np.concatenate(isis), bins=100) axi.axvline(1, color='grey', linestyle='--') try: axts[0].get_shared_y_axes().join(*axts[0::2]) except: print('axt problem') embed() axts[1].get_shared_y_axes().join(*axts[1::2]) axts[0].get_shared_x_axes().join(*axts) join_y(axps) join_x(axps) individual_tag = 'data ' + cell + '_DF_chosen_' + str( df_chosen) + mt_type save_visualization(individual_tag, show, pdf=True) print('plotted') file.close() print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice_big(limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = True if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') big_adapt = True if big_adapt: frame_big = frame_desired[(frame_desired.contrast > 25) | (frame_desired.contrast_true > 25)] else: frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)] datasets = frame_big.cell.unique() datasets_loaded = datasets[::-1] else: frame = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') datasets_loaded = np.sort(frame.cell.unique())[::-1] datasets = ['2020-10-29-ai-invivo-1', '2020-10-20-ad-invivo-1', '2018-09-13-aa-invivo-1', '2020-10-29-ac-invivo-1'] # [,'2020-10-29-ai-invivo-1',] datasets.extend(datasets_loaded) plt_spectra_compar(datas_new, datasets, nfft, show, add='plt_beats_modulation_several_with_overview_nice_big') print('finished examples') embed() def plt_beats_modulation_several_with_overview_nice(limit=1, duration_exclude=0.45, nfft=int(4096), show=False): # Function to load the experimental data save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all' print(save_name) datas_new = [] old_cells = True if old_cells: # das ist falls ich die alten Datensätze untersuchen will _, _ = find_all_dir_cells() frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv') big_adapt = False if big_adapt: frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)] else: frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)] datasets = frame_big.cell.unique() datasets_loaded = datasets[::-1] else: frame = pd.read_pickle( load_folder_name( 'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') datasets_loaded = np.sort(frame.cell.unique())[::-1] datasets = ['2018-09-13-aa-invivo-1', '2020-10-20-ad-invivo-1'] datasets.extend(datasets_loaded) plt_spectra_compar(datas_new, datasets, nfft, show, add='plt_beats_modulation_several_with_overview_nice') print('finished examples') embed() def plt_spectra_compar(datas_new, datasets, nfft, show, add=''): frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl') colors = ['red', 'green', 'purple', 'blue'] for i, cell in enumerate(datasets): path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1] print(cell) cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1'] if cell not in cells_exclude: if os.path.exists(path): print('exists') file = nix.File.open(path, nix.FileMode.ReadOnly) b = file.blocks[0] cont2 = False names = [] names_dataarrays = [] for stims in b.data_arrays: # this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen if 'sinewave-1_Contrast' in stims.name: names.append(stims.name) names_dataarrays.append(stims.name) 'sinewave''SAM' sam = find_mt(b, 'SAM') sine = find_mt(b, 'sine') if (len(sine) > 0) or (len(sam) > 0): cont2 = True test = False if test: from utils_test import tes_rlx2 tes_rlx2() if cont2 == True: print('cont2') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) < 1: frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl') frame_cell = frame_all[frame_all['cell'] == cell] if len(frame_cell) > 0: dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1] df_name = 'df_sign' df_pos = '' # 'min_df' dfs_all_unique = list(dfs_all_unique) if len(np.unique(np.array(dfs_all_unique))) < 2: df_name = 'df' dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1] dfs_all_unique = list(dfs_all_unique) if len(dfs_all_unique) > 0: if df_pos == 'min_df': try: dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]] except: print('df min') embed() contrasts_all_unique = np.unique(frame_cell.contrast) if len(contrasts_all_unique) > 1: for df_chosen in dfs_all_unique: if not np.isnan(df_chosen): frame_df = frame_cell[frame_cell[df_name] == df_chosen] mt_types = frame_df.mt_type.unique() for mt_type in mt_types: if 'base' not in mt_type: contrasts_here = [] frame_type = frame_df[ (frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base') gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045, right=0.97) # grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2, subplot_spec=gs0[1]) scores = ['amp_stim', 'amp_df', 'amp_f0', 'amp_fmax_interval'] # 'stim', 'f0', plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) axs = plt.subplot(grid1[1]) scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_'] if scores[0] in frame_cell.keys(): plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell, axs) axs.set_xlim(-10, 100) nr_col = int(len(np.unique(contrasts_all_unique)) - 1) grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col, height_ratios=[1, 1, 0.5, 1, 1, 1], wspace=0.2, hspace=0.2, subplot_spec=gs0[0]) axts = [] axps = [] mt_names = frame_type.mt_name.unique() for m, mt_name in enumerate(mt_names): frame_name = frame_type[frame_type.mt_name == mt_name] mt_idxs = list(map(int, np.array(frame_name.mt_idx))) mts = b.multi_tags[mt_name] print(mts.name) name = mts.name contrast = name.split('=')[1].split('%')[0] if contrast not in contrasts_here: print(contrast) if np.isnan(float(contrast)): counter = 0 else: counter = np.where( np.round(contrasts_all_unique, 2) == np.round( float(contrast), 2))[ 0][0] # +1 try: contrasts_here.append(contrast) except: print('embed problem') embed() try: dfs = [mts.metadata[mts.name]['DeltaF']] * len( mts.positions[:]) except: dfs = mts.metadata['DeltaF'] features, dfs, contrasts, id = get_features_and_info(mts, dfs=dfs, contrasts=contrasts) eod_frs, eod_redo = get_eod_fr_simple(b, names) names = [] for stims in b.data_arrays: names.append(stims.name) print(cell + ' Beat calculation') datas_new.append(cell) try: pass except: print('rlx problem') eods_all = [] eods_all_g = [] V_1 = [] spike_times = [] for m in mt_idxs: # range(len(mts.positions[:])) try: eods, _ = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'LocalEOD-1') except: print('eods thing') embed() eods_all.append(eods) eods_g, sampling_rate = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'EOD') v_1, sampling_rate = link_arrays_eod(b, mts.positions[:][m], mts.extents[:][m], 'V-1') eods_all_g.append(eods_g) V_1.append(v_1) if eod_redo == True: p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) else: pass spike_times.append( (mts.retrieve_data(m, 'Spikes-1')[:] - mts.positions[ m]) * 1000) # - cut print(len(spike_times)) smooth = [] spikes_mats = [] for s in range(len(spike_times)): try: spikes_mat = cr_spikes_mat(spike_times[s] / 1000, sampling_rate, int( mts.extents[:][ mt_idxs[ s]] * sampling_rate)) # time[-1] * sampling_rate except: print('mts prob') embed() spikes_mats.append(spikes_mat) # für den Mean danach schneiden wir das wie das kürzeste try: smooth.append(gaussian_filter( spikes_mat[ 0:int(np.min(mts.extents[:]) * sampling_rate)], sigma=0.0005 * sampling_rate)) except: print('embed problem') embed() try: pass except: print('smoothed thing') embed() plt.suptitle('data ' + cell + ' ' + mts.name) xlim = [0, 40] nr_example = 0 ########################################## # time psd axp = plt.subplot(grid2[3, counter]) axp2 = plt.subplot(grid2[4, counter]) ps = [] maxx = 1000 for s, spikes_mat in enumerate(spikes_mats): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) if s == nr_example: color = 'purple' zorder = 100 axp.plot(f, p, color=color, zorder=zorder) eodf = np.mean(frame_name.eod_fr) names = ['0', '01', '02', '012'] names_here = [names[1]] # extend = True colors_array = ['pink', 'green'] freqs, colors_peaks, labels, alphas = chose_all_freq_combos( [], colors_array, df_chosen, maxx, eodf, color_eodf='black', name= names_here[ 0], color_stim='pink', color_stim_mult='pink') plt_peaks_several(freqs, [p], axp, p, f, labels, 0, colors_peaks, alphas=alphas, extend=extend, ms=18, clip_on=True) else: color = 'grey' zorder = 1 axp2.plot(f, p, color=color, zorder=zorder) axp2.set_xlim(0, maxx) axp.set_xlim(0, maxx) remove_xticks(axp) axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='--') axp2.set_xlabel('Power [Hz]') if counter != 0: remove_yticks(axp2) axp2.set_ylabel('') if counter != 0: remove_yticks(axp) axp.set_ylabel('') axps.append(axp) axps.append(axp2) ########################################### # time spikes stimulus = eods_all[nr_example] # eods_g + Efield axt = plt.subplot(grid2[0, counter]) axt.set_ylabel('local') time = np.arange(0, len(V_1[nr_example]) / sampling_rate, 1 / sampling_rate) * 1000 axt.plot(time, V_1[nr_example], color='purple', linewidth=0.5) axt.scatter(spike_times[nr_example], np.max(V_1[nr_example]) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), axt.set_xlim(xlim) axt.set_title(contrast) remove_xticks(axt) if counter != 0: remove_yticks(axt) axt.set_ylabel('') axts.append(axt) axt = plt.subplot(grid2[1, counter]) axts.append(axt) try: time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 except: print('time all') embed() eods_am, eod_norm = extract_am(stimulus, time, norm=False) axt.plot(time, eod_norm, color='grey', linewidth=0.5) axt.plot(time, eods_am, color='red') axt.scatter(spike_times[nr_example], np.mean(eod_norm) * np.ones( len(spike_times[nr_example])) , color='black', s=10, marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]), axt.set_xlim(xlim) axt.set_xlabel('Time [ms]') if counter != 0: remove_yticks(axt) axt.set_ylabel('') axi = plt.subplot(grid2[-1, counter]) isis = [] for sp_nr, sp in enumerate(np.array(spike_times)): isis.append( calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr])) axi.hist(np.concatenate(isis), bins=100) axi.axvline(1, color='grey', linestyle='--') try: axts[0].get_shared_y_axes().join(*axts[0::2]) except: print('axt problem') embed() axts[1].get_shared_y_axes().join(*axts[1::2]) axts[0].get_shared_x_axes().join(*axts) join_y(axps) join_x(axps) individual_tag = 'data ' + cell + '_DF_chosen_' + str( df_chosen) + mt_type save_visualization(add + individual_tag, show) print('plotted') file.close() def get_eod_fr_simple(b, names): if 'sinewave-1_EOD Rate' in names: eod_frs = b.data_arrays['sinewave-1_EOD Rate'][:] eod_redo = False else: eod_frs = b.metadata['Recording']['Subject']['EOD Frequency'] eod_redo = True return eod_frs, eod_redo def plt_response(ax, sampling_rate, spike_times, smooth_mean, spikes_mats, counter, stimulus, extract=True): ###################### # plt local eod xlim = (0, 200) ax[0, counter].set_ylabel('mV') ax[0, counter].set_title('local') time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000 if extract: eods_am, eod_norm = extract_am(stimulus, time, norm=False) ax[0, counter].plot(time, eod_norm) ax[0, counter].plot(time, eods_am, color='red') else: ax[0, counter].plot(time, stimulus, color='red') ax[0, counter].set_xlim(xlim) color = 'grey' ##################### # plt smpikes mat ax[1, counter].eventplot(spike_times, color=color) # s, np.ones(len(spike_times)), ax[1, counter].set_xlim(xlim) ax[1, counter].set_ylabel('Run nr') remove_xticks(ax[0, counter]) remove_xticks(ax[1, counter]) remove_xticks(ax[2, counter]) try: ax[2, counter].plot(np.arange(0, len(smooth_mean) / 40000, 1 / 40000) * 1000, smooth_mean, color=color) except: print('smooth problem') embed() ax[2, counter].set_xlim(xlim) ax[2, counter].set_ylabel('FR [Hz]') ps = [] for spikes_mat in spikes_mats: p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=2 ** 13, noverlap=2 ** 13 / 2) ps.append(p) ax[3, counter].plot(f, p, color=color) ax[3, counter].set_xlim(0, 1000) remove_xticks(ax[3, counter]) ax[3, counter].plot(f, np.mean(ps, axis=0), color='black') ax[4, counter].plot(f, np.mean(ps, axis=0), color='black') ax[4, counter].set_ylabel('[Hz]') ax[4, counter].set_xlim(0, 1000) ax[4, counter].set_xlabel('F [Hz]') if counter != 0: remove_yticks(ax[1, counter]) remove_yticks(ax[0, counter]) remove_yticks(ax[2, counter]) remove_yticks(ax[3, counter]) remove_yticks(ax[4, counter]) ax[1, counter].set_ylabel('') ax[0, counter].set_ylabel('') ax[2, counter].set_ylabel('') ax[3, counter].set_ylabel('') ax[4, counter].set_ylabel('') def plt_cocktailparty_lines(ax, frame_df): frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2']) cs = {} means = {} scores_data = ['amp_f0_01_original', 'amp_f0_012_original', 'amp_f0_02_original', 'amp_f0_0_original', 'amp_B1_01_original', 'amp_B1_012_original', 'amp_B2_02_original', 'amp_B2_012_original', ] colors = ['green', 'purple', 'orange', 'black', 'green', 'blue', 'orange', 'red'] linestyles = ['--', '--', '--', '--', '-', '-', '-', '-'] for sss, score in enumerate(scores_data): ax[sss].plot(np.sort(frame_df_mean['c1']), frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])], color=colors[sss], linestyle=linestyles[ sss]) # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big) if sss not in means.keys(): means[sss] = [] cs[sss] = [] ax[sss].set_ylabel(score.replace('_mean', '').replace('amp_', '') + '[Hz]', fontsize=8) ax[sss].set_xlabel('Contrast small') ax[sss].set_xlabel('Contrast small') def get_dfs_and_contrasts_from_calccocktailparty(cell, frame_data): frame_data_cell = frame_data[(frame_data['cell'] == cell)] c1_unique = np.sort(frame_data_cell.c1.unique())[::-1] c1_unique_big = c1_unique[c1_unique > 7] c2_unique = np.sort(frame_data_cell.c2.unique())[::-1] c2_unique_big = c2_unique[c2_unique > 7] DF1s = np.unique(np.round(frame_data_cell.m1, 2)) DF2s = np.unique(np.round(frame_data_cell.m2, 2)) return DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique def plt_stim_response_saturation(a, arrays_here, arrays_sp, arrays_st, arrays_time, axes, axps, axts, colors_array_here, f, f_counter, grid_ll, names, nfft, sampling, time, freqs=[50], colors_peaks=['green', 'red'], xlim=[1, 1.12]): grid_pt = gridspec.GridSpecFromSubplotSpec(3, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll) # hspace=0.4,wspace=0.2,len(chirps) ############################# axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, arrays_sp[a][0], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) ############################# axt = plt.subplot(grid_pt[1]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) ############################# axp = plt.subplot(grid_pt[2]) axps.append(axp) pp, ff = ml.psd(arrays_here[a][0] - np.mean(arrays_here[a][0]), Fs=sampling, NFFT=nfft, noverlap=nfft // 2) pp = log_calc_psd('log', pp, np.max(pp)) plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs, colors_peaks=colors_peaks) return axp, axt def plt_stim_saturation(a, arrays_sp, arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=[1, 1.12]): if f != 0: remove_yticks(axe) if a != len(arrays_st) - 1: remove_xticks(axe) if f_counter == 0: axe.set_ylabel(names[a]) try: axe.plot(time, arrays_st[a], color=colors_array_here[a], linewidth=0.5) # colors_contrasts[c_nn] except: print('axe something') embed() axe.set_xlim(xlim) axe.show_spines('') spikes_in_vmem(arrays_sp, arrays_st[a], axe, type_here='stim') def plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=[1, 1.12]): if f != 0: remove_yticks(axt) if a != len(arrays_time) - 1: remove_xticks(axt) try: axt.plot(time[(time < xlim[1]) & (time > xlim[0])], arrays_time[a][(time < xlim[1]) & (time > xlim[0])], color=colors_array_here[a], clip_on=False) # colors_contrasts[c_nn] except: axt.plot(time[(time < xlim[1]) & (time > xlim[0])], arrays_time[a][0][(time < xlim[1]) & (time > xlim[0])], color=colors_array_here[a], clip_on=False) # colors_contrasts[c_nn] axt.set_xlim(xlim) axt.show_spines('') spikes_in_vmem(arrays_sp[a][0], arrays_time[a], axt, type_here='vmem') def plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=[50, 50], colors_peaks=['blue', 'red'], xlim=(0, 300), markeredgecolor=[], labels=['DF1', 'DF2', 'DF1', 'DF2', 'DF1', 'DF2', 'DF1', 'DF2']): axp.plot(ff[ff < xlim[1]], pp[ff < xlim[1]], color=colors_array_here[a]) axp.set_xlim(xlim) plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, markeredgecolors=markeredgecolor) def vary_contrasts(freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='', show=True): runs = 1 n = 1 dev = 0.0005 ############################################# # plot a single ROC Curve for the model! # das aus dem Lissabon talk und das was wir für Jörg verwenden werden # also wir wollen hier viele Kontraste und einige Frequenzen # das will ich noch für verschiedene Frequenzen und Kontraste default_settings() # ts=13, ls=13, fs=13, lw = 0.7 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] a_f1s = [0.03] # np.logspace(np.log10(0.0001), np.log10(1), 25) min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [1] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_settings(column=2, length=8.5) for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str( stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) c_nrs = [0.0002, 0.05, 0.5] grid0 = gridspec.GridSpec(1, 1, bottom=0.05, top=0.92, left=0.09, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.04, hspace=0.27, height_ratios=[1, 2.5], subplot_spec=grid0[0]) # grid_u = gridspec.GridSpecFromSubplotSpec(1, len(freqs), hspace=0.7, wspace=0.1, subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps) grid_l = gridspec.GridSpecFromSubplotSpec(1, len(freqs), hspace=0.7, wspace=0.1, subplot_spec=grid00[1]) # hspace=0.4,wspace=0.2,len(chirps) ################################################################# # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 grid_ll = gridspec.GridSpecFromSubplotSpec(3, len(c_nrs), hspace=0.2, wspace=0.2, subplot_spec=grid_l[ f]) # hspace=0.4,wspace=0.2,len(chirps) frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] scores = ['amp_B1_01_mean', 'amp_B1_012_mean', 'amp_B2_02_mean', 'amp_B2_012_mean'] # 'amp_B1+B2_012_mean', colors = ['green', 'blue', 'orange', 'red', 'grey'] colors_array = ['grey', 'green', 'orange', 'purple'] linestyles = ['-', '--', '-', '--', '--'] alpha = [1, 1, 1, 1, 1] print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 ax_u1 = plt.subplot(grid_u[f, 0]) ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2, scores=scores, colors=colors, linestyles=linestyles, alpha=alpha, sum=False, B_replace='F') if f != 0: print('hi') else: ax_u1.set_ylabel('Hz') plt.suptitle(cell_here + ' DF1=' + str(freq1) + ' DF2=' + str(freq2)) axts = [] axps = [] c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs, cell=cell_here) for c_nn, c_nr in enumerate(c_nrs): ax_u1.scatter(c_nrs, np.zeros(len(c_nrs)), color='black', marker='^', clip_on=False) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n, reshuffled=reshuffled, min_amps=min_amps) time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) arrays_time = arrays[1::] # v_mems[1::] arrays_here = arrays[1::] colors_array_here = colors_array[1::] for a in range(len(arrays_here)): grid_pt = gridspec.GridSpecFromSubplotSpec(2, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[ a, c_nn]) # hspace=0.4,wspace=0.2,len(chirps) axt = plt.subplot(grid_pt[0]) axts.append(axt) if f != 0: remove_yticks(axt) if a != len(arrays_time) - 1: remove_xticks(axt) if f_counter == 0: axt.set_ylabel(names[a]) if a == 0: axt.set_title(' c1=' + str(a_f1s[0]) + ' c2=' + str(a_f2s[0])) axt.plot(time, arrays_time[a][0], color=colors_array_here[a]) # colors_contrasts[c_nn] axt.set_xlim(1, 1.12) ############################# axp = plt.subplot(grid_pt[1]) axps.append(axp) pp, ff = ml.psd(arrays_here[a][0] - np.mean(arrays_here[a][0]), Fs=sampling, NFFT=nfft, noverlap=nfft // 2) axp.plot(ff, pp, color=colors_array_here[a]) # colors_contrasts[c_nn] axp.set_xlim(0, 300) if a != 2: colors_peaks = [colors_array[1], colors_array[2]] else: colors_peaks = ['blue', 'red'] plt_peaks_several([freq1, np.abs(freq2)], [pp], axp, pp, ff, ['DF1', 'DF2'], 0, colors_peaks) if a != 2: remove_xticks(axp) if c_nn != 0: remove_yticks(axt) remove_yticks(axp) axt.set_xlabel('Time [s]') axp.set_xlabel('Frequency [Hz]') f_counter += 1 axts_all.extend(axts) axps_all.extend(axps) ax_us[0].legend(loc=(-0.07, 1), ncol=6) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_x(ax_us) join_y(ax_us) save_visualization(cell_here, show) def spikes_in_vmem(arrays_sp, arrays_time, axt, type_here='vmem'): if type_here == 'vmem': axt.eventplot(arrays_sp, lineoffsets=np.max(arrays_time)) # * np.ones(len(arrays_sp))) else: try: axt.eventplot(arrays_sp, lineoffsets=np.mean(arrays_time)) # * np.ones(len(arrays_sp))) except: print('axt something') embed() def get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig): new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] sorted = np.argsort(np.abs(dfs)) new_f2_tuple = new_f2_tuple[sorted] f2s = [tup[1] for tup in new_f2_tuple] f2s = np.sort(f2s) frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)] frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell) def find_mt_type(mts): if 'chirp' in mts.name: mt_type = 'chirp' elif 'SAM' in mts.name: mt_type = 'SAM' elif 'sine' in mts.name: mt_type = 'sine' elif find_gwn(mts): mt_type = 'stim' elif 'three' in mts.name: mt_type = 'three' return mt_type def choice_specific_indices(contrasts, negativ='negativ', units=2 * 7, cut_val=2): next_step = int(np.round(len(contrasts) / units)) if next_step == 0: next_step = 1 if negativ == 'negativ': indeces_show = np.argsort(contrasts)[0:int(len(contrasts) / cut_val)][0::next_step][::-1] contrasts_show = np.sort(contrasts)[0:int(len(contrasts) / cut_val)][0::next_step][::-1] elif negativ == 'positiv': try: indeces_show = np.argsort(contrasts)[::-1][0:int(len(contrasts) / cut_val)][ 0::next_step][::-1] except: print('positiv something') embed() contrasts_show = np.sort(contrasts)[::-1][0:int(len(contrasts) / cut_val)][ 0::next_step][::-1] elif negativ == 'highest': indeces_show = np.argsort(contrasts)[::-1][0:int(units / cut_val)][::-1] contrasts_show = np.sort(contrasts)[::-1][0:int(units / cut_val)][::-1] return contrasts_show, indeces_show def spike_times_cocktailparty(b, delay, mt, mt_nr, load_eod_array='LocalEOD-1'): timepoint = time.time() try: eod_mt, spikes_mt = load_eod_for_three(b, delay, mt, mt_nr, load_eod_array=load_eod_array) except: print('problem') embed() time_eod = np.arange(0, len(eod_mt) / 40000, 1 / 40000) - delay time_laod_eods = time.time() - timepoint return eod_mt, spikes_mt, time_eod, time_laod_eods, timepoint def load_eod_for_three(b, delay, mt, mt_nr, load_eod_array='LocalEOD-1'): eod_mt, spikes_mt, sampling = link_arrays(b, first=mt.positions[:][mt_nr] - delay, second=mt.extents[:][mt_nr] + delay, minus_spikes=mt.positions[:][mt_nr], load_eod_array=load_eod_array) return eod_mt, spikes_mt def diagonal_points(): # global combis combis = {'off1': (0.5, 0.67), 'test_data_cell_2022-01-05-aa-invivo-1': (0.27, 1.27,), 'B1-B2_diagonal': (0.27, 1.27,), 'diagonal1': (1 / 3, 2 / 3), 'B1+B2_diagonal': (1 / 4, 3 / 4), 'B1+B2_diagonal2': (0.27, 0.73), 'B1+B2_diagonal3': (0.3, 0.7), 'B1-B2_diagonal3': (0.3, 1.3), 'B1+B2_diagonal31': (0.31, 0.69), 'B1+B2_diagonal32': (0.32, 0.68), 'B1+B2_diagonal33': (0.33, 0.67), 'B1+B2_diagonal_plus_0.2c1': ((1 / 3) + 0.2, 2 / 3), 'Half_Fr_c1': (0.5, 0.3), 'Half_Fr_c2': (0.3, 0.5), 'diagonal2': (2 / 3, 1 / 3,), 'diagonal3': (0.1, 0.9), 'vertical1': (1, 0.7), 'vertical4': (0.8, 0.6), 'vertical5': (0.8, 0.55), 'vertical2': (1, 1.05), 'vertical3': (1 + (1.167 - 1.1644) / 0.1644, 1 + (1.18 - 1.1644) / 0.1644), 'vertical6': (0.4, 1), 'vertical6': (0.4, 1), 'horizontal': (0.8, 1), 'inside': (1 / 2, 2 / 3), 'outside': (1.2, 2 / 3) } return combis def plt_ROC_model_w_female_square_nonlin(frame_names=[], female='wo_female', reshuffled='reshuffled', datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False, plus_q='minus', freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3, way='absolut', stimulus_length=0.5, runs=3, trials_nr=500, cells=[], show=False, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, gain=1, talk=True, fish_jammer='Alepto', us_name=''): if talk: plt.rcParams['lines.linewidth'] = 1 try: model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") except: embed() print('still some model something') if len(cells) < 1: cells = len(model_cells) for cell_here in cells: # sachen die ich variieren will ########################################### single_waves = ['_SeveralWave_'] # , '_SingleWave_'] ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: for a_f1 in a_f1s: a_frs = [1] titles_amp = ['base eodf'] # ,'baseline to Zero',] for a, a_fr in enumerate(a_frs): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0() # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat variant = 'sinz' # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) spikes_base = [[]] * trials_nr colors_w, colors_wo, color0, color01, color02, color012 = colors_cocktailparty_all() default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(12.43)) default_ticks_talks() fig = plt.figure() grid = gridspec.GridSpec(1, 2, wspace=0.35, width_ratios=[0.8, 1.6, ], hspace=0.5, left=0.08, top=0.95, bottom=0.12, right=0.96) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8, grid0 = gridspec.GridSpecFromSubplotSpec(5, 2, wspace=0.18, hspace=0.12, subplot_spec=grid[1], height_ratios=[1, 0.6, 1, 1, 1.25]) # ,0.4,1.2 for run in range(runs): print(run) t1 = time.time() for t in range(trials_nr): stimulus = eod_fish_r stimulus_base = eod_fish_r if 'Zero' in titles_amp[a]: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=power_here, power_alpha=alpha, power_nr=n, reshuffle=reshuffled, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 if printing: print('Baseline time' + str(time.time() - t1)) base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) fr = np.mean(base_cut) if 'diagonal' in diagonal: two_third_fr = fr * freq2_ratio freq1_ratio = (1 - freq2_ratio) third_fr = fr * freq1_ratio else: two_third_fr = fr * freq2_ratio third_fr = fr * freq1_ratio if plus_q == 'minus': two_third_fr = -two_third_fr third_fr = -third_fr freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr, freqs1 = [ eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ] sampling_rate = 1 / deltat base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev) fr = np.mean(base_cut) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) isi = np.diff(spikes_base[0]) cv0 = np.std(isi) / np.mean(isi) for ff, freq1 in enumerate(freqs1): freq1 = [freq1] freq2 = [freqs2[ff]] t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev) if printing: print('Generation process' + str(time.time() - t1)) ################################## # power spectrum array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] t_off = 10 position_diff = 0 results_diff = pd.DataFrame() results_diff['f1'] = freq1 results_diff['f2'] = freq2 results_diff['f0'] = eod_fr trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling) if run == 0: pass else: pass grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, hspace=0.4, subplot_spec=grid[0]) ax_ROC = plt.subplot(grid1[0]) ax_nonlin = plt.subplot(grid1[1]) colors_wo = ['orange', 'orange', 'orange'] colors_w = ['green', 'green', 'green'] xlim = core_xlim_dist_roc() plt_ROC_nonlin(xlim, frame_names, ax_ROC, ax_nonlin, cells, colors_wo, colors_w) ax_nonlin.set_xlabel(core_distance_label()) ax_ROC.set_xlabel(core_distance_label()) if run == 0: plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012, spikes_base, mat_base, mat05_01, mat05_012, mat05_02, color02, color012, a_f2, trials, sampling, a_f1, fr, female, color01, color0, grid0, eod_fr, freq2, freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01, stimulus_base, time_array - time_array[0], vlin=False, carrier=True) axs = plt_power_spectrum(grid0, color01, color02, color012, color0, fr, results_diff, female, nfft, smoothed012, smoothed01, smoothed02, smoothed0, sampling_rate, mult_val=0.15, add_to=195, wierd_charing=False) _, _ = plt.gca().get_legend_handles_labels() remove_yticks(axs[1]) # ax[6 + 1] join_y(axs) ax = fig.axes ax = ax[1::] ax[4 + 1].set_ylabel('Firing Rate [Hz]') ax[4 + 1].set_xlabel('Time [ms]') ax[5 + 1].set_xlabel('Time [ms]') ax[6 + 1].set_xlabel('Frequency [Hz]') ax[7 + 1].set_xlabel('Frequency [Hz]') ax[6 + 1].set_ylabel('Power [Hz]') for aa, ax_here in enumerate(ax[2:5]): ax_here.set_xticks([]) for aa, ax_here in enumerate(ax[1::]): if aa not in np.arange(0, 2, 2): pass else: ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2]) individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str( trials_nr) + '_stimulus_length_' + str( stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr) fig = plt.gcf() fig.tag([fig.axes[0], fig.axes[2], fig.axes[3], fig.axes[1]], xoffs=-3.5) save_visualization(individual_tag, show, pdf=pdf, counter_contrast=0, savename='') def default_model0(): f1 = 0 f2 = 0 sampling_factor = '' phaseshift_fr = 0 cell_recording = '' mimick = 'no' zeros = 'zeros' fish_morph_harmonics_var = 'harmonic' fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus'] fish_receiver = 'Alepto' # phase_right = '_phaseright_' adapt_offset = 'adaptoffsetallall2' constant_reduction = '' n = 1 lower_tol = 0.995 upper_tol = 1.005 SAM = '' # , damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35 damping_type = '' exponential = '' dent_tau_change = 1 return SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros def plt_ROC_nonlin(xlim, frame_names, ax_ROC, ax_nonlin, cells, colors_wo, colors_w): for c, cell in enumerate(cells): for f, frame_name in enumerate(frame_names): path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv' if os.path.exists(path): frame = pd.read_csv(path) path_ref = load_folder_name( 'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) frame_ref = frame_ref.sort_values(by='cv_0') _, _ = find_row_col(cells, row=4) frame_cell = frame[frame.cell == cell] label = ['with female', 'CLS: 100n', 'LS: 1000n'] label2 = ['without female', 'CLS: 100n', 'LS: 1000n'] for ax in [ax_ROC]: if len(frame_cell) > 0: plt_area_between(frame_cell, ax, ax, colors_wo, colors_w, f, labels_without_female=label2[f], labels_with_female=label[f]) ax.set_xlim(xlim) ax.set_ylim(0, 0.5) ax.legend(loc=(0.5, 0.8)) plt_nonlin_line(ax_nonlin, cell, 0, frame_cell, xlim) def plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012, spikes_base, mat_base, mat05_01, mat05_012, mat05_02, color02, color012, a_f2, trials, sampling, a_f1, fr, female, color01, color0, grid0, eod_fr, freq2, freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01, stimulus_base, time_array, carrier=False, spike_events=True, vlin=True, short_title=True): beat2 = freq2 - eod_fr beat1 = freq1 - eod_fr ############################################# eod_interp_base, _, = extract_am(stimulus_base, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if len(np.shape(stimulus_01)) > 1: stimulus_01_here = stimulus_01[0] else: stimulus_01_here = stimulus_01 # [0] eod_interp_01, eod_norm = extract_am(stimulus_01_here, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if len(np.shape(stimulus_02)) > 1: stimulus_02_here = stimulus_02[0] else: stimulus_02_here = stimulus_02 # [0] eod_interp_02, eod_norm = extract_am(stimulus_02_here, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) if len(np.shape(stimulus_012)) > 1: stimulus_012_here = stimulus_012[0] else: stimulus_012_here = stimulus_012 # [0] eod_interp, eod_norm = extract_am(stimulus_012_here, time_array, sampling=sampling_rate, eodf=eod_fr, emb=False, extract='', norm=False) start = 0 # 0.2 time_array = time_array - start # lim_shift xlim = (0, 0.102 * 1000) counter = 0 ax = plt_stimulus_ROC(eod_interp_base, stimulus_base, a_f1, a_f2, beat1, beat2, carrier, color0, color01, color012, color02, eod_interp, eod_interp_01, eod_interp_02, female, fr, freq1_ratio, freq2_ratio, grid0, short_title, stimulus_01, stimulus_012, stimulus_02, time_array, xlim, counter=counter) counter += 1 ############################################# # spikes_012 if spike_events: plt_eventplot_ROC(ax, female, grid0, spikes_01, spikes_012, spikes_02, spikes_base, xlim, counter=counter) counter += 1 ############################################# # smoothed plt_firingrate_ROC(female, grid0, mat05_01, mat05_012, mat05_02, mat_base, sampling, t_off, time_array, trials, vlin, xlim, counter=counter) counter += 1 def plt_firingrate_ROC(female, grid0, mat05_01, mat05_012, mat05_02, mat_base, sampling, t_off, time_array, trials, vlin, xlim, counter=2): color_mat = 'black' if 'wo_female' in female: ax = plt.subplot(grid0[counter, 0]) ax.set_xlim(xlim) plt.plot(time_array * 1000, mat_base, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) plt.plot(time_array * 1000, mat05_01, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) elif 'base_female' in female: ax = plt.subplot(grid0[2, 0]) ax.set_xlim(xlim) plt.plot(time_array * 1000, mat_base, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) plt.plot(time_array * 1000, mat05_02, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) else: ax = plt.subplot(grid0[counter, 0]) ax.set_xlim(xlim) plt.plot(time_array * 1000, mat05_02, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) plt.plot(time_array * 1000, mat05_012, color=color_mat) if vlin: plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5) plt.axvline([0], color='grey', linestyle='--', linewidth=0.5) def plt_eventplot_ROC(ax, female, grid0, spikes_01, spikes_012, spikes_02, spikes_base, xlim, counter=1): if 'wo_female' in female: ax = plt.subplot(grid0[counter, 0], sharex=ax) ax.set_xlim(xlim) ax.spines['bottom'].set_visible(False) plt.eventplot(np.array(spikes_base) * 1000, color='black') ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) ax.spines['bottom'].set_visible(False) plt.eventplot(np.array(spikes_01) * 1000, color='black') elif 'base_female' in female: ax = plt.subplot(grid0[counter, 0], sharex=ax) ax.set_xlim(xlim) ax.spines['bottom'].set_visible(False) plt.eventplot(np.array(spikes_base) * 1000, color='black') ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) ax.spines['bottom'].set_visible(False) plt.eventplot(np.array(spikes_02) * 1000, color='black') else: ax = plt.subplot(grid0[counter, 0]) ax.spines['bottom'].set_visible(False) ax.set_xlim(xlim) plt.eventplot(np.array(spikes_02) * 1000, color='black') ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) ax.spines['bottom'].set_visible(False) plt.eventplot(np.array(spikes_012) * 1000, color='black') def plt_stimulus_ROC(eod_interp_base, stimulus_base, a_f1, a_f2, beat1, beat2, carrier, color0, color01, color012, color02, eod_interp, eod_interp_01, eod_interp_02, female, fr, freq1_ratio, freq2_ratio, grid0, short_title, stimulus_01, stimulus_012, stimulus_02, time_array, xlim, counter=0): if 'wo_female' in female: ax = plt.subplot(grid0[counter, 0]) plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier) if short_title: plt.title('Baseline', color=color0) else: plt.title('Base: 0 \n $fr=$' + str(np.round(fr)) + 'Hz', color=color0) ax = plt.subplot(grid0[counter, 1], sharex=ax) plt_base(ax, xlim, time_array, eod_interp_01, color01, stimulus_01, carrier) remove_yticks(ax) if short_title: plt.title('Intruder', color=color01) else: plt.title('Intruder: 01 \n $f=$' + str(np.round(beat1[0])) + 'Hz ' + ' $c_{1}=$' + str( a_f1 * 100) + '$\%$' + '\n' + r' $\frac{f}{fr}=$' + str(np.round(freq1_ratio, 2)), color=color01) elif 'base_female' in female: ax = plt.subplot(grid0[counter, 0]) # plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier) if short_title: ax.set_title('Baseline', color=color0) else: ax.set_title('Base: 0 \n $fr=$' + str(np.round(fr)) + 'Hz', color=color0) ax = plt.subplot(grid0[counter, 1], sharex=ax) remove_yticks(ax) plt_base(ax, xlim, time_array, eod_interp_02, color02, stimulus_02, carrier) if short_title: ax.set_title('Female', color=color02) else: ax.set_title('Female: 02 \n $f=$' + str(np.round(beat2[0])) + ' Hz' + ' $c_{2}$ ' + str( a_f2 * 100) + '$\%$ ' + '\n' + r'$\frac{f}{fr}={len(folder)}$' + str(np.round(freq2_ratio, 2)), color=color02) else: ax = plt.subplot(grid0[counter, 0]) plt_base(ax, xlim, time_array, eod_interp_02, color02, stimulus_02, carrier) if short_title: ax.set_title('Female', color=color02) else: ax.set_title('Female: 02 \n $f=$' + str(np.round(beat2[0])) + ' Hz' + ' $c_{2}$ ' + str( a_f2 * 100) + '$\%$ ' + '\n' + r'$\frac{f}{fr}={len(folder)}$' + str(np.round(freq2_ratio, 2)), color=color02) # eod interp ax = plt.subplot(grid0[counter, 1], sharex=ax) plt_base(ax, xlim, time_array, eod_interp, color012, stimulus_012, carrier) remove_yticks(ax) if short_title: ax.set_title('Female + Intruder', color=color012) else: ax.set_title('Fem. + Int.: 012 \n $f=$' + str(np.round(beat1[0] + beat2[0])) + ' Hz', color=color012) return ax def plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier): ax.set_xlim(xlim) ax.plot(time_array * 1000, eod_interp_base, color=color0) if carrier: if len(np.shape(stimulus_base)) > 1: stimulus_base_here = stimulus_base[0] else: stimulus_base_here = stimulus_base ax.plot(time_array * 1000, stimulus_base_here, color='grey', linewidth=0.5) ax.set_ylim(-1.15, 1.15) else: ax.set_ylim([0.85, 1.15]) ax.spines['bottom'].set_visible(False) def plt_power_spectrum2(grid0, color01, color02, color012, color0, fr, results_diff, female, nfft, smoothed012, smoothed01, smoothed02, smoothed0, sampling_rate, counter=4, add_to=70, mult_val=0.125, wierd_charing=True, log = ''): p0, p02, p01, p012, fs = calc_ps(nfft, smoothed012, smoothed01, smoothed02, smoothed0, sampling_rate=sampling_rate, log = log, xlim = xlim_ROC_talk2()) DF1 = np.abs(results_diff.f1.iloc[-1] - results_diff.f0.iloc[-1]) DF2 = np.abs(results_diff.f2.iloc[-1] - results_diff.f0.iloc[-1]) if 'wo_female' in female: p_arrays = [p0, p01] else: four = False if four: p_arrays = [p02, p012] freqs_all = [[np.abs(DF2), np.abs(DF2) * 2], [np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1), np.abs(DF1) + np.abs(DF2), (np.abs(DF1) + np.abs(DF2)) * 2, fr, fr * 2, (np.abs(DF1) + np.abs(DF2) * 2), ]] # np.abs(np.abs(DF1) - np.abs(DF2)), color0122 = color_sumpeak() colors_all = [[color02, color02], [color02, color02, color01, color012, color012, color0, color0, color0122, ]] # color01_2, labels_all = [['DF2', 'DF2 H1'], [r'$\Delta \mathrm{f_{Female}}$', '', r'$\Delta \mathrm{f_{Intruder}}$', sum_intruder_core(), '', r'$\mathrm{f'+basename()+'}$', '', r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$']] # '$|Intruder-Female|$', else: p_arrays = [p02, p012] freqs_all = [[np.abs(DF2), np.abs(DF2) * 2], [np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1), np.abs(DF1) + np.abs(DF2), (np.abs(DF1) + np.abs(DF2)) * 2, ]] # color0122, (np.abs(DF1) + np.abs(DF2) * 2)fr, fr * 2, np.abs(np.abs(DF1) - np.abs(DF2)), color0122 = color_sumpeak() colors_all = [[color02, color02], [color02, color02, color01, color012, color012, color0, color0, ]] # color01_2, '', labels_all = [[r'$\Delta \mathrm{f_{Female}}$', r'2 $\Delta \mathrm{f_{Female}}$'], [r'$\Delta \mathrm{f_{Female}}$', '', r'$\Delta \mathrm{f_{Intruder}}$', r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$', '', '', '', ]] # r'$\mathrm{f'+basename()+'}$''$|Intruder-Female|$',r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$' axs = plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all, mult_val, p0, p_arrays, wierd_charing, xlim=xlim_ROC_talk2(), counter=counter, add_texts=[0, 0, 0, 0, 0, 0, 0, 0, 500, 0, 0, 0, 0, 0], log = log, text_extra=True) return axs def plt_power_spectrum(grid0, color01, color02, color012, color0, fr, results_diff, female, nfft, smoothed012, smoothed01, smoothed02, smoothed0, sampling_rate, counter=4, add_to=70, mult_val=0.125, wierd_charing=True): p0, p02, p01, p012, fs = calc_ps(nfft, smoothed012, smoothed01, smoothed02, smoothed0, sampling_rate=sampling_rate) DF1 = np.abs(results_diff.f1.iloc[-1] - results_diff.f0.iloc[-1]) DF2 = np.abs(results_diff.f2.iloc[-1] - results_diff.f0.iloc[-1]) if 'wo_female' in female: p_arrays = [p0, p01] else: four = False if four: p_arrays = [p02, p012] freqs_all = [[np.abs(DF2), np.abs(DF2) * 2], [np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1), np.abs(DF1) + np.abs(DF2), (np.abs(DF1) + np.abs(DF2)) * 2, fr, fr * 2, (np.abs(DF1) + np.abs(DF2) * 2), ]] # np.abs(np.abs(DF1) - np.abs(DF2)), color0122 = color_sumpeak() colors_all = [[color02, color02], [color02, color02, color01, color012, color012, color0, color0, color0122, ]] # color01_2, labels_all = [['DF2', 'DF2 H1'], [r'$\Delta \mathrm{f_{Female}}$', '', r'$\Delta \mathrm{f_{Intruder}}$', sum_intruder_core(), '', r'$\mathrm{f'+basename()+'}$', '', r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$']] # '$|Intruder-Female|$', else: p_arrays = [p02, p012] freqs_all = [[np.abs(DF2), np.abs(DF2) * 2], [np.abs(DF2), np.abs(DF2) * 2, (np.abs(DF1) + np.abs(DF2) * 2), np.abs(DF1), np.abs(DF1) + np.abs(DF2), (np.abs(DF1) + np.abs(DF2)) * 2, ]] # fr, fr * 2, np.abs(np.abs(DF1) - np.abs(DF2)), color0122 = color_sumpeak() colors_all = [[color02, color02], [color02, color02, color0122, color01, color012, color012, color0, color0, ]] # color01_2, labels_all = [[r'$\Delta \mathrm{f_{Female}}$', r'2 $\Delta \mathrm{f_{Female}}$'], [r'$\Delta \mathrm{f_{Female}}$', '', '', r'$\Delta \mathrm{f_{Intruder}}$', r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$', '', '', '', ]] # r'$\mathrm{f'+basename()+'}$''$|Intruder-Female|$',r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$' axs = plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all, mult_val, p0, p_arrays, wierd_charing, xlim=xlim_ROC_talk2(), counter=counter, add_texts=[0, 0, 0, 0, 0, 0, 0, 0, 500, 0, 0, 0, 0, 0], text_extra=True) return axs def xlim_ROC_talk2(): return (0, 200) def color_sumpeak(): color0122 = 'yellow' return color0122 def plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all, mult_val, p0, p_arrays, wierd_charing, counter=4, xlim=(0, 300), add_texts=[0, 0, 0, 0, 0, 0, 0, ], text_extra=True, log = ''): axs = [] for j in range(len(p_arrays)): if (j != 0) & wierd_charing: ax = plt.subplot(grid0[counter, j], sharex=ax, sharey=ax) # , sharex=ax else: ax = plt.subplot(grid0[counter, j]) # , sharex=ax axs.append(ax) p0_means = [] for i in range(len(p0)): ax.plot(fs, p_arrays[j][i], color='grey') p0_mean = np.mean(p_arrays[j], axis=0) p0_means.append(p0_mean) ax.plot(fs, p0_mean, color='black') # plt_peaks(ax[0], p01, fs, 'orange') for p in range(len(p0_means)): if 'wo_female' in female: freqs = [np.abs(DF1), fr] colors = [color01, color0] labels = ['DF1', 'baseline'] else: labels = labels_all[j] colors = colors_all[j] freqs = freqs_all[j] new = True ax.set_xlim(xlim) if new: plt_peaks_several(freqs, p_arrays, ax, p0_mean, fs, labels, 0, colors, add_texts=add_texts, add_log=2.5, exact=False, text_extra=True, perc_peaksize=5, rel='rel', ms=24,ha='left', clip_on=False, several_peaks=True, log=log) # True ha='center', else: df_passed = [] for f in range(len(freqs)): if int(freqs[f]) in df_passed: add = (np.max(np.max(p_arrays)) + add_to) * mult_val else: add = (np.max(np.max(p_arrays)) + add_to) * 0.05 try: _, _ = plt_peaks(ax, p0_means[p], freqs[f], fs, fr_color=colors[f], s=25, label=labels[f], add_text=add_texts[f], text_extra=text_extra, add=add, extend=False, clip_on = False) except: print('p problem') embed() df_passed.append(int(freqs[f])) return axs def sum_intruder_core(): return r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$' def plt_area_between(frame_cell, ax0, ax, colors_w, colors_wo, f, cut_starts=False, alphas=[1, 1, 1, 1, 1, 1, 1, 1, 1], starts=0.25, ls='-', labels_with_female='', arrow=True, fill=True, labels_without_female='', dist_redo=True): cell = frame_cell.cell.unique()[0] frame_cell = frame_cell.groupby('c1', as_index=False).mean() c1 = frame_cell.c1 if dist_redo: c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True) sorting = np.argsort(c1) c1 = np.sort(c1) frame_cell['auci_02_012'].iloc[frame_cell.index] = frame_cell['auci_02_012'].iloc[sorting] frame_cell['auci_base_01'].iloc[frame_cell.index] = frame_cell['auci_base_01'].iloc[sorting] upper0 = frame_cell['auci_02_012'] * 1 lower1 = frame_cell['auci_base_01'] * 1 upper0new = [0.5] upper0new.extend(upper0) upper0 = np.array(upper0new) lower0new = [0.5] lower0new.extend(lower1) lower1 = np.array(lower0new) upper = upper0 * 1 lower = lower1 * 1 lower[lower > upper0] = upper0[lower > upper0] upper[upper > lower1] = lower1[upper > lower1] c1_new = [5] c1_new.extend(c1) c1 = c1_new if fill: ax0.fill_between(c1, upper0, upper, color='red', edgecolor=None, zorder=2, alpha=0.05) ax0.fill_between(c1, lower1, lower, color='blue', edgecolor=None, zorder=2, alpha=0.05) ax.set_xlim(0, ax.get_xlim()[1]) if type(starts) == list: start = starts[f] else: start = starts test = False if test: ax = plt.subplot(1, 1, 1) with_female = np.array(upper0) without_female = np.array(lower1) c1_interp = c1 reintepolate = True if reintepolate: c1_interp_new = np.arange(np.min(c1_interp), np.max(c1_interp), 1) with_female = interpolate(c1_interp, with_female, c1_interp_new, kind='linear') without_female = interpolate(c1_interp, without_female, c1_interp_new, kind='linear') c1_interp = c1_interp_new # _new pos_l = np.argmin(np.abs(with_female - start)) pos_r = np.argmin(np.abs(without_female - start)) val_l = with_female[pos_l] val_r = without_female[pos_r] pos_ll = np.min([pos_l, pos_r]) if cut_starts: ax.plot(c1_interp[pos_r::], without_female[pos_r::], alpha=alphas[f], color=colors_wo[f], label=labels_without_female, clip_on=True, linestyle=ls) # linewidth=lw, # todo: das muss man in linear machen ax.plot(c1_interp, with_female, alpha=alphas[f], color=colors_w[f], label=labels_with_female, clip_on=True) else: ax.plot(c1_interp, without_female, color=colors_wo[f], alpha=alphas[f], label=labels_without_female, clip_on=True, linestyle=ls) # linewidth = lw, ax.plot(c1_interp, with_female, color=colors_w[f], alpha=alphas[f], label=labels_with_female, clip_on=True) # , linewidth = lw if arrow: # colors_w[f] # ich will halt dass es einge gerade linie ist if val_r != val_l: val_rr = val_l if pos_l != pos_ll: ax.annotate('', xy=(c1_interp[pos_l], val_l), xytext=(c1_interp[pos_r], val_rr), arrowprops=dict(arrowstyle="->", color='black'), textcoords='data', xycoords='data', horizontalalignment='left') else: ax.annotate('', xy=(c1_interp[pos_l], val_l), xytext=(c1_interp[pos_r], val_rr), arrowprops=dict(arrowstyle="->", color='black'), textcoords='data', xycoords='data', horizontalalignment='left') ax.set_xlabel(core_distance_label()) ax.set_ylabel(core_auc_label()) return c1_interp, without_female, with_female def arrow_annotate(ax, c1, colors_w, f, pos_l, pos_ll, pos_r, val_l, val_r): if pos_l != pos_ll: ax.annotate('', xy=(c1[pos_l], val_l), xytext=(c1[pos_r], val_r), arrowprops=dict(arrowstyle="->", color=colors_w[f]), textcoords='data', xycoords='data', horizontalalignment='left') else: ax.annotate('', xy=(c1[pos_l], val_l), xytext=(c1[pos_r], val_r), arrowprops=dict(arrowstyle="->", color=colors_w[f]), textcoords='data', xycoords='data', horizontalalignment='left') def plt_several_ROC_declining_one_with_ROC_single_in_one_dec(bt=0.12, lf=0.07, females=[], color_01='green', color_02='red', color_012='orange', figsize=(12, 5.5), frame_names=[ 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'], reshuffled='reshuffled', datapoints=1000, dev=0.0005, a_f1s=[0.03], printing=False, plus_q='minus', way='absolut', stimulus_length=0.5, runs=3, trials_nr=500, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto', us_name='', wr=[1.6, 2]): try: pass except: print('split something') freq1_ratio = float(frame_names[0].split('FrF1rel_')[1].split('_FrF2rel')[0]) freq2_ratio = float(frame_names[0].split('FrF2rel_')[1].split('_C2')[0]) cells = [ "2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ] cells_chosen = [ '2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1', plt.rcParams['lines.linewidth'] = 1 model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") grid_here = gridspec.GridSpec(1, 2, hspace=0.3, wspace=0.36, left=lf, top=0.92, bottom=bt, right=0.98, width_ratios=wr) # 1.3,1 wspace=0.16 grid_here1 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.6, hspace=0.75, subplot_spec=grid_here[0]) # 0.3 grid_here2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.3, subplot_spec=grid_here[1], height_ratios=[1.5, 3]) if len(cells) < 1: cells = len(model_cells) colors_wo = [color_01] # ['limegreen', 'green', 'darkgreen'] colors_w = [color_012] # ['orange', 'darkorange','goldenrod'] for cell_here in cells: # sachen die ich variieren will ########################################### single_waves = ['_SeveralWave_'] # , '_SingleWave_'] ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: for a_f1 in a_f1s: a_frs = [1] titles_amp = ['base eodf'] # ,'baseline to Zero',] for a, a_fr in enumerate(a_frs): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0() # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) save_name_roc = 'decline_ROC_examples_trial_nr.csv' redo = False version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False) for run in range(runs): print(run) t1 = time.time() if cont_redo: trials_nr_base = 1 stimulus_length = 1 model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) else: trials_nr_base = trials_nr spikes_base = [[]] * trials_nr_base for t in range(trials_nr_base): # get the baseline properties here # baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output stimulus = eod_fish_r if 'Zero' in titles_amp[a]: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, power_variant=power_here, power_nr=n, reshuffle=reshuffled, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 if printing: print('Baseline time' + str(time.time() - t1)) base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) fr = np.mean(base_cut) two_third_fr = fr * freq2_ratio third_fr = fr * freq1_ratio if plus_q == 'minus': two_third_fr = -two_third_fr third_fr = -third_fr freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr, freqs1 = [ eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ] base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) for ff, freq1 in enumerate(freqs1): if cont_redo: frame = pd.read_csv(save_name_roc) tp_012_all = frame['tp_012'] # = tp_012_all tp_01_all = frame['tp_01'] # = tp_01_all tp_02_all = frame['tp_02'] # = tp_02_all fp_all = frame['fp_all'] # = fp_all else: freq1 = [freq1] freq2 = [freqs2[ff]] t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev) if printing: print('Generation process' + str(time.time() - t1)) ################################## array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] t_off = 10 position_diff = 0 results_diff = pd.DataFrame() results_diff['f1'] = freq1 results_diff['f2'] = freq2 results_diff['f0'] = eod_fr trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling) frame = pd.DataFrame() frame['tp_012'] = tp_012_all frame['tp_01'] = tp_01_all frame['tp_02'] = tp_02_all frame['fp_all'] = fp_all frame['threshhold'] = threshhold if version_comp == 'develop': frame.to_csv(save_name_roc) # threshhold if run == 0: color = 'black' lw = 1 else: color = 'grey' lw = 0.5 for f_nr, female in enumerate(females): if female == 'w_female': ax1 = plt.subplot(grid_here1[0]) color_e = 'lightgrey' roc_wo_female(color, ax1, tp_02_all, tp_012_all, color_02, color_012, title_color='black', color_e=color_e) # colors_w[ff] plt.fill_between(tp_02_all, tp_02_all, tp_012_all, color=colors_w[ff], alpha=0.8) ax1.set_title('') #: 0 ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 elif female == 'wo_female': ax1 = plt.subplot(grid_here1[0]) color_e = None roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black', title_color='black', color_e=color_e) # colors_wo[ff] plt.fill_between(fp_all, fp_all, tp_01_all, color=colors_wo[ff], alpha=0.8) ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 ax1.set_title('') #: 0 else: ax1 = plt.subplot(grid_here1[0]) roc_wo_female(color, ax1, tp_02_all, tp_012_all, 'black', 'black', title_color='black', color_e=color_e) # colors_w[ff] plt.fill_between(tp_02_all, tp_02_all, tp_012_all, color=colors_w[ff], alpha=0.8) roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black', title_color='black') # colors_wo[ff] plt.fill_between(fp_all, fp_all, tp_01_all, color=colors_wo[ff], alpha=1) ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 ax1.set_title('') #: 0 ################################################ # part with the ROC declining ax0 = plt.subplot(grid_here2[0]) distance_cm = np.arange(0, 200, 0.2) xlim_dist = core_xlim_dist_roc() distances_mv = c_to_dist(distance_cm, convert='dist_to_contrast') ax0.plot(distance_cm, distances_mv, label='cubed', color='black') ax0.set_xlim(xlim_dist) ax0.set_ylabel('EOD Amplitude\n [mV]') ax0.set_yticks([]) test = False if test: from utils_test import test_distances test_distances() ax0.set_yscale('log') ax0.set_ylim(0, 2) ax1 = plt.subplot(grid_here2[1]) for c, cell in enumerate(cells_chosen): for f, frame_name in enumerate(frame_names): path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv' if os.path.exists(path): frame = pd.read_csv(path) path_ref = load_folder_name( 'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) _, _ = find_row_col(cells, row=4) frame_cell = frame[frame.cell == cell] axs = [ax1] label_f = ['with female', 'CLS: 100n with female', 'LS: 1000n with female', ] label_f_wo = ['without female', 'CLS: 100n without female', 'LS: 1000n without female', ] labels_w = [label_f[f], label_f[f]] labels_wo = [label_f_wo[f], label_f_wo[f]] for a, ax in enumerate(axs): if len(frame_cell) > 0: c1 = c_dist_recalc_func(eod_size_change=True, mult_eod=0.5, frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True) lw = lw_roc() s = 15 # 100 if female == 'w_female': ax.plot(c1, frame_cell['auci_02_012'], color=colors_w[f], label=labels_w[f], clip_on=True, linewidth=lw) elif female == 'wo_female': ax.plot(c1, frame_cell['auci_base_01'], color=colors_wo[f], label=labels_wo[f], clip_on=True, linewidth=lw) # , linestyle='--' else: plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f, labels_with_female=labels_w[a], labels_without_female='', arrow=True) ax.set_xlim(xlim_dist) ax.set_ylim(0, 0.52) ax.set_yticks_delta(0.1) pos = np.argmin(np.abs(frame_cell.c1 - a_f1)) if f == 0: c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=[frame_cell.c1.iloc[pos]], cell=cell, c_dist_recalc=True) if female == 'wo_female': ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos], clip_on=True, color=colors_wo[0], s=s) # , facecolor = 'none' elif female == 'w_female': ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos], clip_on=True, color=colors_w[0], s=s) # ,facecolor='none' else: ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos], clip_on=True, color=colors_wo[0], s=s) # , facecolor = 'none' ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos], clip_on=True, color=colors_w[0], s=s) # , facecolor='none' if a == 0: ax.legend(loc=(0.6, 0.7)) # , fontsize = 8, handlelength = 0.5 else: ax.legend(loc=(0.6, 0.6)) # , fontsize = 8, handlelength = 0.5 ax.set_ylim(0, 0.52) if c != 0: remove_yticks(ax) ax.show_spines('lb') ax1.set_ylabel(core_auc_label()) # ax1.set_xlabel('mV/cm') ax1.set_xlabel(core_distance_label()) ax = plt.gcf().axes if f_nr == 0: fig.tag(ax[0:3], xoffs=-6.5, yoffs=1.5, ) # 0.7 plt.subplots_adjust(left=0.03, wspace=0.3) save_visualization(frame_name, False, show_anything=False, pdf=True, jpg=True, png=False, counter_contrast=0, savename='', add='_' + female) plt.show() def core_distance_label(): return 'Intruder Distance [cm]' def core_auc_label(): return 'AUC' # 'Determinant' def core_xlim_dist_roc(): xlim_dist = [0, 225] return xlim_dist def lw_roc(): lw = 0.75 # 2 return lw def plt_several_ROC_declining_one_with_ROC_single_in_one(bt=0.12, lf=0.07, females=[], color_01='green', color_02='red', color_012='orange', frame_names=[ 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'], reshuffled='reshuffled', datapoints=1000, dev=0.0005, a_f1s=[0.03], printing=False, plus_q='minus', way='absolut', stimulus_length=0.5, runs=3, trials_nr=500, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto', us_name='', wr=[1.6, 2]): try: pass except: print('split something') freq1_ratio = float(frame_names[0].split('FrF1rel_')[1].split('_FrF2rel')[0]) freq2_ratio = float(frame_names[0].split('FrF2rel_')[1].split('_C2')[0]) cells = [ "2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ] cells_chosen = [ '2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1', plt.rcParams['lines.linewidth'] = 1 model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") grid_here = gridspec.GridSpec(1, 2, hspace=0.6, wspace=0.16, left=lf, top=0.92, bottom=bt, right=0.98, width_ratios=wr) # 1.3,1 grid_here1 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.3, hspace=0.75, subplot_spec=grid_here[0]) grid_here2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.3, subplot_spec=grid_here[1], height_ratios=[1.5, 3]) if len(cells) < 1: cells = len(model_cells) colors_wo = [color_01] # ['limegreen', 'green', 'darkgreen'] colors_w = [color_012] # ['orange', 'darkorange','goldenrod'] for cell_here in cells: # sachen die ich variieren will ########################################### single_waves = ['_SeveralWave_'] # , '_SingleWave_'] ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: for a_f1 in a_f1s: a_frs = [1] titles_amp = ['base eodf'] # ,'baseline to Zero',] for a, a_fr in enumerate(a_frs): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0() # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) # fig = plt.figure(figsize=(11.5, 5.4)) save_name_roc = 'decline_ROC_examples_trial_nr.csv' version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() for run in range(runs): print(run) t1 = time.time() if (os.path.exists(save_name_roc)) | (version_comp == 'public'): trials_nr_base = 1 stimulus_length = 1 model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) else: trials_nr_base = trials_nr spikes_base = [[]] * trials_nr_base for t in range(trials_nr_base): # get the baseline properties here # baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output stimulus = eod_fish_r if 'Zero' in titles_amp[a]: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, power_variant=power_here, power_alpha=alpha, power_nr=n, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 if printing: print('Baseline time' + str(time.time() - t1)) base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) fr = np.mean(base_cut) two_third_fr = fr * freq2_ratio third_fr = fr * freq1_ratio if plus_q == 'minus': two_third_fr = -two_third_fr third_fr = -third_fr freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr, freqs1 = [ eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ] base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) for ff, freq1 in enumerate(freqs1): if (os.path.exists(save_name_roc)) | (version_comp == 'public'): frame = pd.read_csv(save_name_roc) tp_012_all = frame['tp_012'] # = tp_012_all tp_01_all = frame['tp_01'] # = tp_01_all tp_02_all = frame['tp_02'] # = tp_02_all fp_all = frame['fp_all'] # = fp_all else: freq1 = [freq1] freq2 = [freqs2[ff]] t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev) if printing: print('Generation process' + str(time.time() - t1)) array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] t_off = 10 position_diff = 0 results_diff = pd.DataFrame() results_diff['f1'] = freq1 results_diff['f2'] = freq2 results_diff['f0'] = eod_fr trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling) frame = pd.DataFrame() frame['tp_012'] = tp_012_all frame['tp_01'] = tp_01_all frame['tp_02'] = tp_02_all frame['fp_all'] = fp_all frame['threshhold'] = threshhold if version_comp == 'develop': frame.to_csv(save_name_roc) if run == 0: color = 'black' lw = 1 else: color = 'grey' lw = 0.5 for female in females: if female == 'w_female': ax1 = plt.subplot(grid_here1[0]) roc_wo_female(color, ax1, tp_02_all, tp_012_all, color_02, color_012, title_color='black') # colors_w[ff] plt.fill_between(tp_02_all, tp_02_all, tp_012_all, color=colors_w[ff], alpha=0.8) ax1.set_title('') #: 0 ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 elif female == 'wo_female': ax1 = plt.subplot(grid_here1[0]) roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black', title_color='black') # colors_wo[ff] plt.fill_between(fp_all, fp_all, tp_01_all, color=colors_wo[ff], alpha=0.8) ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 ax1.set_title('') #: 0 else: ax1 = plt.subplot(grid_here1[0]) roc_wo_female(color, ax1, tp_02_all, tp_012_all, 'black', 'black', title_color='black') # colors_w[ff] plt.fill_between(tp_02_all, tp_02_all, tp_012_all, color=colors_w[ff], alpha=0.8) roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black', title_color='black') # colors_wo[ff] plt.fill_between(fp_all, fp_all, tp_01_all, color=colors_wo[ff], alpha=1) ax1.set_xlabel('False-Positive Rate ') #: 0 ax1.set_ylabel('Correct-Detection Rate ') # 01 ax1.set_title('') #: 0 ################################################ # part with the ROC declining ax0 = plt.subplot(grid_here2[0]) distance = np.arange(0, 200, 0.2) xlim_dist = [0, 70] distances_mv = c_to_dist_reverse(distance) # distance_changed*factor ax0.plot(distance, distances_mv, label='cubed', color='black') ax0.set_xlim(xlim_dist) ax0.set_ylabel('EOD Amplitude') ax0.set_yticks([]) test = False if test: from utils_test import test_vals test_vals() ax0.set_yscale('log') ax0.set_yticks([]) ax1 = plt.subplot(grid_here2[1]) for c, cell in enumerate(cells_chosen): for f, frame_name in enumerate(frame_names): path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv' if os.path.exists(path): frame = pd.read_csv(path) path_ref = load_folder_name( 'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) _, _ = find_row_col(cells, row=4) frame_cell = frame[frame.cell == cell] axs = [ax1] label_f = ['with female', 'CLS: 100n with female', 'LS: 1000n with female', ] label_f_wo = ['without female', 'CLS: 100n without female', 'LS: 1000n without female', ] labels_w = [label_f[f], label_f[f]] labels_wo = [label_f_wo[f], label_f_wo[f]] for a, ax in enumerate(axs): if len(frame_cell) > 0: c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True) if female == 'w_female': ax.plot(c1, frame_cell['auci_02_012'], color=colors_w[f], label=labels_w[f], clip_on=True, linewidth=2) elif female == 'wo_female': ax.plot(c1, frame_cell['auci_base_01'], color=colors_wo[f], label=labels_wo[f], clip_on=True, linewidth=2) # , linestyle='--' else: plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f, labels_with_female=labels_w[a], talk=False, labels_without_female='', arrow=True) ax.set_xlim(xlim_dist) ax.set_ylim(0, 0.52) ax.set_yticks_delta(0.1) pos = np.argmin(np.abs(frame_cell.c1 - a_f1)) if f == 0: c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True) s = 100 if female == 'wo_female': ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos], clip_on=True, color=colors_wo[0], s=s) # , facecolor = 'none' elif female == 'w_female': ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos], clip_on=True, color=colors_w[0], s=s) # ,facecolor='none' else: ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos], clip_on=True, color=colors_wo[0], s=s) # , facecolor = 'none' ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos], clip_on=True, color=colors_w[0], s=s) # , facecolor='none' if a == 0: ax.legend(loc=(0.6, 0.7)) # , fontsize = 8, handlelength = 0.5 else: ax.legend(loc=(0.6, 0.6)) # , fontsize = 8, handlelength = 0.5 ax.set_ylim(0, 0.52) if c != 0: remove_yticks(ax) ax.show_spines('lb') ax1.set_ylabel('Determinant') ax1.set_xlabel('mV/cm') ax1.set_xlabel('Distance [cm]') plt.subplots_adjust(left=0.03, wspace=0.3) save_visualization(frame_name, False, show_anything=False, pdf=True, jpg=True, png=False, counter_contrast=0, savename='', add='_' + female) plt.show() def plt_several_ROC_square_nonlin(brust_corrs=['_burstIndividual_'], nffts=['whole'], powers=[1], contrasts=[0], column=None, noises_added=[''], D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'], external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], repeats=[1000000], receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], label=r'$\frac{1}{mV^2S}$'): plot_style() default_settings(column=column, width=12) # ts=12, ls=13, fs=11, cells = [ "2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ] grid = gridspec.GridSpec(1, 2, wspace=0.35, hspace=0.5, left=0.06, top=0.8, bottom=0.15, right=0.96) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8, ################################### # plot square ax = plt.subplot(grid[0]) square_part(ax) ax.set_aspect('equal') #################################### # plot nonlin ax = plt.subplot(grid[1]) trials_nrs = [1] iternames = [brust_corrs, cells, D_extraction_method, external_noise, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal, c_noises, ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ] for all in it.product(*iternames): burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all print(trials_stim, stim_type_noise, power, nfft, a_fe, a_fr, dendrid, var_type, cut_off1, trial_nrs) nr = '2' trial_nr = 250000 save_name = load_folder_name( 'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_' + str( trial_nr) + '_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' path = save_name + '.pkl' # '../'+ model = pd.read_pickle(path) # load_data(path, cells, save_name) model_show = model[( model.cell == cell)] new_keys = model_show.index.unique() # [0:490] stack_plot = model_show[new_keys] # [list(map(str, new_keys))] stack_plot = np.abs(stack_plot.iloc[np.arange(0, len(new_keys), 1)]) ax.set_xlim(0, 237) ax.set_ylim(0, 237) ax.set_aspect('equal') model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") model_params = model_cells[model_cells['cell'] == cell] noise_strength = model_params.noise_strength.iloc[0] # **2/2 D = noise_strength # (noise_strength ** 2) / 2 _, _, _ = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based stack_plot = RAM_norm(stack_plot, trials_stim=trials_stim, model_show=model_show) perc = '10' # 'perc' im = plt_RAM_perc(ax, perc, stack_plot) ax.set_aspect('equal') cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal" cbar.set_label(label, labelpad=100) # rotation=270, ax.set_xlabel(F1_xlabel(), labelpad=20) ax.set_ylabel(F2_xlabel()) save_visualization(jpg=True, png=False) plt.show() def plt_several_ROC_declining_classified_small(): frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'] cm = plt.get_cmap("hsv") cells_chosen = ['2013-01-08-aa-invivo-1', "2012-06-27-ah-invivo-1", "2014-06-06-ac-invivo-1"] # '2012-06-27-an-invivo-1', cells = ["2013-01-08-aa-invivo-1", "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1", "2012-06-27-ah-invivo-1", ] x_pos = 0.02 grid = gridspec.GridSpec(1, 5, wspace=0.2, hspace=0.5, left=0.1, top=0.8, bottom=0.15, right=0.95, width_ratios=[1, 1, 1, 0.5, 1]) # height_ratios = [1,6]bottom=0.25, top=0.8, grid1 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.3, hspace=0.75, subplot_spec=grid[-1]) for c, cell in enumerate(cells_chosen): grid0 = gridspec.GridSpecFromSubplotSpec(5, 1, wspace=0.2, hspace=0.35, subplot_spec=grid[c]) # height_ratios=[1, 0.7, 1, 1], for f, frame_name in enumerate(frame_names): path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv' if os.path.exists(path): frame = pd.read_csv(path) title = cut_title(frame_name, datapoints=100) plt.suptitle(title) path_ref = load_folder_name( 'calc_ROC') + '/calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) frame_ref = frame_ref.sort_values(by='cv_0') colr = [cm(float(i) / (len(frame_ref))) for i in range(len(frame_ref))] cells_sorted = frame_ref.cell.unique() _, _ = find_row_col(cells, row=4) frame_cell = frame[frame.cell == cell] ax0 = plt.subplot(grid0[f]) ax1 = plt.subplot(grid0[3]) ax2 = plt.subplot(grid0[4]) axs = [ax0, ax1] colors = ['black', 'grey', 'lightgrey', ] for ax in axs: if len(frame_cell) > 0: plt_area_between(frame_cell.c1, frame_cell, ax0, ax, colors, colors, f) ax.axhline(0, linestyle='--', color='grey', linewidth=0.5) col_pos = np.where(cells_sorted == cell)[0][0] if f == 0: ax0.set_title(cell[0:13] + '\n cv ' + str(np.round(np.mean(frame_cell.cv_0.unique()), 2)), color=colr[col_pos], fontsize=8) ax.set_ylim(0, 0.5) if c != 0: remove_yticks(ax) remove_yticks(ax2) else: ax2.set_ylabel('B1+B2') if f == 1: ax0.set_ylabel('Determinant') remove_xticks(ax0) remove_xticks(ax1) ax.axvline(x_pos, color='grey', linestyle='--', linewidth=0.5) ax2.plot(frame_cell.c1, frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean'], color=colors[f]) ax2.set_xscale('log') ax2.axvline(x_pos, color='grey', linestyle='--', linewidth=0.5) ax2.set_xlabel('mV/cm') ax2.set_ylim(0, 0.5) ###################################### # plot the plot on the right upper part (Area vs CV) path = load_folder_name('calc_ROC') + '/' + frame_names[0] + '.csv' ax_scatter = plt.subplot(grid1[0]) ax_scatter_nonlin_sole = plt.subplot(grid1[1]) ax_scatter_nonlin = plt.subplot(grid1[2]) cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one = calc_areas(path, frame_ref, colr, x_pos, cells_chosen) ax_scatter.scatter(cvs, diff_areas, color=colr, s=15, clip_on=False) ax_scatter.axhline(0, linestyle='--', linewidth=0.5, color='grey') ax_scatter.set_xlabel('CV') ax_scatter.set_ylabel('Area Detection improvement') ax_scatter_nonlin_sole.scatter(cvs, nonlin_area, color=colr, s=15, clip_on=False) ax_scatter_nonlin_sole.axhline(0, linestyle='--', linewidth=0.5, color='grey') ax_scatter_nonlin_sole.set_xlabel('CV') ax_scatter_nonlin_sole.set_ylabel('Area Nonlinearity (B1+B2)') ax_scatter_nonlin.set_xlabel('Area Detection improvement') ax_scatter_nonlin.set_ylabel('Area Nonlinearity (B1+B2)') ax_scatter_nonlin.scatter(nonlin_area, diff_areas, color=colr, s=15, clip_on=False) ###################################### # plot the plot on the right lower part (Area vs Nonlin at B1+B2) save_visualization(png=False) plt.show() def plt_ROC_model_w_female2(redo=False, t_off=10, top=0.95, bottom=0.12, add_name='', color0='green', color01='blue', color02='red', color012='orange', female='wo_female', reshuffled='reshuffled', datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False, plus_q='minus', freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3, way='absolut', stimulus_length=0.5, runs=3, trials_nr=500, cells=[], show=False, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, fr=None, gain=1, fish_jammer='Alepto', us_name=''): save_name_roc = 'decline_ROC_examples_trial_nr.csv' version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False) if cont_redo: stimulus_length = 0.14 plt.rcParams['lines.linewidth'] = 1 model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells) < 1: cells = model_cells.cell # ) for cell_here in cells: # sachen die ich variieren will ########################################### single_waves = ['_SeveralWave_'] # , '_SingleWave_'] ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: for a_f1 in a_f1s: a_frs = [1] titles_amp = ['base eodf'] # ,'baseline to Zero',] for a, a_fr in enumerate(a_frs): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0() # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, ) # phaseshift_fr, # cell_recording, zeros, mimick, # sampling, fish_receiver, deltat, # nfft, nfft_for_morph, # fish_morph_harmonics_var=fish_morph_harmonics_var, # beat=beat # embed() sampling = 1 / deltat variant = 'sinz' spikes_base = [[]] * trials_nr default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(12.43)) default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(13.98)) default_figsize(width=cm_to_inch(31.89), length=cm_to_inch(15)) add_bottom, add_right = implement_fig_borders(bottom=1.59) default_ticks_talks() plt.rcParams['figure.facecolor'] = 'none' fig = plt.figure() grid = gridspec.GridSpec(1, 2, wspace=0.4, left=0.09, top=top, bottom=bottom + add_bottom, right=1.02 - add_right, height_ratios=[1], width_ratios=[4, 2.8]) # 1.3,1 grid0 = gridspec.GridSpecFromSubplotSpec(3, 2, wspace=0.18, hspace=0.1, subplot_spec=grid[0], height_ratios=[1, 0.6, 1]) # ,0.4,1.2 grid1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[1]) # wspace=0.5, hspace=0.55, for run in range(runs): print(run) t1 = time.time() for t in range(trials_nr): # get the baseline properties here # baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output stimulus = eod_fish_r stimulus_base = eod_fish_r if 'Zero' in titles_amp[a]: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, power_variant=power_here, power_nr=n, **model_params) if t == 0: # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 if printing: print('Baseline time' + str(time.time() - t1)) base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) if not fr: fr = np.mean(base_cut) if 'diagonal' in diagonal: two_third_fr = fr * freq2_ratio freq1_ratio = (1 - freq2_ratio) third_fr = fr * freq1_ratio else: two_third_fr = fr * freq2_ratio third_fr = fr * freq1_ratio if plus_q == 'minus': two_third_fr = -two_third_fr third_fr = -third_fr freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr, freqs1 = [ eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ] sampling_rate = 1 / deltat base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev) fr = np.mean(base_cut) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) isi = np.diff(spikes_base[0]) cv0 = np.std(isi) / np.mean(isi) for ff, freq1 in enumerate(freqs1): freq1 = [freq1] freq2 = [freqs2[ff]] print(cell + ' f1' + str(freq1) + ' f2 ' + str(freq2) + ' f1' + str( freq1 - eod_fr) + ' f2 ' + str(freq2 - eod_fr)) t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev, redo_stim=False) if printing: print('Generation process' + str(time.time() - t1)) array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] position_diff = 0 results_diff = pd.DataFrame() results_diff['f1'] = freq1 results_diff['f2'] = freq2 results_diff['f0'] = eod_fr trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling) if run == 0: color = 'black' lw = 1.5 z = 2 else: color = 'grey' lw = 0.8 z = 1 color0 = 'black' if cont_redo: frame = pd.read_csv(save_name_roc) tp_012_all = frame['tp_012'] # = tp_012_all tp_01_all = frame['tp_01'] # = tp_01_all tp_02_all = frame['tp_02'] # = tp_02_all fp_all = frame['fp_all'] # = fp_all if 'wo_female' in female: ax_roc_wof = plt.subplot(grid1[0]) roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01, title_color=color01, z=z) elif 'base_female' in female: ax_roc_wof = plt.subplot(grid1[0]) roc_female(ax_roc_wof, color, fp_all, tp_02_all, lw, color0, color02, z=z, add_01='\n Female', add_base=' Baseline') ax_roc_wof.set_title('Receiver Operating Characteristics (ROC)', pad=15) else: ax_roc_wf = plt.subplot(grid1[0]) roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012, title_color=color012, z=z) if run == 0: plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012, spikes_base, mat_base, mat05_01, mat05_012, mat05_02, color02, color012, a_f2, trials, sampling, a_f1, fr, female, color01, color0, grid0, eod_fr, freq2, freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01, stimulus_base, time_array, carrier=True) ax = fig.axes remove_axes_roc_traces(ax, add=1) for aa, ax_here in enumerate(ax[2:5]): ax_here.set_xticks([]) for aa, ax_here in enumerate(ax[1::]): if aa not in np.arange(0, 100, 2): pass else: ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2]) plt.subplots_adjust(top=0.95, left=0.09, right=0.95, hspace=0.5, bottom=0.12, wspace=0.25) individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str( trials_nr) + '_stimulus_length_' + str( stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr) save_visualization(individual_tag, show=show, add=add_name, pdf=pdf, counter_contrast=0, savename='') def remove_axes_roc_traces(ax, add=0): ax[0 + add].show_spines('') ax[1 + add].show_spines('') ax[2 + add].show_spines('') ax[3 + add].show_spines('') ax[4 + add].set_ylabel('Firing Rate [Hz]') ax[4 + add].set_xlabel('Time [ms]') ax[5 + add].set_xlabel('Time [ms]') ax[5 + add].show_spines('b') def implement_fig_borders(bottom=1.89): bottom_pp = cm_to_inch(bottom) rigth_pp = cm_to_inch(2.33) add_right = rigth_pp / plt.rcParams['figure.figsize'][0] add_bottom = bottom_pp / plt.rcParams['figure.figsize'][1] return add_bottom, add_right def plt_ROC_model_w_female(redo=False, t_off=10, top=0.95, bottom=0.14, add_name='', color0='green', color01='blue', color02='red', color012='orange', figsize=(11.5, 5.4), female='wo_female', reshuffled='reshuffled', datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False, plus_q='minus', freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3, way='absolut', stimulus_length=0.5, runs=3, trials_nr=500, cells=[], show=False, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, fr=None, gain=1, fish_jammer='Alepto', us_name=''): save_name_roc = 'decline_ROC_examples_trial_nr.csv' version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not() cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False) if cont_redo: stimulus_length = 0.14 plt.rcParams['lines.linewidth'] = 1 model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells) < 1: cells = model_cells.cell # ) for cell_here in cells: # sachen die ich variieren will ########################################### single_waves = ['_SeveralWave_'] # , '_SingleWave_'] ####### VARY HERE for single_wave in single_waves: if single_wave == '_SingleWave_': a_f2s = [0] # , 0,0.2 else: a_f2s = [0.1] for a_f2 in a_f2s: for a_f1 in a_f1s: a_frs = [1] titles_amp = ['base eodf'] # ,'baseline to Zero',] for a, a_fr in enumerate(a_frs): model_params = model_cells[model_cells['cell'] == cell_here].iloc[0] eod_fr = model_params['EODf'] # .iloc[0] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0() # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat variant = 'sinz' if exponential == '': pass # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) spikes_base = [[]] * trials_nr fig = plt.figure(figsize=figsize) grid = gridspec.GridSpec(1, 2, wspace=0.3, left=0.09, top=top, bottom=bottom, right=0.96, height_ratios=[1], width_ratios=[4, 2.8]) # 1.3,1 grid0 = gridspec.GridSpecFromSubplotSpec(3, 2, wspace=0.18, hspace=0.1, subplot_spec=grid[0], height_ratios=[1, 0.6, 1]) # ,0.4,1.2 grid1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[1]) # wspace=0.5, hspace=0.55, for run in range(runs): print(run) t1 = time.time() for t in range(trials_nr): stimulus = eod_fish_r stimulus_base = eod_fish_r if 'Zero' in titles_amp[a]: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, power_variant=power_here, power_alpha=alpha, power_nr=n, reshuffle=reshuffled, **model_params) if t == 0: offset = offset_new * 1 if printing: print('Baseline time' + str(time.time() - t1)) base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev) if not fr: fr = np.mean(base_cut) if 'diagonal' in diagonal: two_third_fr = fr * freq2_ratio freq1_ratio = (1 - freq2_ratio) third_fr = fr * freq1_ratio else: two_third_fr = fr * freq2_ratio third_fr = fr * freq1_ratio if plus_q == 'minus': two_third_fr = -two_third_fr third_fr = -third_fr freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr, freqs1 = [ eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ] sampling_rate = 1 / deltat base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length, dev=dev) fr = np.mean(base_cut) _, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0) isi = np.diff(spikes_base[0]) cv0 = np.std(isi) / np.mean(isi) for ff, freq1 in enumerate(freqs1): freq1 = [freq1] freq2 = [freqs2[ff]] print(cell + ' f1' + str(freq1) + ' f2 ' + str(freq2) + ' f1' + str( freq1 - eod_fr) + ' f2 ' + str(freq2 - eod_fr)) t1 = time.time() phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, zeros, mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three( cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev, redo_stim=False) if printing: print('Generation process' + str(time.time() - t1)) array0 = [mat_base] array01 = [mat05_01] array02 = [mat05_02] array012 = [mat05_012] position_diff = 0 results_diff = pd.DataFrame() results_diff['f1'] = freq1 results_diff['f2'] = freq2 results_diff['f0'] = eod_fr trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd( results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling) if run == 0: color = 'black' lw = 1.5 z = 2 else: color = 'grey' lw = 0.8 z = 1 if cont_redo: frame = pd.read_csv(save_name_roc) tp_012_all = frame['tp_012'] # = tp_012_all tp_01_all = frame['tp_01'] # = tp_01_all tp_02_all = frame['tp_02'] # = tp_02_all fp_all = frame['fp_all'] # = fp_all if 'wo_female' in female: ax_roc_wof = plt.subplot(grid1[0]) roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01, title_color=color01, z=z) elif 'base_female' in female: ax_roc_wof = plt.subplot(grid1[0]) roc_female(ax_roc_wof, color, fp_all, tp_02_all, lw, color0, color02, z=z, add_01='\n Female', add_base=' Baseline') ax_roc_wof.set_title('Receiver Operating Characteristics (ROC)') else: ax_roc_wf = plt.subplot(grid1[0]) roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012, title_color=color012, z=z) if run == 0: plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012, spikes_base, mat_base, mat05_01, mat05_012, mat05_02, color02, color012, a_f2, trials, sampling, a_f1, fr, female, color01, color0, grid0, eod_fr, freq2, freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01, stimulus_base, time_array, carrier=True) ax = fig.axes ax[0 + 1].set_ylabel('Amplitude') ax[2 + 1].set_ylabel('Trials') ax[4 + 1].set_ylabel('Firing Rate [Hz]') ax[3 + 2].set_xlabel('Time [ms]') ax[4 + 2].set_xlabel('Time [ms]') for aa, ax_here in enumerate(ax[2:5]): ax_here.set_xticks([]) for aa, ax_here in enumerate(ax[1::]): if aa not in np.arange(0, 100, 2): pass else: ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2]) fig.tag([ax[1], ax[2], ax[0]], xoffs=-4.6, yoffs=1.5) plt.subplots_adjust(top=0.95, left=0.09, right=0.95, hspace=0.5, bottom=0.12, wspace=0.25) individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str( trials_nr) + '_stimulus_length_' + str( stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str( a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr) save_visualization(individual_tag, show=show, add=add_name, pdf=pdf, counter_contrast=0, savename='') def roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012, add_01='\n Intruder + Female', z=2, add_base=' Female', color_e='grey', title_color='black'): ax_roc_wf.set_title(r'With Female: ROC\ensuremath{\rm{_{Female}}}', color=title_color) # linewidth=lw,'With Female' ax_roc_wf.plot(tp_02_all, tp_012_all, color=color, zorder=z, clip_on=False) # , aspect = 'auto' ax_roc_wf.set_aspect('equal') if color_e: ax_roc_wf.plot([0, 1], [0, 1], color=color_e, linestyle='--') ax_roc_wf.set_xlabel('False-Positive Rate: ' + add_base, color=color02) ax_roc_wf.set_ylabel('Correct-Detection Rate: ' + add_01, color=color012) def roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01, color_e='grey', z=2, title_color='black', add_01='\n Intruder', add_base=' Baseline'): ax_roc_wof.set_title(r'Without Female: ROC\ensuremath{\rm{_{NoFemale}}}', color=title_color) # 'Without Female' ax_roc_wof.plot(fp_all, tp_01_all, color=color, linewidth=lw, zorder=z, clip_on=False) # , aspect = 'auto' if color_e: ax_roc_wof.plot([0, 1], [0, 1], color=color_e, linestyle='--', clip_on=False) ax_roc_wof.set_aspect('equal') ax_roc_wof.set_xlabel('False-Positive Rate: ' + add_base, color=color0) #: 0 ax_roc_wof.set_ylabel('Correct-Detection Rate: ' + add_01, color=color01) # 01 def c_to_dist_reverse(distance, power=2.09, factor=12.23): c_changed = factor / distance ** power return c_changed def Pl_model(freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[], fish_jammer='Alepto', us_name='', show=True, c_nrs_orig=[0.03, 0.2, 0.8]): # "2013-01-08-aa-invivo-1" runs = 1 n = 1 dev = 0.0005 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=6) # 6.4 default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(15.2)) default_ticks_talks() for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') for cell_here in cells_here: c_grouped = ['c1'] # , 'c2'] frame_cell_orig = frame[(frame.cell == cell_here)] if freq_mult: freqs = freq_two_mult_recalc(frame_cell_orig, freqs) if len(frame_cell_orig) > 0: print('cell there') try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) grid0 = gridspec.GridSpec(1, 1, bottom=0.2, top=0.8, left=0.115, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 2, wspace=0.4, hspace=0.1, width_ratios = [2, 1], subplot_spec=grid0[0]) # height_ratios=[2,1], grid_ll = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs_orig), hspace=0.35, wspace=0.2, subplot_spec=grid00[0]) # height_ratios=[1, 0.8],hspace=0.4,wspace=0.2,len(chirps) #grid_rr = gridspec.GridSpecFromSubplotSpec(2, 1, # wspace=0.04, hspace=0.1, # subplot_spec=grid0[1]) # height_ratios=[2,1], ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] color_stim = color_stim_core() color_eodf = coloer_eod_fr_core() print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 try: ax_u1 = plt.subplot(grid00[1]) except: print('grid search problem5') embed() add = get_mean_add(frame_cell_orig) #_original labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add=add) scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add, ] # 'amp_B1+B2_012_mean', labels = labels_pi_core() alpha = [1, 1, 1] c_dist_recalc = dist_recalc_phaselockingchapter() ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2, scores=scores, colors=[color01, coloer_eod_fr_core(), color_stim_core()], linestyles=['-', '-', '-'], alpha=alpha, labels=labels, sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc, delta=False) ax_u1.set_xlabel('Contrast$_{' + vary_val() + '}$ [$\%$]') if f != 0: print('hi') else: ax_u1.set_ylabel(representation_ylabel(delta=False)) # power_spectrum_name() axts = [] axps = [] axes = [] recalc = 100 c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc) mults_period = 3 xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] letters = ['A', 'B', 'C'] height = 240 for c_nn, c_nr in enumerate(c_nrs): ax_u1.scatter(c_nrs, height * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False, s=7) ax_u1.text(c_nr, height + 15, letters[c_nn], ha='center', va='center', color='black') ax_u1.plot([c_nr, c_nr], [0, height], color='black', linewidth=0.05, clip_on=False) ax_u1.set_ylim(0, 285) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 ####################### # plot the first array arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays, arrays_spikes, arrays_stim, choice='01') colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_here = p_arrays[1::] for a in range(len(arrays_here)): print('a' + str(a)) eodf = frame_cell.f0.iloc[0] f1 = frame_cell.f1.iloc[0] colors_peaks = [color01, color_stim, color_eodf] # , 'red'] freqs_psd = [np.abs(freq1), f1, eodf] grid_pt = gridspec.GridSpecFromSubplotSpec(6, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[a, c_nn], height_ratios=[1, 0.7, 0, 1, 0.25, 2.2 ]) # .2 hspace=0.4,wspace=0.2,len(chirps) axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here) if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): title_name = ' $c_{' + vary_val() + '}=%s\% c' + stable_val() + '=' % ( ((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str( elif a == 0: # elif (a_f1s[0] != 0):_{p}_{s} title_name = ' $c_{' + vary_val() + '}=%s$' % int(np.round( c_nrs[c_nn])) + '\,$\%$, ' + '\n $\Delta f_{' + vary_val() + '}= %s$\,Hz' % ( int(freq1)) # str() #+ '$\%$' elif a == 1: # elif (a_f2s[0] != 0): title_name = ' $c_{' + vary_val() + '}=%s$' % int( np.round(a_f2_cm[0])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % ( int(freq1)) # str() axe.text(1, 1, title_name, va='bottom', ha='right', transform=axe.transAxes) axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim) ############################# axt = plt.subplot(grid_pt[3]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) ############################# axp = plt.subplot(grid_pt[5]) axps.append(axp) log = '' # 'log' # 'log' maxx = eodf * 1.15 # 5 pp = log_calc_psd(log, p_arrays_here[a][0], np.nanmax(p_arrays_here)) freqs_peaks1, colors_peaks1, labels1, alphas1 = chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf=coloer_eod_fr_core(), name='01', stim_thing=False, color_stim=color_stim_core(), color_stim_mult=color_stim_core()) plt_peaks_several(freqs_peaks1, [pp], axp, pp, ff, labels1, 0, colors_peaks1, limit=10000, alphas=alphas1, ms=25, clip_on=False) plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd, colors_peaks=colors_peaks, xlim=(0, maxx)) if log: axp.show_spines('b') if a == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') else: axp.show_spines('lb') if c_nn != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) if a == 0: axt.show_spines('') if c_nn == 0: axt.xscalebar(0.3, -0.1, 5, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top') axp.set_xlabel('Frequency [Hz]') ############################# isis = False if isis: axi = plt.subplot(grid_pt[-1]) isis = [] for t in range(len(arrays_sp[a])): isi = calc_isi(arrays_sp[a][t], eodf) isis.append(isi) axi.hist(np.concatenate(isis), bins=100, color='grey') axi.set_xlabel(isi_xlabel()) axi.show_spines('b') f_counter += 1 axts_all.extend(axts) axps_all.extend(axps) ax_us[0].legend(ncol=1, loc=(-0.25, 1.01), columnspacing=2.5) # 5 -0.07#loc=(0.9, 0.7) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) join_x(ax_us) join_y(ax_us) fig = plt.gcf() #fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=3) #fig.tag([ax_u1], xoffs=-3, yoffs=3) save_visualization(cell_here, show) print('finished cell here') def get_mean_add(frame_cell_orig): if 'amp_B1_01_mean_original' in frame_cell_orig.keys(): add = '_mean_original' else: add = '_mean' return add def vary_contrasts50(freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[], fish_jammer='Alepto', us_name='', show=True, c_nrs_orig=[0.03, 0.2, 0.8]): # "2013-01-08-aa-invivo-1" runs = 1 n = 1 dev = 0.0005 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=6) # 6.4 for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') for cell_here in cells_here: c_grouped = ['c1'] # , 'c2'] frame_cell_orig = frame[(frame.cell == cell_here)] if freq_mult: freqs = freq_two_mult_recalc(frame_cell_orig, freqs) if len(frame_cell_orig) > 0: print('cell there') try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.96, left=0.115, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.04, hspace=0.1, subplot_spec=grid0[0]) # height_ratios=[2,1], grid_ll = gridspec.GridSpecFromSubplotSpec(2, len(c_nrs_orig), hspace=0.35, wspace=0.2, height_ratios=[1, 0.8], subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] color_stim = color_stim_core() color_eodf = coloer_eod_fr_core() print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 try: ax_u1 = plt.subplot(grid_ll[-1, :]) except: print('grid search problem2') embed() if 'amp_B1_01_mean_original' in frame_cell_orig.keys(): add = '_mean_original' else: add = '_mean' labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add=add) scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add, ] # 'amp_B1+B2_012_mean', labels = labels_pi_core() alpha = [1, 1, 1] c_dist_recalc = dist_recalc_phaselockingchapter() ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2, scores=scores, colors=[color01, coloer_eod_fr_core(), color_stim_core()], linestyles=['-', '-', '-'], alpha=alpha, labels=labels, sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc, delta=False) ax_u1.set_xlabel('Contrast$_{' + vary_val() + '}$ [$\%$]') if f != 0: print('hi') else: ax_u1.set_ylabel(representation_ylabel(delta=False)) # power_spectrum_name() axts = [] axps = [] axes = [] recalc = 100 c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc) mults_period = 3 xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] letters = ['A', 'B', 'C'] height = 240 for c_nn, c_nr in enumerate(c_nrs): ax_u1.scatter(c_nrs, height * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False, s=7) ax_u1.text(c_nr, height + 15, letters[c_nn], ha='center', va='center', color='black') ax_u1.plot([c_nr, c_nr], [0, height], color='black', linewidth=0.05, clip_on=False) ax_u1.set_ylim(0, 285) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 ####################### # plot the first array arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays, arrays_spikes, arrays_stim, choice='01') colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_here = p_arrays[1::] for a in range(len(arrays_here)): print('a' + str(a)) eodf = frame_cell.f0.iloc[0] f1 = frame_cell.f1.iloc[0] colors_peaks = [color01, color_stim, color_eodf] # , 'red'] freqs_psd = [np.abs(freq1), f1, eodf] grid_pt = gridspec.GridSpecFromSubplotSpec(6, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[a, c_nn], height_ratios=[1, 0.7, 0, 1, 0.25, 2.2 ]) # .2 hspace=0.4,wspace=0.2,len(chirps) axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here) if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): title_name = ' $c_{' + vary_val() + '}=%s\% c' + stable_val() + '=' % ( ((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str( elif a == 0: # elif (a_f1s[0] != 0):_{p}_{s} title_name = ' $c_{' + vary_val() + '}=%s$' % int(np.round( c_nrs[c_nn])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % ( int(freq1)) # str() #+ '$\%$' elif a == 1: # elif (a_f2s[0] != 0): title_name = ' $c_{' + vary_val() + '}=%s$' % int( np.round(a_f2_cm[0])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % ( int(freq1)) # str() axe.text(1, 1, title_name, va='bottom', ha='right', transform=axe.transAxes) axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim) ############################# axt = plt.subplot(grid_pt[3]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) ############################# axp = plt.subplot(grid_pt[5]) axps.append(axp) log = '' # 'log' # 'log' maxx = eodf * 1.15 # 5 pp = log_calc_psd(log, p_arrays_here[a][0], np.nanmax(p_arrays_here)) freqs_peaks1, colors_peaks1, labels1, alphas1 = chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf=coloer_eod_fr_core(), name='01', stim_thing=False, color_stim=color_stim_core(), color_stim_mult=color_stim_core()) plt_peaks_several(freqs_peaks1, [pp], axp, pp, ff, labels1, 0, colors_peaks1, limit=10000, alphas=alphas1, ms=25, clip_on=False) plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd, colors_peaks=colors_peaks, xlim=(0, maxx)) if log: axp.show_spines('b') if a == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') else: axp.show_spines('lb') if c_nn != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) if a == 0: axt.show_spines('') axt.xscalebar(0.1, -0.1, 5, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top') axp.set_xlabel('Frequency [Hz]') ############################# isis = False if isis: axi = plt.subplot(grid_pt[-1]) isis = [] for t in range(len(arrays_sp[a])): isi = calc_isi(arrays_sp[a][t], eodf) isis.append(isi) axi.hist(np.concatenate(isis), bins=100, color='grey') axi.set_xlabel(isi_xlabel()) axi.show_spines('b') f_counter += 1 axts_all.extend(axts) axps_all.extend(axps) ax_us[0].legend(ncol=3, loc=(0, 1), columnspacing=2.5) # 5 -0.07#loc=(0.9, 0.7) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) join_x(ax_us) join_y(ax_us) fig = plt.gcf() fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=1) fig.tag([ax_u1], xoffs=-3, yoffs=3) save_visualization(cell_here, show) print('finished cell here') def color_stim_core(): return 'grey' def coloer_eod_fr_core(): return 'black' def labels_pi_core(): labels = [DF_pi_core(), f_pi_core(), f_eod_pi_core(), ] # ('+f_eod_name_core_rm()+' + f_{p})$('+f_eod_name_core_rm()+' + f_{p}) ('+f_eod_name_core_rm()+' + f_{p}) return labels def vary_contrasts5(freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[], fish_jammer='Alepto', us_name='', show=True, c_nrs_orig=[0.05, 0.1, 0.8]): # "2013-01-08-aa-invivo-1" runs = 1 n = 1 dev = 0.0005 ############################################# # plot a single ROC Curve for the model! # das aus dem Lissabon talk und das was wir für Jörg verwenden werden # also wir wollen hier viele Kontraste und einige Frequenzen # das will ich noch für verschiedene Frequenzen und Kontraste default_settings() # ts=13, ls=13, fs=13, lw = 0.7 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_settings(column=2, length=6.5) for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') for cell_here in cells_here: c_grouped = ['c1'] # , 'c2'] frame_cell_orig = frame[(frame.cell == cell_here)] if freq_mult: freqs = freq_two_mult_recalc(frame_cell_orig, freqs) if len(frame_cell_orig) > 0: print('cell there') try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.92, left=0.11, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.04, hspace=0.1, subplot_spec=grid0[0]) # height_ratios=[2,1], grid_ll = gridspec.GridSpecFromSubplotSpec(2, len(c_nrs_orig), hspace=0.35, wspace=0.2, height_ratios=[1, 0.8], subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 # frame_cell = frame_cell_orig[ == freq1 & ] frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] color_stim = 'grey' color_eodf = 'black' print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 try: ax_u1 = plt.subplot(grid_ll[-1, :]) except: print('grid search problem3') embed() if 'amp_B1_01_mean_original' in frame_cell_orig.keys(): add = '_mean_original' else: add = '_mean' labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add=add) scores = ['amp_B1_01' + add, 'amp_f0_01' + add, 'amp_f1_01' + add, ] # 'amp_B1+B2_012_mean', labels = ['$\Delta f_{p}$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$', '$'+f_eod_name_core_rm()+'$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$', '$f_{p}$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$', ] alpha = [1, 1, 1] c_dist_recalc = dist_recalc_phaselockingchapter() ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2, scores=scores, colors=[color01, 'black', 'grey'], linestyles=linestyles, alpha=alpha, labels=labels, sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc) if f != 0: print('hi') else: ax_u1.set_ylabel(power_spectrum_name()) plt.suptitle(' $\Delta f_{p}= %s $ Hz' % (int(freq1))) # + cell_here + ' DF2=' + str(freq2) axts = [] axps = [] axes = [] recalc = 100 c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc) mults_period = 3 xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] for c_nn, c_nr in enumerate(c_nrs): ax_u1.scatter(c_nrs, np.zeros(len(c_nrs)), color='black', marker='^', clip_on=False) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 # plot the first array arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays, arrays_spikes, arrays_stim, choice='01') colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_here = p_arrays[1::] for a in range(len(arrays_here)): print('a' + str(a)) eodf = frame_cell.f0.iloc[0] f1 = frame_cell.f1.iloc[0] colors_peaks = [color01, color_stim, color_eodf] # , 'red'] freqs_psd = [np.abs(freq1), f1, eodf] grid_pt = gridspec.GridSpecFromSubplotSpec(8, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[a, c_nn], height_ratios=[1, 0.7, 0, 1, 0.25, 2.2, 1, 1.2]) # hspace=0.4,wspace=0.2,len(chirps) axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here) if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): title_name = ' c$_{p}=%s\% c2=' % ( ((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str( elif a == 0: # elif (a_f1s[0] != 0): title_name = ' c$_{p}=%s$' % int(np.round(c_nrs[c_nn])) + '$\%$' # str() #+ '$\%$' elif a == 1: # elif (a_f2s[0] != 0): title_name = ' $c2=%s$' % int(np.round(a_f2_cm[0])) + '$\%$' # str() axe.text(1, 1, title_name, va='bottom', ha='right', transform=axe.transAxes) ############################# axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim) ############################# axt = plt.subplot(grid_pt[3]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) ############################# axp = plt.subplot(grid_pt[-3]) axps.append(axp) log = '' # 'log' maxx = eodf * 1.15 # 5 pp = log_calc_psd(log, p_arrays_here[a][0], np.nanmax(p_arrays_here)) plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd, colors_peaks=colors_peaks, xlim=(0, maxx)) if log: axp.show_spines('b') if a == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') else: axp.show_spines('lb') if c_nn != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) if a == 0: axt.show_spines('') axt.xscalebar(0.1, -0.1, 5, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top') axp.set_xlabel('Frequency [Hz]') freqs_peaks, colors_peaks, labels, alphas = chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf='black', name='01', color_stim='grey', color_stim_mult='grey') plt_peaks_several(freqs_peaks, [pp], axp, pp, ff, labels, 0, colors_peaks, limit=10000, alphas=alphas, ms=18, clip_on=False) ############################# axi = plt.subplot(grid_pt[-1]) isis = [] for t in range(len(arrays_sp[a])): isi = calc_isi(arrays_sp[a][t], eodf) isis.append(isi) axi.hist(np.concatenate(isis), bins=100, color='grey') axi.set_xlabel(isi_xlabel()) axi.show_spines('b') f_counter += 1 axts_all.extend(axts) axps_all.extend(axps) ax_us[0].legend(ncol=3, loc=(0, 1)) # -0.07#loc=(0.9, 0.7) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) join_x(ax_us) join_y(ax_us) fig = plt.gcf() fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=1) fig.tag([ax_u1], xoffs=-3, yoffs=3) save_visualization(cell_here, show) print('finished cell here') def power_spectrum_name(): return 'Power [Hz]' def choose_arrays_phaselocking(arrays, arrays_spikes, arrays_stim, choice='all'): if choice == 'all': arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::] arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]# arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]# arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::] elif choice == '01': arrays_time = [arrays[1]] # [v_mems[1],v_mems[3]]#[1,2]#[1::] arrays_here = [arrays[1]] # [arrays[1],arrays[3]]#arrays[1::]# arrays_st = [arrays_stim[1]] # [arrays_stim[1],arrays_stim[3]]# arrays_sp = [arrays_spikes[1]] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::] return arrays_here, arrays_sp, arrays_st, arrays_time def f_vary_name(delta=False, freq=None): if delta: val = '\ensuremath{\Delta f_{1}}' else: val = '\ensuremath{f_{1}}' if freq: val = '$' + val + '=%s$' % freq + '\,Hz' return val def f_stable_name(freq=None, delta=False): if delta: val = '\ensuremath{\Delta f_{2}}' else: val = '\ensuremath{f_{2}}' if freq: val = '$' + val + '=%s$' % freq + '\,Hz' return val def strong_signals(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='', show=True,indexes = [[0, 1, 2, 3,4]]): runs = 1 n = 1 dev = 0.001 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=7.5) default_figsize(width=cm_to_inch(21.6), length=cm_to_inch(14)) default_ticks_talks() for _ in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str( stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] full_names = [ 'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) c_nrs_orig = [0.2] #0.02, 0.0002, 0.05, 0.5 trials_nr = 20 # 20 redo = False # True log = 'log' # 'log' grid0 = gridspec.GridSpec(1, 1, bottom=0.15, top=0.8, left=0.13, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.04, hspace=0.1, subplot_spec=grid0[0]) # height_ratios=[2,1], grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1,wspace=0.04, hspace=0.1,subplot_spec = grid00[0]) # height_ratios=[2,1], ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str( frame_cell.f2.unique()) + ' for cell ' + str(cell_here)) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean', nr=4) #add='_mean' add = get_mean_add(frame_cell_orig) nr=4 labels = ['$A(\Delta $' + f_vary_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_vary_name(), '$A(\Delta $' + f_vary_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name(), '$A(\Delta $' + f_stable_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_stable_name(), '$A(\Delta $' + f_stable_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name() ,'amp_f0_01' + add,] labels = ['Intruder detection without female', 'Intruder detection with female', 'Receiver detection with intruder'] # 'Female detection without intruder', # 'Female detection with intruder', color01 = 'darkred' # 'darkgreen' color02 = 'darkblue' # 'darkblue' color01_012 = 'red' # 'red'#'black'##'lightgreen' # 'blue'# color02_012 = 'cyan' # 'green'#'lightblue'#'grey'# colors = ['green', 'orange', 'black']#color02, color02_012, colors_array = ['grey', 'green', 'lightgreen', 'purple']#color01 dashed = (0, (nr, nr)) linestyles = ['-', '-', '-', '-', '-'] alpha = [1, 1, 1, 1, 1, 1] linewidths = [1.6, 1.4, 1.4, 1.4]#1.6, 1.4, scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add,'amp_f0_01' + add ] # 'amp_B1+B2_012_mean', scores = ['amp_B1_01' + add, 'amp_B1_012' + add, 'amp_f0_01' + add] # 'amp_B2_02' + add, #'amp_B2_012' + add,'amp_B1+B2_012_mean',#($'+f_eod_name_core_rm()+'$ + $f_{p}$)($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)($'+f_eod_name_core_rm()+'$ + $f_{s}$ )($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$) print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 c_dist_recalc = dist_recalc_phaselockingchapter() c_dist_recalc = True c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc) if c_dist_recalc == False: c_nrs = np.array(c_nrs) * 100 mults_period = 3 start = 200 # 1000 xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] letters = ['A']#, 'B' #[0, 1], [2, 3], for i, index in enumerate(indexes): try: ax_u1 = plt.subplot(grid_rr[i]) except: print('grid search problem') embed() try: plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2, scores=np.array(scores)[index], labels=np.array(labels)[index], colors=np.array(colors)[index], linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index], alpha=np.array(alpha)[index],lim_recalc = None, sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc) except: print('something lenght') embed() ax_us.append(ax_u1) frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] c1 = c_dist_recalc_here(c_dist_recalc, frame_cell) #ax_u1.set_xlim(0, 50) if i != 0: ax_u1.set_ylabel('') remove_yticks(ax_u1) #if i < 2: # ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]], # frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1) #ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v', # clip_on=False) ax_us[-1].set_xlim(0, 200) ax_us[-1].set_ylim(0, 400) ax_us[-1].set_xlabel('Intruder distance [cm]') ax_us[-1].set_ylabel('Detection') axts = [] axps = [] axes = [] #if len(indexes[0]) == 3: try: reorder_legend_handles(ax_us[-1], order=[0, 2, 1],loc=(-0.14, 1.03),ncol=2, handlelength=2.5, fs = None) except:#[0, 2, 4, 1,3] print('handles not working') #ax_us[-1].legend(loc=(-0.14, 1.03), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1) #axts_all[0].get_shared_y_axes().join(*axts_all) #axts_all[0].get_shared_x_axes().join(*axts_all) #axps_all[0].get_shared_y_axes().join(*axps_all) #axps_all[0].get_shared_x_axes().join(*axps_all) #join_y(axts) #set_same_ylim(axts) #set_same_ylim(axps) #join_x(axts) #join_x(ax_us) #join_y(ax_us) #fig = plt.gcf() save_visualization(cell_here, show) #fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7) #fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7) #fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4) save_visualization(cell_here, show) def vary_contrasts6(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='', show=True): runs = 1 n = 1 dev = 0.001 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=7.5) default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(15.2)) default_ticks_talks() for _ in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str( stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) c_nrs_orig = [0.2] #0.02, 0.0002, 0.05, 0.5 trials_nr = 20 # 20 redo = False # True log = 'log' # 'log' grid0 = gridspec.GridSpec(1, 1, bottom=0.15, top=0.8, left=0.11, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 2, wspace=0.04, hspace=0.1, width_ratios = [2.5,1], subplot_spec=grid0[0]) # height_ratios=[2,1], grid_ll = gridspec.GridSpecFromSubplotSpec(len(c_nrs_orig), 4, hspace=0.75, wspace=0.1, subplot_spec=grid00[0]) # width_ratios=[2, 1],height_ratios=[1, 1],1.2hspace=0.4,wspace=0.2,len(chirps) grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1,wspace=0.04, hspace=0.1,subplot_spec = grid00[1]) # height_ratios=[2,1], ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str( frame_cell.f2.unique()) + ' for cell ' + str(cell_here)) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean', nr=4) print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 c_dist_recalc = dist_recalc_phaselockingchapter() c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc) if c_dist_recalc == False: c_nrs = np.array(c_nrs) * 100 mults_period = 3 start = 200 # 1000 xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] letters = ['A']#, 'B' indexes = [[0, 1, 2, 3]]#[0, 1], [2, 3], for i, index in enumerate(indexes): try: ax_u1 = plt.subplot(grid_rr[i]) except: print('grid search problem') embed() plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2, scores=np.array(scores)[index], labels=np.array(labels)[index], colors=np.array(colors)[index], linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index], alpha=np.array(alpha)[index], sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc) ax_us.append(ax_u1) frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] c1 = c_dist_recalc_here(c_dist_recalc, frame_cell) ax_u1.set_xlim(0, 50) if i != 0: ax_u1.set_ylabel('') remove_yticks(ax_u1) if i < 2: ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]], frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1) ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False) for c_nn, c_nr in enumerate(c_nrs): try: ax_u1.text(c_nr, yposs[i][c_nn] + 50, letters[c_nn], color='black', ha='center', va='top') except: print('assigment thing') embed() axts = [] axps = [] axes = [] p_arrays_all = [] for c_nn, c_nr in enumerate(c_nrs): ################################# # arrays plot save_dir = load_savedir(level=0).split('/')[0] name_psd = save_dir + '_psd.npy' name_psd_f = save_dir + '_psdf.npy' if (not os.path.exists(name_psd)) | (redo == True): if log != 'log': stimulus_length_here = 0.5 nfft_here = 32768 else: stimulus_length_here = 50 trials_nr = 20 nfft_here = 2 ** 22 else: nfft_here = 2 ** 14 stimulus_length_here = 0.5 v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length_here, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) p_arrays_here = p_arrays_p[1::] xlimp = (0, 300) for p in range(len(p_arrays_here)): p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]] ff_p = ff_p[ff_p < xlimp[1]] time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 # plot the first array arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::] arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]# arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]# arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::] colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_all.append(p_arrays_here) for a in range(len(arrays_here)): print('a' + str(a)) if a == 0: freqs = [np.abs(freq1)] # ], np.abs(freq2)], elif a == 1: freqs = [np.abs(freq2)] else: freqs = [np.abs(freq1), np.abs(freq2)] grid_pt = gridspec.GridSpecFromSubplotSpec(5, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[c_nn, a], height_ratios=[1, 0.7, 1, 0.25, 2.5]) # hspace=0.4,wspace=0.2,len(chirps) axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here, c_dist_recalc=c_dist_recalc) if c_dist_recalc == False: a_f2_cm = np.array(a_f2_cm) * 100 if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): fish = 'Three fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$ beat_here = twobeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name( freq=int(freq1), delta=True) + ',\,$c_{1}=%s$' % ( int(np.round(c_nrs[c_nn]))) + '$\%$' + '\n' + f_stable_name( freq=int(freq2), delta=True) + ',\,$c_{2}=%s$' % ( int(np.round(a_f2_cm[0]))) + '$\%$' # +'$' title_name = fish + '\n' + beat_here # +c1+c2 elif a == 0: # elif (a_f1s[0] != 0): beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name( freq=int(freq1), delta=True) # +'$' + ' $\Delta ' fish = 'Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() # +'$' c1 = ',\,$c_{1}=%s$' % (int(np.round(c_nrs[c_nn]))) + '$\%$ \n ' title_name = fish + '\n' + beat_here + c1 # +'cm'+'cm'+'cm' elif a == 1: # elif (a_f2s[0] != 0): beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_stable_name(freq=int(freq2), delta=True) # +'$' fish = '\n Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_stable_name() # +'$' c1 = ',\,$c_{2}=%s$' % (int(np.round(a_f2_cm[0]))) + '$\%$ \n' title_name = fish + '\n' + beat_here + c1 # +'cm' text = False if text: axe.text(1, 1.1, title_name, va='bottom', ha='right', transform=axe.transAxes) ############################# axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1) ############################# axt = plt.subplot(grid_pt[2]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) axp = plt.subplot(grid_pt[-1]) axps.append(axp) if a == 0: axt.show_spines('') if c_nn == 0: axt.xscalebar(0.2, -0.1, 10, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top') f_counter += 1 if (not os.path.exists(name_psd)) | (redo == True): np.save(name_psd, p_arrays_all) np.save(name_psd_f, ff_p) else: ff_p = np.load(name_psd_f) # p_arrays_p p_arrays_all = np.load(name_psd) # p_arrays_p for c_nn, c_nr in enumerate(c_nrs): for a in range(len(arrays_here)): axps_here = [[axps[0], axps[1], axps[2]]]#, [axps[3], axps[4], axps[5]] axp = axps_here[c_nn][a] pp = log_calc_psd(log, p_arrays_all[c_nn][a][0], np.nanmax(p_arrays_all)) markeredgecolors = [] if a == 0: colors_peaks = [color01] # , 'red'] freqs = [np.abs(freq1)] # ], np.abs(freq2)], elif a == 1: colors_peaks = [color02] # , 'red'] freqs = [np.abs(freq2)] else: colors_peaks = [color01_012, color02_012] # , 'red'] freqs = [np.abs(freq1), np.abs(freq2)] markeredgecolors = [color01, color02] plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs, colors_peaks=colors_peaks, xlim=xlimp, markeredgecolor=markeredgecolors, ) if log: scalebar = False if scalebar: axp.show_spines('b') if a == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') axp.set_ylim(-33, 5) else: axp.show_spines('lb') if a == 0: axp.set_ylabel('dB') # , va='center', ha='left' else: remove_yticks(axp) axp.set_ylim(-39, 5) else: axp.show_spines('lb') if a != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) axp.set_xlabel('Frequency [Hz]') axts_all.extend(axts) axps_all.extend(axps) ax_us[-1].legend(loc=(-2.12, 1.03), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) #join_x(ax_us) #join_y(ax_us) fig = plt.gcf() #fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7) #fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7) #fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4) save_visualization(cell_here, show) def vary_contrasts4(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4, gain=1, cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='', show=True): runs = 1 n = 1 dev = 0.001 reshuffled = 'reshuffled' # , # standard combination with intruder small a_f2s = [0.1] min_amps = '_minamps_' dev_name = ['05'] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") if len(cells_here) < 1: cells_here = np.array(model_cells.cell) a_fr = 1 a = 0 trials_nrs = [5] datapoints = 1000 stimulus_length = 2 results_diff = pd.DataFrame() position_diff = 0 plot_style() default_figsize(column=2, length=7.5) for trials_nr in trials_nrs: # +[trials_nrs[-1]] # sachen die ich variieren will ########################################### auci_wo = [] auci_w = [] nfft = 32768 for cell_here in cells_here: full_names = [ 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str( stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal'] c_grouped = ['c1'] # , 'c2'] frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv') frame_cell_orig = frame[(frame.cell == cell_here)] if len(frame_cell_orig) > 0: try: pass except: print('min thing') embed() get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig) c_nrs_orig = [0.02, 0.2] # 0.0002, 0.05, 0.5 trials_nr = 20 # 20 redo = False # True log = 'log' # 'log' grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.93, left=0.11, right=0.95, wspace=0.04) # grid00 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.04, hspace=0.1, subplot_spec=grid0[0]) # height_ratios=[2,1], grid_ll = gridspec.GridSpecFromSubplotSpec(len(c_nrs_orig) + 1, 3, hspace=0.75, wspace=0.1, height_ratios=[1, 1, 0.7], subplot_spec=grid00[ 0]) # 1.2hspace=0.4,wspace=0.2,len(chirps) ################################################################# # calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv # devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05'] # da implementiere ich das jetzt für eine Zelle # wo wir den einezlnen Punkt und Kontraste variieren f_counter = 0 frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig) eodf = frame_cell_orig.f0.unique()[0] f = -1 axts_all = [] axps_all = [] ax_us = [] for freq1, freq2 in freqs: f += 1 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] if len(frame_cell) < 1: freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1 freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2 frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str( frame_cell.f2.unique()) + ' for cell ' + str(cell_here)) labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept( add='_mean', nr=4) print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2)) sampling = 20000 c_dist_recalc = dist_recalc_phaselockingchapter() c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here, c_dist_recalc=c_dist_recalc) if c_dist_recalc == False: c_nrs = np.array(c_nrs) * 100 mults_period = 3 start = 200 # 1000 xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))] letters = ['A', 'B'] indexes = [[0, 1], [2, 3], [0, 1, 2, 3]] for i, index in enumerate(indexes): try: ax_u1 = plt.subplot(grid_ll[-1, i]) except: print('grid search problem4') embed() plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2, scores=np.array(scores)[index], labels=np.array(labels)[index], colors=np.array(colors)[index], linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index], alpha=np.array(alpha)[index], sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc) ax_us.append(ax_u1) frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)] c1 = c_dist_recalc_here(c_dist_recalc, frame_cell) ax_u1.set_xlim(0, 50) if i != 0: ax_u1.set_ylabel('') remove_yticks(ax_u1) if i < 2: ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]], frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1) ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False) # for c_nn, c_nr in enumerate(c_nrs): ax_u1.text(c_nr, yposs[i][c_nn] + 50, letters[c_nn], color='black', ha='center', va='top') # ax_u1.plot([c_nr, c_nr], [0, 435], color='black', linewidth=0.8, clip_on=False) # embed() # plt.show() # embed() # if f != 0: # # remove_yticks(ax_u0) # # remove_yticks(ax_u1) # print('hi') # else: # ax_u1.set_ylabel(power_spectrum_name()) # embed() axts = [] axps = [] axes = [] p_arrays_all = [] for c_nn, c_nr in enumerate(c_nrs): ################################# # arrays plot save_dir = load_savedir(level=0).split('/')[0] name_psd = save_dir + '_psd.npy' name_psd_f = save_dir + '_psdf.npy' if (not os.path.exists(name_psd)) | (redo == True): if log != 'log': stimulus_length_here = 0.5 nfft_here = 32768 else: stimulus_length_here = 50 trials_nr = 20 nfft_here = 2 ** 22 else: nfft_here = 2 ** 14 stimulus_length_here = 0.5 v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail( [freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length_here, model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]], n=n, reshuffled=reshuffled, min_amps=min_amps) p_arrays_here = p_arrays_p[1::] xlimp = (0, 300) for p in range(len(p_arrays_here)): p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]] ff_p = ff_p[ff_p < xlimp[1]] time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling) time = time * 1000 # plot the first array arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::] arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]# arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]# arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::] colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::] p_arrays_all.append(p_arrays_here) for a in range(len(arrays_here)): print('a' + str(a)) if a == 0: freqs = [np.abs(freq1)] # ], np.abs(freq2)], elif a == 1: freqs = [np.abs(freq2)] else: freqs = [np.abs(freq1), np.abs(freq2)] grid_pt = gridspec.GridSpecFromSubplotSpec(5, 1, hspace=0.3, wspace=0.2, subplot_spec=grid_ll[c_nn, a], height_ratios=[1, 0.7, 1, 0.25, 2.5]) # hspace=0.4,wspace=0.2,len(chirps) axe = plt.subplot(grid_pt[0]) axes.append(axe) plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f, f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000 a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here, c_dist_recalc=c_dist_recalc) if c_dist_recalc == False: a_f2_cm = np.array(a_f2_cm) * 100 if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0): fish = 'Three fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$ beat_here = twobeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name( freq=int(freq1), delta=True) + ',\,$c_{1}=%s$' % ( int(np.round(c_nrs[c_nn]))) + '$\%$' + '\n' + f_stable_name( freq=int(freq2), delta=True) + ',\,$c_{2}=%s$' % ( int(np.round(a_f2_cm[0]))) + '$\%$' # +'$' title_name = fish + '\n' + beat_here # +c1+c2 elif a == 0: # elif (a_f1s[0] != 0): beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name( freq=int(freq1), delta=True) # +'$' + ' $\Delta ' fish = 'Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() # +'$' c1 = ',\,$c_{1}=%s$' % (int(np.round(c_nrs[c_nn]))) + '$\%$ \n ' title_name = fish + '\n' + beat_here + c1 # +'cm'+'cm'+'cm' elif a == 1: # elif (a_f2s[0] != 0): beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_stable_name(freq=int(freq2), delta=True) # +'$' fish = '\n Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_stable_name() # +'$' c1 = ',\,$c_{2}=%s$' % (int(np.round(a_f2_cm[0]))) + '$\%$ \n' title_name = fish + '\n' + beat_here + c1 # +'cm' axe.text(1, 1.1, title_name, va='bottom', ha='right', transform=axe.transAxes) ############################# axs = plt.subplot(grid_pt[1]) plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1) ############################# axt = plt.subplot(grid_pt[2]) axts.append(axt) plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time, xlim=xlim) axp = plt.subplot(grid_pt[-1]) axps.append(axp) if a == 0: axt.show_spines('') axt.xscalebar(0.1, -0.1, 10, 'ms', va='right', ha='bottom') axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top') f_counter += 1 if (not os.path.exists(name_psd)) | (redo == True): np.save(name_psd, p_arrays_all) np.save(name_psd_f, ff_p) else: ff_p = np.load(name_psd_f) # p_arrays_p p_arrays_all = np.load(name_psd) # p_arrays_p for c_nn, c_nr in enumerate(c_nrs): for a in range(len(arrays_here)): axps_here = [[axps[0], axps[1], axps[2]], [axps[3], axps[4], axps[5]]] axp = axps_here[c_nn][a] pp = log_calc_psd(log, p_arrays_all[c_nn][a][0], np.nanmax(p_arrays_all)) markeredgecolors = [] if a == 0: colors_peaks = [color01] # , 'red'] freqs = [np.abs(freq1)] # ], np.abs(freq2)], elif a == 1: colors_peaks = [color02] # , 'red'] freqs = [np.abs(freq2)] else: colors_peaks = [color01_012, color02_012] # , 'red'] freqs = [np.abs(freq1), np.abs(freq2)] markeredgecolors = [color01, color02] plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs, colors_peaks=colors_peaks, xlim=xlimp, markeredgecolor=markeredgecolors, ) if log: scalebar = False if scalebar: axp.show_spines('b') if a == 0: axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left') axp.set_ylim(-33, 5) else: axp.show_spines('lb') if a == 0: axp.set_ylabel('dB') # , va='center', ha='left' else: remove_yticks(axp) axp.set_ylim(-39, 5) else: axp.show_spines('lb') if a != 0: remove_yticks(axp) else: axp.set_ylabel(power_spectrum_name()) axp.set_xlabel('Frequency [Hz]') axts_all.extend(axts) axps_all.extend(axps) ax_us[-1].legend(loc=(-2.22, 1.2), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1) axts_all[0].get_shared_y_axes().join(*axts_all) axts_all[0].get_shared_x_axes().join(*axts_all) axps_all[0].get_shared_y_axes().join(*axps_all) axps_all[0].get_shared_x_axes().join(*axps_all) join_y(axts) set_same_ylim(axts) set_same_ylim(axps) join_x(axts) join_x(ax_us) join_y(ax_us) fig = plt.gcf() fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7) fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7) fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4) save_visualization(cell_here, show) def twobeat_cond(big=False, double=False, cond=True): if cond: if not big: val = 'two-beat condition' else: val = 'Two-beat condition' if double: val += ':' else: if not big: val = 'two beats' else: val = 'Two beats' if double: val += ':' return val def colors_susept(add='_mean', nr=4): scores = ['amp_B1_01' + add, 'amp_B1_012' + add, 'amp_B2_02' + add, 'amp_B2_012' + add] # 'amp_B1+B2_012_mean',#($'+f_eod_name_core_rm()+'$ + $f_{p}$)($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)($'+f_eod_name_core_rm()+'$ + $f_{s}$ )($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$) lables = ['$A(\Delta $' + f_vary_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_vary_name(), '$A(\Delta $' + f_vary_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name(), '$A(\Delta $' + f_stable_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_stable_name(), '$A(\Delta $' + f_stable_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name() ] color01 = 'darkred' # 'darkgreen' color02 = 'darkblue' # 'darkblue' color01_012 = 'red' # 'red'#'black'##'lightgreen' # 'blue'# color02_012 = 'cyan' # 'green'#'lightblue'#'grey'# colors = [color01, color01_012, color02, color02_012, 'grey'] colors_array = ['grey', color01, color02, 'purple'] dashed = (0, (nr, nr)) linestyles = ['-', dashed, '-', dashed, dashed] alpha = [1, 1, 1, 1, 1] linewidth = [1.6, 1.4, 1.6, 1.4, 1.4] return lables, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidth def square_part(ax, shrink=1, what=[], end='.pkl', folder='calc_model', full_name='modell_all_cell_no_sinz1_afe1_0.03__afr0_1__afj2_0.1__phaseright__len5_adaptoffset_bisecting_0.995_1.005____ratecorrrisidual35__modelbigfit_nfft4096_StartE1_1_EndE1_1.3_in0.005_StartJ2_1_EndJ2_1.3_in0.005_trialnr20__reshuffled_ThreeDiff_SameOffset'): score = 'auci02_012-auci_base_01' # 'previous_auci02_012-auci_base_01'#['auci_02_012', 'auci_base_01', 'previous_auci02_012-auci_base_01', ] cell_orig = '2013-01-08-aa-invivo-1' # '2012-12-13-an-invivo-1' dev = '_05' # ,'_2','_original','_stim','_isi' mult = '_abs1000' # ,'_mult3' counter = 0 versions = {} if ('auci' not in score) and ('auc' not in score): mult_new = '' else: mult_new = mult if len(what) < 1: what = score + mult_new + dev mat, vers_here, cell, eod_m, fr_rate_mult = define_squares_model_three(what=what, square=[], full_name=full_name, minimum=0, folder=folder, maximum=3, end=end, cell_data=cell_orig, emb=False) lim = find_lims(what, vers_here) versions[what] = vers_here ax.set_aspect('equal') try: power = np.unique(mat['power'])[0] except: print('power thing') embed() plt.suptitle(str(cell_orig) + ' power ' + str(power) + ' dev ' + str(dev)) mult_type = '' pcolor = True im = plt_square(mat, pcolor, mult_type, vers_here, lim) square_labels(mult_type, ax, vers_here, 0) extra_labels = True if extra_labels: ax.set_xlabel('$\Delta \mathrm{f_{Intruder}}$ [Hz]') ax.set_ylabel('$\Delta \mathrm{f_{Female}}$ [Hz]') _, _, _, _, _ = colorbar_outside(ax, im, add=5, delta=0.25, round_digit=2, width=0.01, shrink=shrink) ax.text(1.3 , 0.5, core_scatter_wfemale(), va='center', ha='center', rotation=90 , transform=ax.transAxes) # va = 'center',270 im.set_clim(-0.5, 0.5) counter += 1 def find_lims(what, vers_here): if 'auci' in what: lim = [-0.5, 0.5] elif 'auc' in what: lim = [-1, 1] else: vmax = np.nanpercentile(np.abs(vers_here), 95) vmin = np.nanpercentile(np.abs(vers_here), 5) lims = np.max([vmax, np.abs(vmin)]) lim = [-lims, lims] return lim def cut_matrix_generation(condition, minimum, maximum): index_chosen = condition.index[(condition.index > minimum) & (condition.index < maximum)] column_chosen = condition.columns[(condition.columns > minimum) & (condition.columns < maximum)] condition = condition.loc[index_chosen, column_chosen] return condition, column_chosen, index_chosen def define_squares_model2(a_fe, nr, a_fj, cell_nr, what, step, cell=[], a_fr=1, adapt='adaptoffsetallall2', variant='no', self='', symetric='', resize=True, minimum=0.5, maximum=1.5, dist_type='SimpleDist', redo=False, beat_type='', version_sinz='sinz', varied='emitter', full_name='', emb=False): if full_name == '': name = load_folder_name('calc_model') + '/modell_all_cell_' + variant + '_' + version_sinz + str( nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str( str(a_fj)) + '__length1.5_' + adapt + '___stepefish' + str( step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + beat_type + '.pkl' else: name = load_folder_name('calc_model') + '/' + full_name + '.pkl' if os.path.exists(name): ############################ # Simples GLOBAL scores, like std, amp etc, without temporal inforrmation what_orig = what if 'spike_times' in what: what = 'spike_times' control, condition, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, what, a_fr=a_fr, variant=variant, adapt=adapt, full_name=full_name, version_sinz=version_sinz, resize=resize, symetric=symetric, minimum=minimum, maximum=maximum, beat_type=beat_type, self=self, step=step, cell=cell) base, base_matrix, baseline = load_baseline_matrix(what, cell, condition, a_fr=a_fr) control_afj, DF_e, dict_here, eod_m = get_control(nr, cell_nr, what, 'afj', a_fr=a_fr, adapt=adapt, varied=varied , symetric=symetric, duration=duration, contrast1=a_fe, beat_type=beat_type, contrast2='0', version_sinz=version_sinz, step=step, cell=cell, variant=variant, minimum=minimum, maximum=maximum, self=self) control_afe, DF_e, dict_here, eod_m = get_control(nr, cell_nr, what, 'afe', a_fr=a_fr, adapt=adapt, varied=varied, contrast1='0', duration=duration, symetric=symetric, beat_type=beat_type, contrast2=a_fj, version_sinz=version_sinz, step=step, cell=cell, variant=variant, minimum=minimum, maximum=maximum, self=self) # not found if 'spike_times' in what_orig: ############################# # temporal information if maximum != []: max_name = '_min' + str(minimum) + '_min' + str(maximum) else: max_name = '' name_diff = load_folder_name('calc_model') + '/diffsquare_' + variant + '_' + version_sinz + str( nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str( str(a_fr)) + '__length1.5_' + adapt + '___stepefish' + str( step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + '_' + dist_type + beat_type + max_name 'diffsquare_no_sinz3_afe0.1__afr1__afj1__length1.5_adaptoffsetallall2___stepefish10Hz_ratecorrrisidual35__modelbigfit_nfft4096_SimpleDist_beat_min0.5_min1' if (os.path.exists(name_diff + '.pkl')) and (redo == False): diff_loaded = pd.read_pickle(name_diff + '.pkl') if cell in np.unique(diff_loaded['dataset']): diff_load = diff_loaded[diff_loaded['dataset'] == cell] diff_load.pop('dataset') if '05' in what_orig: dev = '05' elif '2' in what_orig: dev = '2' else: dev = 'original' diff_load = diff_load[diff_load['dev'] == dev] diff_load.pop('dev') versions = {} sorted = retrieve_mat(diff_load, '0-1-2') versions['diff'] = sorted sorted = retrieve_mat(diff_load, '0-1') versions['0-1'] = sorted sorted = retrieve_mat(diff_load, '0-2') versions['0-2'] = sorted diff_load.pop('dist') cont = False print('cell already there') else: cont = True else: diff_loaded = pd.DataFrame() cont = True if cont: print('load diff ' + cell) versions = {} # get parameters control, nfft, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'nfft', a_fr=a_fr, beat_type=beat_type, variant=variant, adapt=adapt, version_sinz=version_sinz, self=self, step=step, cell=cell) control, sampling_rate, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'sampling_rate', beat_type=beat_type, a_fr=a_fr, variant=variant, adapt=adapt, version_sinz=version_sinz, self=self, step=step, cell=cell) control, length, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'length', a_fr=a_fr, beat_type=beat_type, variant=variant, adapt=adapt, version_sinz=version_sinz, self=self, step=step, cell=cell) control, cut_spikes, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'cut_spikes', beat_type=beat_type, a_fr=a_fr, variant=variant, adapt=adapt, version_sinz=version_sinz, self=self, step=step, cell=cell) diff = pd.DataFrame() diff05 = pd.DataFrame() diff2 = pd.DataFrame() diff_1 = pd.DataFrame() diff05_1 = pd.DataFrame() diff2_1 = pd.DataFrame() diff_2 = pd.DataFrame() diff05_2 = pd.DataFrame() diff2_2 = pd.DataFrame() min_array = 0.5 max_array = 1 for i in range(len(condition)): for j in range(len(condition.iloc[0])): print(j) print(i) try: arrays = [condition.iloc[i, j][0], control_afe.iloc[i, j][0], control_afj.iloc[i, j][0], base_matrix.iloc[i, j][0]] except: embed() sampling_rate_here = sampling_rate.iloc[i, j] spikes_mat = {} mat05 = {} mat2 = {} names = ['condition', 'control_afe', 'control_afj', 'base_matrix'] for a, array in enumerate(arrays): spikes_cut, spikes_mat[names[a]], mat05[names[a]], mat2[names[a]] = create_spikes_mat( max_array - min_array, array[(array < max_array) & (array > min_array)] - min_array, sampling_rate_here) # length_here- cut_spikes_here*2 if 'SimpleDist' in dist_type: diff05.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat05['condition'] - mat05['control_afe'] - mat05['control_afj'] + mat05['base_matrix']) diff2.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat2['condition'] - mat2['control_afe'] - mat2['control_afj'] + mat2['base_matrix']) diff.loc[condition.index[i], condition.columns[j]] = np.nanmean( spikes_mat['condition'] - spikes_mat['control_afe'] - spikes_mat['control_afj'] + spikes_mat['base_matrix']) diff05_1.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat05['condition'] - mat05['control_afe']) diff2_1.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat2['condition'] - mat2['control_afe']) diff_1.loc[condition.index[i], condition.columns[j]] = np.nanmean( spikes_mat['condition'] - spikes_mat['control_afe']) diff05_2.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat05['condition'] - mat05['control_afj']) diff2_2.loc[condition.index[i], condition.columns[j]] = np.nanmean( mat2['condition'] - mat2['control_afj']) diff_2.loc[condition.index[i], condition.columns[j]] = np.nanmean( spikes_mat['condition'] - spikes_mat['control_afj']) # todo: here noch ein paar andere Differenzen machen elif 'ConspDist' in dist_type: length_consp = 0.030 * sampling_rate_here shift = 0.005 shift_conditions = np.arange(0, len(mat2['control_afe']), shift * sampling_rate_here) shift_controls = np.arange(0, len(mat2['control_afe']), shift * sampling_rate_here) consp = pd.DataFrame() for s, shift_condition in enumerate(shift_conditions): for ss, shift_control in enumerate(shift_controls): if (int(length_consp + shift_control) < len(mat2['control_afe'])) & ( int(length_consp + shift_condition) < len(mat2['condition'])): consp.loc[s, ss] = np.sqrt(np.mean((mat2['control_afe'][ 0 + int(shift_control):int( length_consp + shift_control)] - mat2['condition'][ 0 + int(shift_condition):int( length_consp + shift_condition)]) ** 2)) embed() diff['dataset'] = cell diff05['dataset'] = cell diff2['dataset'] = cell diff_1['dataset'] = cell diff05_1['dataset'] = cell diff2_1['dataset'] = cell diff_2['dataset'] = cell diff05_2['dataset'] = cell diff2_2['dataset'] = cell diff['dist'] = '0-1-2' diff05['dist'] = '0-1-2' diff2['dist'] = '0-1-2' diff_1['dist'] = '0-1' diff05_1['dist'] = '0-1' diff2_1['dist'] = '0-1' diff_2['dist'] = '0-2' diff05_2['dist'] = '0-2' diff2_2['dist'] = '0-2' diff['dev'] = 'original' diff05['dev'] = '05' diff2['dev'] = '2' diff_1['dev'] = 'original' diff05_1['dev'] = '05' diff2_1['dev'] = '2' diff_2['dev'] = 'original' diff05_2['dev'] = '05' diff2_2['dev'] = '2' if len(diff_loaded) < 1: vertical_stack = pd.concat( [diff, diff05, diff2, diff_1, diff05_1, diff2_1, diff_2, diff05_2, diff2_2, ], axis=0) vertical_stack.to_pickle(name_diff + '.pkl') else: vertical_stack = pd.concat( [diff_loaded, diff, diff05, diff2, diff_1, diff05_1, diff2_1, diff_2, diff05_2, diff2_2, ], axis=0) vertical_stack.to_pickle(name_diff + '.pkl') if '05' in what_orig: dev = '05' elif '2' in what_orig: dev = '2' else: dev = 'original' dev_here = vertical_stack[vertical_stack['dev'] == dev] diff_output = dev_here[dev_here['dist'] == '0-1-2'] diff_output.pop('dist') diff_output.pop('dev') diff_output.pop('dataset') versions['diff'] = diff_output diff_output = dev_here[dev_here['dist'] == '0-1'] diff_output.pop('dist') diff_output.pop('dev') diff_output.pop('dataset') versions['0-1'] = diff_output diff_output = dev_here[dev_here['dist'] == '0-2'] diff_output.pop('dist') diff_output.pop('dev') diff_output.pop('dataset') versions['0-2'] = diff_output versions['eod'] = eod_m else: print('load diff ' + cell) else: diff_output = condition - control_afe - control_afj + base_matrix versions = {} versions['base'] = base_matrix versions['control1'] = control_afe versions['control2'] = control_afj versions['12'] = condition versions['diff'] = diff_output versions['0-1'] = condition - control_afe versions['0-2'] = condition - control_afj versions['eod'] = eod_m else: versions = [] cell = '' eod_m = '' if emb: embed() return versions, cell, eod_m def get_condition(contrast1, nr, contrast2, cell_nr, what, step=60, a_fr=1, adapt='adaptoffsetallall2', variant='no', version_sinz='sinz', full_name='', resize=True, symetric='', SAM='SAM', square=[], three='', length='1.5', duration='', folder='model', minimum=[], maximum=[], f0='f0', f2='f2', f1='f1', self='', beat_type='', cell=[], emb=False, end='.pkl'): # f0 = 'eodf' f2 = 'eodj', f1 = 'eode' if 'csv' in end: # das ist falls wir ein csv haben wie das simplified Threewave protokoll if full_name == '': pass else: pass control = pd.read_csv( '../data/' + folder + '/' + full_name + end, index_col=0) cell_array = control[control.cell == cell] df2 = np.round(cell_array.df2, 2) df1 = np.round(cell_array.df1, 2) cell_array.df2 = df2 cell_array.df1 = df1 # np.unique(cell_array.df2)np.unique(cell_array.df1) condition = cell_array.pivot(index='df2', columns='df1', values=what) # index=['eode', 'nnft'] this will create multiindexing eod_f = cell_array['f0'] if resize == True: fr = cell_array.fr fr_rate_mult = fr / np.mean(cell_array['f0']) else: # das ist für die pkls also vor allem für das nicht simplified protokoll, was am meisten verwendet wurde if full_name == '': control = pd.read_pickle( load_folder_name('calc_model') + +'/modell_all_cell_' + variant + '_' + version_sinz + str( nr) + self + '_afe' + str(contrast1) + '__afr' + str(a_fr) + '__afj' + str( str(contrast2)) + '__length' + str(length) + '_' + adapt + '_' + SAM + '__stepefish' + str( step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + beat_type + symetric + three + end) else: control = pd.read_pickle( load_folder_name('calc_model') + '/' + full_name + end) # '../data/'+#'../data/'+ test = False if test == True: from utils_test import test_controls test_controls() if square != []: sqaure_array = control[control['square'] == square] else: sqaure_array = control if not cell: if cell_nr < len(np.unique(sqaure_array['dataset'])): cell = np.unique(sqaure_array['dataset'])[cell_nr] else: cell = [] if cell != []: cell_array = sqaure_array[sqaure_array['dataset'] == cell] if square != []: base = control[control['square'] == 'base_0'] base_cell = base[base['dataset'] == cell] fr_rate = np.unique(base_cell.mean_fr) fr_rate_mult = fr_rate else: fr_rate_mult = np.mean(cell_array.rate_baseline_after.iloc[ 1::]) # todo: ok als das stimmt schon das after, ab dem zweiten trial das die ursprüngliche baseline print(fr_rate_mult) if (what in cell_array.keys()) and not cell_array.empty: condition = cell_array.pivot(index=f2, columns=f1, values=what) # index=['eode', 'nnft'] this will create multiindexing if square == 'base_0': sqaure_array_012 = control[control['square'] == '012'] cell_array = sqaure_array_012[sqaure_array_012['dataset'] == cell] condition_012 = cell_array.pivot(index=f2, columns=f1, values=what) condition_012[:] = condition.iloc[0, 0] condition = condition_012 DF_1 = np.unique( np.array((cell_array[f1] - cell_array[f0]) / cell_array[ f0] + 1)) DF_2 = np.round(np.unique(np.array( (cell_array[f2] - cell_array[f0]) / cell_array[ f0] + 1)), 3) eod_f = cell_array[f1] if resize == True: fr_rate_mult = fr_rate_mult / np.mean(cell_array[f0]) dict_here = dict(zip(np.unique(cell_array[f1]), np.round(DF_1, 3))) condition = condition.rename(columns=dict_here) condition = condition.set_index(DF_2) condition.columns.name = f1 + str('-f0') # 'fish2-fish0 $f_{stim}/'+f_eod_name_core_rm()+'$' # 'DeltaF-eodj-eodf' condition.index.name = f2 + str('-f0') # 'fish1-fish0 $f_{stim}/'+f_eod_name_core_rm()+'$' # 'DeltaF-eode-eodf' if maximum != []: condition, column_chosen, index_chosen = cut_matrix_generation(condition, minimum, maximum) else: condition = [] cell = [] eod_f = [] fr_rate_mult = [] else: condition = [] cell = [] eod_f = [] fr_rate_mult = [] if emb: embed() return control, condition, cell, eod_f, fr_rate_mult def define_squares_model_three(emb=False, a_fe=0.1, nr=3, a_fj=0.1, cell_nr=0, what='std', step=50, cell_data=[], a_fr=1, adapt='adaptoffsetallall2', variant='no', square=[], full_name='', self='', length=0.5, SAM='SAM', resize=True, cell=[], duration='', symmetric='', three='ThreeDiff', minimum=[], maximum=[], beat_type='', folder='model', end='.pkl', version_sinz='sinz'): if full_name == '': name = load_folder_name('calc_model') + '/modell_all_cell_' + variant + '_' + version_sinz + str( nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str( str(a_fj)) + '__length' + str(length) + '_' + adapt + '_' + SAM + '__stepefish' + str( step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + symmetric + three + end else: name = load_folder_name('calc_model') + '/' + full_name + end # '../data/''../data/'+ condition = [] eod_f = [] fr_rate_mult = [] control = [] print(name) if os.path.exists(name): ############################ # Simples GLOBAL scores, like std, amp etc, without temporal inforrmation control, condition, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, what, a_fr=a_fr, variant=variant, adapt=adapt, full_name=full_name, version_sinz=version_sinz, SAM=SAM, symetric=symmetric, folder=folder, resize=resize, end=end, square=square, duration=duration, length=length, three=three, minimum=minimum, maximum=maximum, beat_type=beat_type, self=self, step=step, cell=cell_data, emb=False) if len(condition) > 0: if condition.iloc[0].dtype == complex: condition = np.abs(condition) if emb: embed() return control, condition, cell, eod_f, fr_rate_mult def plt_square(control, pcolor, mult_type, vers, lim): if pcolor: if 'mult' in mult_type: axs = plt.pcolormesh( np.array(list(map(float, vers.columns))), np.array(vers.index), vers, vmin=lim[0], vmax=lim[1], cmap="RdBu_r") else: axs = plt.pcolormesh( (np.array(list(map(float, vers.columns))) - 1) * control.f0.iloc[0], (np.array(vers.index) - 1) * control.f0.iloc[0], vers, vmin=lim[0], vmax=lim[1], cmap="RdBu_r") else: try: axs = plt.imshow(vers, origin='lower', cmap="RdBu_r", vmin=lim[0], vmax=lim[1], extent=[np.min(vers.columns), np.max(vers.columns), np.min(vers.index), np.max(vers.index)]) except: print('axs problenm') embed() return axs def square_labels(mult_type, ax, vers_here, w): if 'mult' in mult_type: if '2' in vers_here.index.names: ax.set_xlabel('m2') else: ax.set_xlabel('m1') else: if '2' in vers_here.index.names: ax.set_xlabel('Beat2 [Hz]') if w == 0: ax.set_ylabel('Beat1 [Hz]') else: ax.set_xlabel('Beat1 [Hz]') if w == 0: ax.set_ylabel('Beat2 [Hz]') def figsize_ROC_start(): return [column2(), 2.8] # 3.03.370,2.92.73.5# 13.5/7 = 1.9285, 6.5/1.9285 = 3.3704 def plt_several_ROC_square_nonlin_single(shrink=0.5, top=0.9, loc=(0.4, 0.8), fs=14, defaultst=True, figsize=(13.5, 6.5), ): xlim = core_xlim_dist_roc() if defaultst: default_settings(width=12, ts=20, ls=20, fs=20) if figsize: fig = plt.figure(figsize=figsize) colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all() frame_names, trial_nr = core_decline_ROC() 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal', short = True frame_names_both = [[ core_decline_ROC(trial_nr='20', short=short, absolut=True)[0][0]], [core_decline_ROC(trial_nr='20', pos=False, short=short, absolut=True)[0][0]]] print(frame_names_both) cells = [ "2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ] cells_chosen = [ '2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1', grid = gridspec.GridSpec(2, 4, wspace=0.45, width_ratios=[0.27, 0.27, 0, 0.7], hspace=0.5, left=0.1, top=top, bottom=0.17, right=0.9, ) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8, df1 = [] df2 = [] axes = [] for c, cell in enumerate(cells_chosen): for ff, frame_names in enumerate(frame_names_both): grid0 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.15, hspace=0.25, subplot_spec=grid[:, ff]) # height_ratios=[1, 0.7, 1, 1], ax1 = plt.subplot(grid0[0]) frame_cell = plt_gain_area(ax1, c, cell, cells, colors_w, colors_wo, df1, df2, ff, frame_names, fs, loc, xlim) axes.append(ax1) if ff != 0: ax1.set_ylabel('') else: ax1.set_ylabel('AUC') ax1.set_xlabel('') ############################### # roc nonlin part ax2 = plt.subplot(grid0[1]) plt_nonlin_line(ax2, cell, ff, frame_cell, xlim) # squares ax = plt.subplot(grid[:, 3]) axes.append(ax) full_name = 'modell_all_cell_no_sinz1_afe1_0.03__afr0_1__afj2_0.1__phaseright__len5_adaptoffset_bisecting_0.995_1.005____ratecorrrisidual35__modelbigfit_nfft4096_StartE1_1_EndE1_1.3_in0.005_StartJ2_1_EndJ2_1.3_in0.005_trialnr20__reshuffled_ThreeDiff_SameOffset' square_part(ax, shrink=shrink, full_name=full_name) max_val = 236 ax.plot([-2, 0], [50, max_val], color='black', linewidth=1) ax.plot([max_val, max_val], [50, max_val], color='black', linewidth=1) ax.plot([-2, max_val], [50, 50], color='black', linewidth=1) ax.plot([-2, max_val], [max_val, max_val], color='black', linewidth=1) third_diagonal = True if third_diagonal: df1.append(df1[0]) df2.append(df2[0] + np.abs(frame_cell.fr.iloc[0] - df2[0]) * 2) plt_circles_matrix(ax, df1, df2) plt.suptitle('') fig.tag(axes, xoffs=-3.6, yoffs=2.7, ) save_visualization(jpg=True, png=False) plt.show() def plt_circles_matrix(ax, df1, df2, scat=True): titles = [r'$\numcircled{1}$', r'$\numcircled{2}$', r'$\numcircled{3}$'] for d in range(len(df1)): if d == 2: ax.text(df1[d] + 5, df2[d] + 5, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes elif d == 1: ax.text(df1[d] + 5, df2[d] + 5, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes if scat: ax.scatter(df1[d], df2[d], facecolors='none', edgecolor='black', marker='s') elif d == 0: ax.text(df1[d] + 5, df2[d] - 15, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes if scat: ax.scatter(df1[d], df2[d], facecolors='none', edgecolor='black', marker='s') def colors_cocktailparty_all(): color_base, color_01, color_02, color_012 = colors_cocktailparty() colors_w = [color_012] colors_wo = [color_01] return colors_w, colors_wo, color_base, color_01, color_02, color_012 def plt_gain_area(ax1, c, cell, cells, colors_w, colors_wo, df1, df2, ff, frame_names, fs, loc, xlim): for f, frame_name in enumerate(frame_names): path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv' if os.path.exists(path): frame = pd.read_csv(path) path_ref = load_folder_name( 'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) _, _ = find_row_col(cells, row=4) frame_cell = frame[frame.cell == cell] nrs = int(frame_name.split('LenNrs_')[1].split('_')[0]) frame_cell = frame_cell.iloc[0:nrs] df1.append(np.mean(frame_cell.df1.unique())) df2.append(np.mean(frame_cell.df2.unique())) axs = [ax1] label_f = ['with female', 'CLS: 100n with female', 'LS: 1000n with female', ] label_f2 = ['without female', 'CLS: 100n without female', 'LS: 1000n without female', ] labels = [label_f[f], label_f[f]] labels2 = [label_f2[f], label_f2[f]] for a, ax in enumerate(axs): if len(frame_cell) > 0: plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f, labels_with_female=labels[a], labels_without_female=labels2[a]) ax.set_xlim(xlim) titles = [ ' $ \Delta \mathrm{f_{Female}} + \Delta \mathrm{f_{Intruder}} $\n' + '$ = \mathrm{f'+basename()+'}$' + r'$\numcircled{1}$ ', ' $ \Delta \mathrm{f_{Female}} + \Delta \mathrm{f_{Intruder}} $\n ' + r' $\neq \mathrm{f'+basename()+'}$' + r'$\numcircled{2}$ '] ax.set_title(titles[ff]) if ff == 1: if a == 0: try: ax.legend(loc=loc, fontsize=fs, ncol=1) except: print('legend something') embed() else: ax.legend(loc=loc, fontsize=fs, ncol=1) ax.set_ylim(0, 0.52) ax.set_yticks_delta(0.25) if c != 0: remove_yticks(ax) ax.show_spines('lb') if a == 0: ax.set_ylabel('AUC') remove_xticks(ax) if ff != 0: remove_yticks(ax) return frame_cell def plt_nonlin_line(ax2, cell, ff, frame_cell, xlim): c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True) talk = False if talk: ax2.plot(c1, sum_score(frame_cell), color='black', clip_on=True) # , linewidth=lw else: score = val_new_core(frame_cell) # , color='black', clip_on=True, linewidth = 0.75) score[score < 0] = 0 ax2.plot(c1, score, color='black', clip_on=True) # , linewidth = 0.75 ax2.set_ylim(0, 7) ax2.set_xlim(xlim) if ff != 0: remove_yticks(ax2) if ff == 0: ax2.set_ylabel(peak_b1b2_name()) # nonlin_title() ax2.set_xlabel(core_distance_label()) def val_new_core(frame_cell): return frame_cell['amp_B1+B2_012-01-02+0_norm_01B1_mean'] def sum_score(frame_cell): return frame_cell['amp_B1+B2_012-01-02+0_norm_01B1_mean'] def find_variable_from_savename(full_name, name='trialnr'): verb_length = len(name) name_start = full_name.find(name) name_end = name_start + verb_length line_pos = full_name.find('_', name_start) trials_nr = full_name[name_end:line_pos] # trials_nr return trials_nr, line_pos, name_end, name_start, def freq_two_mult_recalc(frame_cell_orig, freqs): freqs = [((freqs[0][0] - 1) * frame_cell_orig.f0.iloc[0], (freqs[0][1] - 1) * frame_cell_orig.f0.iloc[0])] return freqs def figure_out_score_and_add(add, c_here, frame_cell): score1, score2, score3 = score_choice(c_here, add) if score1 not in frame_cell.keys(): add = '' score1, score2, score3 = score_choice(c_here, add) if score1 not in frame_cell.keys(): add = '' score1, score2, score3 = score_choice(c_here, add) return add, score1, score2, score3 def plt_matrix_saturation_loss(ax, frame_cell, c_here='c1', add='', ims=[], ims_diff=[], imshow=False, xlabel=True): add, score1, score2, score3 = figure_out_score_and_add(add, c_here, frame_cell) cls = ["RdBu_r", "RdBu_r", "RdBu_r"] for ss, score_here in enumerate([score1, score2, score3]): new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).mean() # ['score'] matrix = new_frame.pivot(index='df2', columns='df1', values=score_here) if ss == 2: lim = np.max([np.max(matrix), np.abs(np.min(matrix))]) ax[ss].set_title(score_here) if imshow: im = ax[2].imshow(matrix, origin='lower', cmap="RdBu_r", vmin=-lim, vmax=lim, extent=[np.min(matrix.columns), np.max(matrix.columns), np.min(matrix.index), np.max(matrix.index)]) else: try: im = ax[ss].pcolormesh( np.array(list(map(float, matrix.columns))), np.array(matrix.index), matrix, cmap=cls[ss], rasterized=False) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95), except: print('ims probelem') embed() if ss == 2: im.set_clim(-lim, lim) ims_diff.append(im) else: ims.append(im) cbar = plt.colorbar(im, ax=ax[ss]) cbar.set_label(score_here, labelpad=100) # rotation=270, ax[0].set_ylabel('df2') if xlabel: ax[ss].set_xlabel('df1') return lim, matrix, ss, ims def score_choice(c_here, add=''): if c_here == 'c1': # 'B1_diff' score1 = 'amp_B1_012_mean' + add score2 = 'amp_B1_01_mean' + add score3 = 'diff' if c_here == 'c1': # 'B1_diff' score1 = 'amp_B1_012_mean' + add score2 = 'amp_B1_01_mean' + add score3 = 'diff' return score1, score2, score3 def area_vs_single_peaks_frame(frame_cell, cs_type='all'): if cs_type != 'all': frame_cell = frame_cell[(frame_cell.c1 == cs_type) & (frame_cell.c2 == cs_type)] return frame_cell def plt_cross(matrix, ax, type='small'): zero_val = np.min([np.min(matrix.index) + 1, np.min(matrix.columns) + 1]) max_val = np.max([np.max(matrix.index) - 1, np.max(matrix.columns) - 1]) x_axis = np.arange(zero_val, max_val, 1) y_axis = np.arange(zero_val, max_val, 1) if type == 'big': ax.plot(x_axis, y_axis, color='black', zorder=100) ax.plot(np.array(x_axis), -np.array(y_axis), color='black', zorder=100) ax.set_xlim(x_axis[0], x_axis[-1]) ax.set_ylim(y_axis[0], y_axis[-1]) elif type == 'small': for y_axis_here in [y_axis, -y_axis]: restrict = ((x_axis > np.min(matrix.columns)) & (x_axis < np.max(matrix.columns))) & ( (y_axis_here > np.min(matrix.index)) & (y_axis_here < np.max(matrix.index))) ax.plot(x_axis[restrict], y_axis_here[restrict], color='black', zorder=100) ax.set_xlim(np.min(matrix.columns), np.max(matrix.columns)) ax.set_ylim(np.min(matrix.index), np.max(matrix.index)) elif type == 'zerozentered': x_rt = np.max(matrix.index) x_lb = np.min(matrix.index) y_rt = np.max(matrix.columns) y_lb = np.min(matrix.columns) ax.plot([0, x_rt], [0, x_rt], color='black', zorder=100) ax.plot([0, x_lb], [0, x_lb], color='black', zorder=100) ax.plot([0, y_rt], [0, y_rt], color='black', zorder=100) ax.plot([0, y_lb], [0, y_lb], color='black', zorder=100) ax.plot([0, -x_rt], [0, x_rt], color='black', zorder=100) ax.plot([0, -x_lb], [0, x_lb], color='black', zorder=100) ax.plot([0, -y_rt], [0, y_rt], color='black', zorder=100) ax.plot([0, -y_lb], [0, y_lb], color='black', zorder=100) ax.plot([0, x_rt], [0, -x_rt], color='black', zorder=100) ax.plot([0, x_lb], [0, -x_lb], color='black', zorder=100) ax.plot([0, y_rt], [0, -y_rt], color='black', zorder=100) ax.plot([0, y_lb], [0, -y_lb], color='black', zorder=100) embed() def plt_square_row(orientation, matrix_extend, dev, name, fig, title_pos, shrink, auc_lim, addsteps, counter_contrast, cut, dfs, gridspacing, scores, contr2, cc, grid0, cell, combinations, matrix_sorted='grid_sorted', xlim=[], ylim=[], text=True): axes = [] y_max = [] y_min = [] x_max = [] x_min = [] c = -1 for contrast11, contrast22 in combinations: print('print: ' + str(cell)) grid1 = gridspec.GridSpecFromSubplotSpec(1, int(len(combinations)), hspace=0.7, wspace=0.2, subplot_spec= grid0 [cc]) # hspace=0.4,wspace=0.2,len(chirps) if os.path.exists(name): c += 1 df = pd.read_pickle(name) if len(df) > 0: try: frame = df[(df['cell'] == cell) & ( df[contr2] == contrast22) & (df['dev'] == dev)] # except: print('contrast think') embed() pivots = {} if len(frame) > 0: if 'auc' not in scores[0]: grid2 = gridspec.GridSpecFromSubplotSpec(1, len(scores) + 1, hspace=0.4, wspace=0.4, subplot_spec= grid1[c]) # hspace=0.4,wspace=0.2,len(chirps) trans = False ax, pivots_scores, orientation, limit_diff, axs = plt_ingridients(trans, gridspacing, scores, scores, xlim, ylim, cut, grid2, frame, dfs, matrix_sorted, orientation, matrix_extend, colorbar=False) axes.append(ax) if len(pivots_scores) > 0: try: y_max.append(np.max(pivots_scores[scores[0]].index)) except: print('y_max') embed() y_min.append(np.min(pivots_scores[scores[0]].index)) x_max.append(np.max(pivots_scores[scores[0]].columns)) x_min.append(np.min(pivots_scores[scores[0]].columns)) else: grid2 = gridspec.GridSpecFromSubplotSpec(1, len(scores) + 2, hspace=0.4, wspace=0.4, subplot_spec= grid1[c]) # hspace=0.4,wspace=0.2,len(chirps) #################################### # plot the auc parts if counter_contrast == 0: colorbar = True else: colorbar = False orientation, ax, limit_diff, axs, pivots = plt_auc(grid2, matrix_sorted, orientation, shrink, cut, xlim, ylim, pivots, trans, gridspacing, matrix_extend, dfs, frame, [scores], cell, auc_lim, addsteps, scores, 0, contr2, contrast22, pad=0.3, bar_orientation="horizontal", colorbar=colorbar, fig=fig) pivots_scores = pivots axes.append(ax) y_max.append(np.max(pivots[scores[0]].index)) y_min.append(np.min(pivots[scores[0]].index)) x_max.append(np.max(pivots[scores[0]].columns)) x_min.append(np.min(pivots[scores[0]].columns)) for _ in ax: ax.axvline(1, color='grey', linewidth=0.5, ) ax.axhline(1, color='grey', linewidth=0.5, ) if c != len(combinations) - 1: for a in ax: if a != 0: try: ax[a].set_xticks([]) except: a.set_xticks([]) else: # if cell_counter in np.arange(row*col-col,row*col,1): chose_xlabel_roc_matrix(ax, contrast11, contrast22, contr2, pivots_scores, scores) for s in range(len(scores)): if s != 0: ax[s].set_yticks([]) else: # if cell_counter in np.arange(0,row*col,col): if c == len(combinations) - 1: chose_ylabel_ROC_matrix(ax, contrast11, contrast22, contr2, pivots_scores, s, scores) if text: if c == 0: ax[0].text(0, title_pos, cell, transform=ax[0].transAxes, fontweight='bold') ax[0].text(0 , 1.2, 'C1 ' + str( contrast22) + ' C2 ' + str( contrast11), transform=ax[3].transAxes, fontweight='bold') # ax[0].set_title('') ax[1].set_title('') ax[2].set_title('') ax[3].set_title('') ax[0].axvline(0, color='grey', linewidth=0.5, ) ax[0].axhline(0, color='grey', linewidth=0.5, ) ax[1].axvline(0, color='grey', linewidth=0.5, ) ax[1].axhline(0, color='grey', linewidth=0.5, ) ax[2].axvline(0, color='grey', linewidth=0.5, ) ax[2].axhline(0, color='grey', linewidth=0.5, ) ax[3].axvline(0, color='grey', linewidth=0.5, ) ax[3].axhline(0, color='grey', linewidth=0.5, ) if len(ax) > 4: ax[4].set_title('') set_same_lim(xlim, ylim, y_min, y_max, x_min, x_max, axes) return axes, y_max, y_min, x_max, x_min def chose_ylabel_ROC_matrix(ax, contrast11, contrast22, contrastc2, pivots_scores, s, scores): if (('1' in contrastc2) & ( '1' in pivots_scores[scores[0]].index.name)) | ( (('2' in contrastc2) & ( '2' in pivots_scores[scores[0]].index.name))): ax[s].set_ylabel( pivots_scores[scores[0]].index.name + ' ' + str( contrast22) + '%', labelpad=-25) else: ax[s].set_ylabel( pivots_scores[scores[0]].index.name + ' ' + str( contrast11) + '%', labelpad=-25) def chose_xlabel_roc_matrix(ax, contrast11, contrast22, contrastc2, pivots_scores, scores): if (('1' in contrastc2) & ('1' in pivots_scores[scores[0]].columns.name)) | ( ('2' in contrastc2) & ( '2' in pivots_scores[scores[0]].columns.name)): ax[0].set_xlabel( pivots_scores[scores[0]].columns.name + ' ' + str( contrast22) + '%', labelpad=-15) else: ax[0].set_xlabel( pivots_scores[scores[0]].columns.name + ' ' + str( contrast11) + '%', labelpad=-15) def set_same_lim(xlim, ylim, y_min, y_max, x_min, x_max, axes): if len(xlim) > 0: ylim_here = ylim xlim_here = xlim else: ylim_here = [np.min(y_min) * 0.99, np.max(y_max) * 1.01] xlim_here = [np.min(x_min) * 0.99, np.max(x_max) * 1.01] for aa in range(len(axes)): for a in axes[aa]: try: axes[aa][a].set_ylim(ylim_here) axes[aa][a].set_xlim(xlim_here) except: a.set_ylim(ylim_here) a.set_xlim(xlim_here) def plt_ingridients(trans, gridspacing, pivots_diff, scores, xlim, ylim, cut, grid_orig2, frame, dfs, matrix_sorted='grid_sorted', orientation='f1 on x, f2 on y', matrix_extent='min', title=True, colorbar_title=True, fig=[], colorbar=False, pad=0.1, bar_orientation='vertical', cl_outside=False): ax = {} pivots_scores = {} pivots_min = [] pivot = [] limit_diff = [] axs = [] if len(frame) > 0: for s, score in enumerate(scores): if score in frame.keys(): pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, score, matrix_extent=matrix_extent, matrix_sorted=matrix_sorted, orientation=orientation, gridspacing=gridspacing, dfs=dfs) if trans: # pivot = np.transpose(pivot) if 'var' in score: pivot = np.sqrt(pivot) pivots_scores[score] = pivot pivots_min.append(pivot) else: if len(pivot) > 0: pivots_scores[score] = np.ones_like(pivot) symbol = ['$-$', '$-$', '$+$', '$=$'] for s, score in enumerate(scores): if score in frame.keys(): ax[s] = plt.subplot(grid_orig2[s]) ax[s].text(1.5, 0.5, symbol[s], fontsize=15, va='center', ha='center', transform=ax[s].transAxes) # ha='center', va='center', vmax = np.nanmax(pivots_min) vmin = np.nanmin(pivots_min) axs = plt.imshow(pivots_scores[score], extent=[np.min(pivots_scores[score].columns), np.max(pivots_scores[score].columns), np.min(pivots_scores[score].index), np.max(pivots_scores[score].index)], vmax=vmax, vmin=vmin, origin='lower') if title: plt.title(score, fontsize=7) if colorbar: if cl_outside: _, _, _, _, _ = colorbar_outside(ax[s], axs, fig, orientation='bottom') if colorbar_title: ax[s].text(0.2, -1.4, score, transform=ax[s].transAxes) # , va = 'center', ha = 'center' else: plt.colorbar(orientation=bar_orientation, pad=pad) if cut: ax[s].set_ylim(ylim) ax[s].set_xlim(xlim) ############################### # plot sum of both if len(scores) == len(pivots_scores): ax[s + 1] = plt.subplot(grid_orig2[4]) diff = pivots_scores[scores[0]] - pivots_scores[scores[1]] - pivots_scores[scores[2]] + pivots_scores[ scores[3]] min = np.min(np.min(diff)) max = np.max(np.max(diff)) lim = np.max([np.abs(min), np.abs(max)]) axs = plt.imshow(diff, vmin=-lim, vmax=lim, origin='lower', cmap="RdBu_r", extent=[np.min(diff.columns), np.max(diff.columns), np.min(diff.index), np.max(diff.index)]) limit_diff = np.array([np.min(np.min(diff)), np.max(np.max(diff))]) plt.yticks([]) if cut: ax[s + 1].set_ylim(ylim) ax[s + 1].set_xlim(xlim) if scores != pivots_diff[-1]: # plt.xticks([]) if colorbar: if cl_outside: _, _, _, _, _ = colorbar_outside(ax[s + 1], axs, fig, orientation='bottom', top=True) else: plt.colorbar(orientation=bar_orientation, pad=pad) if len(ax) > 0: ax[s + 1].scatter(1, 1, marker='o', facecolors='none', edgecolors='black') return ax, pivots_scores, orientation, limit_diff, axs def plt_auc(grid0, matrix, orientation, shrink, cut, xlim, ylim, pivots, trans, gridspacing, start, dfs, df_datapoint, pivots_diff, cell, auc_lim, addsteps, scores, di, contrastc2, contrast22, add=0, bar_orientation="vertical", pad=0.1, colorbar=True, cl_outside=True, fig=[]): good_cells = [ '2022-01-06-ai-invivo-1', '2022-01-06-ag-invivo-1', '2022-01-08-ad-invivo-1', '2022-01-08-ah-invivo-1', '2021-07-06-ab-invivo-1', '2021-08-03-ac-invivo-1', ] # if cell in good_cells: pass else: pass if addsteps == True: pass else: pass counter = 0 ax_cont = [] for ss, score in enumerate(scores): what = score # + '_'+dev df_contrast = df_datapoint[df_datapoint[contrastc2] == contrast22] if len(df_contrast) > 0: if score in df_contrast: pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(df_contrast, what, matrix_extent=start, matrix_sorted=matrix, orientation=orientation, gridspacing=gridspacing, dfs=dfs) # 35 try: pass except: pass if trans: pivot = np.transpose(pivot) pivots[score] = pivot if addsteps == True: ax = plt.subplot(grid0[counter]) counter += 1 plt.title(' ' + str(what), fontsize=8) # +' dev'+str(dev) ax_cont.append(ax) if 'auci' in score: vmin = -0.5 vmax = 0.5 elif 'auc' in score: vmin = 0 vmax = 1 else: vmax = np.nanmax(pivot) vmin = np.nanmin(pivot) axs = plt.imshow(pivot, vmin=vmin, vmax=vmax, origin='lower', cmap="RdBu_r", extent=[np.min(pivot.columns), np.max(pivot.columns), np.min(pivot.index), np.max(pivot.index)]) if cut: plt.xlim(xlim) plt.ylim(ylim) if not ((di == 1) & (ss == 0)): pass else: plt.xlabel('EOD mult 2') plt.ylabel('EOD mult 1') if colorbar: if cl_outside: _, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add, top=True) # colorbar_outside if di != len(pivots_diff) - 1: plt.xticks([]) if ss != 0: plt.yticks([]) symbol = ['$-$', '$=$'] ax.text(1.5, 0.5, symbol[ss], fontsize=15, va='center', ha='center', transform=ax.transAxes) # ha='center', if cut: ax.set_ylim(ylim) ax.set_xlim(xlim) else: print(str(score) + ' score not found') if di == 0: if addsteps == False: ax.text(1 + ws * 1, 0.5, '=', va='center', ha='center', fontsize=15, transform=ax.transAxes) # ha='center',# not found name = pivots_diff[di] ax = plt.subplot(grid0[counter]) ax_cont.append(ax) if name[0] in pivots.keys(): pivot_diff = pivots[name[0]] title = name[0] for p in range(1, len(name), 1): if name[p] in pivots.keys(): pivot_diff = pivot_diff - pivots[name[p]] title = title + ' - ' + name[p] limit_diff = np.array([np.min(np.min(pivot_diff)), np.max(np.max(pivot_diff))]) axs = plt.imshow(pivot_diff, vmin=-0.5, vmax=0.5, origin='lower', cmap="RdBu_r", extent=[np.min(pivot_diff.columns), np.max(pivot_diff.columns), np.min(pivot_diff.index), np.max(pivot_diff.index)]) if cut: plt.xlim(xlim) plt.ylim(ylim) plt.ylabel('') plt.xlabel('') if colorbar: if cl_outside: _, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add, top=True) else: plt.colorbar(shrink=shrink, orientation=bar_orientation, pad=pad) if di != len(pivots_diff) - 1: plt.xticks([]) if ss != 0: plt.yticks([]) if auc_lim != 'nonlinm': ax = plt.subplot(grid0[counter + 1]) ax_cont.append(ax) vmax = np.nanpercentile(np.abs(pivot_diff), 95) vmin = np.nanpercentile(np.abs(pivot_diff), 5) lim = np.max([vmax, np.abs(vmin)]) axs = plt.imshow(pivot_diff, vmin=-lim, vmax=lim, origin='lower', cmap="RdBu_r", extent=[np.min(pivot_diff.columns), np.max(pivot_diff.columns), np.min(pivot_diff.index), np.max(pivot_diff.index)]) if cut: plt.xlim(xlim) plt.ylim(ylim) plt.ylabel('') plt.xlabel('') if di != len(pivots_diff) - 1: plt.xticks([]) if ss != 0: plt.yticks([]) if colorbar: if cl_outside: _, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add, top=True) else: plt.colorbar(orientation=bar_orientation, pad=pad) # shrink=shrinkorientation = 'horizontal' counter += 1 if '*' in score: pass else: limit_diff = [] axs = [] return orientation, ax_cont, limit_diff, axs, pivots def condition_for_roc_thesis(): global diagonal, freq1_ratio, freq2_ratio, plus_q, way, length combis = diagonal_points() diagonal = 'B1+B2_diagonal2' # 'B1+B2_diagonal'#'diagonal11'#'test_data_cell_2022-01-05-aa-invivo-1' freq1_ratio = combis[diagonal][0] freq2_ratio = combis[diagonal][1] plus_q = 'plus' # 'minus'#'plus'##'minus' way = 'mult_minimum_1' # 'mult'#'absolut' ways = ['absolut'] # das hier brauchen wir # doch das brauchen wir hier sonst klappt das nicht mit dem ROC! length = 25 # 20 # 5 trials_nr = 20 # 100 return trials_nr, length, ways, way, plus_q, freq2_ratio, freq1_ratio, diagonal def save_RAM_to_csv(data_name, spikes_data_short, end=''): file_name, spikes, spikes_selected = save_RAM_spikes_core(data_name, end, spikes_data_short) save_RAM_overview_csv(data_name, end, spikes_data_short) file_name, spikes_selected = save_RAM_both_csv(data_name, spikes, spikes_data_short) save_RAM_eod_to_csv(data_name, spikes_selected) def save_RAM_both_csv(data_name, spikes, spikes_data_short): amp, file_name = get_min_amp_and_first_file(spikes_data_short, min_find=True) spikes_selected = spikes_data_short[(spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] # eod_path = load_only_spikes_RAM(data_name=data_name, emb=False, core_name='calc_RAM_data_eod_extra__first1_order__') if os.path.exists(eod_path): eod_data_short = pd.read_pickle(eod_path) amp, file_name = get_min_amp_and_first_file(eod_data_short, min_find=True) eod_selected = eod_data_short[(eod_data_short.amp == amp) & (eod_data_short.file_name == file_name)] # frame_cell = save_spikes_csv(eod_selected, spikes, spikes_selected) frame_cell.to_pickle('calc_RAM/spikes_and_eod_' + data_name + '.pkl') return file_name, spikes_selected def save_RAM_overview_csv(data_name, end, spikes_data_short): # , name = '' amps, file_names = get_min_amp_and_first_file(spikes_data_short) path_sascha = load_folder_name('calc_base') + '/' + 'calc_base_data-base_frame_overview.pkl' frame = pd.read_pickle(path_sascha) frame_c = frame[frame.cell == data_name] for a, amp in enumerate(amps): for file_name in file_names: frame_ov = pd.DataFrame() spikes_selected = spikes_data_short[ (spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] # spikes = get_array_from_pandas(spikes_selected['spikes'], abs=False) if len(spikes_selected) > 0: frame_ov['file_name'] = spikes_selected.file_name frame_ov['sampling'] = spikes_selected.sampling frame_ov['cell'] = spikes_selected.cell # das entspricht der abschätzung aus der baseline, aber ich könnte das auch aus dem RAM global EOD berechnen frame_ov['eod_fr'] = spikes_selected.eod_fr frame_ov['species'] = frame_c.species.iloc[0] lim = find_lim_here(data_name, 'individual') frame_ov['burst_corr_individual'] = lim frame_ov['cell_type_reclassified'] = frame_c.cell_type_reclassified.iloc[0] spikes, pos_reshuffled = reshuffle_spike_lengths(spikes) names = names_eodfs() vars = [] for name in names: vars.append(spikes_selected[name].iloc[0]) frame_ov = reshuffle_eodfs(frame_ov, names, pos_reshuffled, vars, res_name='res') cell_type = frame_c.cell_type_reclassified.iloc[0] name = end_calc_ram_name_eod(end='-overview_') + end_calc_ram_name(data_name, end, file_name, amp, cell_type='_' + cell_type, species=frame_c.species.iloc[ 0]) # data_name + end +'_amp_'+str(amp)+'_filename_'+str(file_name)+ '.csv' frame_ov.to_csv(name) del frame def save_RAM_spikes_core(data_name, end, spikes_data_short, cell_type='', species=''): amps, file_names = get_min_amp_and_first_file(spikes_data_short) for a, amp in enumerate(amps): for file_name in file_names: spikes_selected = spikes_data_short[ (spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] # spikes = get_array_from_pandas(spikes_selected['spikes'], abs=False) if len(spikes) > 0: spikes_df = pd.DataFrame() spikes, pos_reshuffled = reshuffle_spike_lengths(spikes) try: spikes_df = save_spikestrains_several(spikes_df, spikes) except: print('shuffling thing') embed() name = end_calc_ram_name_eod(end='-spikes_') + end_calc_ram_name(data_name, end, file_name, amp, '_' + cell_type, species.replace(' ', '')) # data_name + end + '_amp_' + str(amp) + '_filename_' + str(file_name) + '.csv' spikes_df.to_csv(name, index=False) return file_name, spikes, spikes_selected def get_min_amp_and_first_file(spikes_data_short, min_find=False): if min_find == True: amp = [np.min(spikes_data_short.amp)] if len(spikes_data_short.file_name.unique()) > 1: print('alignment problem') embed() file_name = [spikes_data_short.file_name.unique()[0]] else: amp = spikes_data_short.amp.unique() file_name = spikes_data_short.file_name.unique() # [0] return amp, file_name def save_RAM_eod_to_csv(data_name, spikes_selected): eod_path = load_only_spikes_RAM(data_name=data_name, emb=False, core_name='calc_RAM_data_eod_extra__first1_order__') if os.path.exists(eod_path): eod_data_short = pd.read_pickle(eod_path) save_RAM_wod_to_csv_core(data_name, eod_data_short) else: if spikes_selected.file_name2.iloc[0] == 'InputArr_400hz_30s': eod_path = load_only_spikes_RAM(data_name='2022-01-06-aa-invivo-1', emb=False, core_name='calc_RAM_data_eod_extra__first1_order__') eod_data_short = pd.read_pickle(eod_path) save_RAM_wod_to_csv_core(data_name, eod_data_short) else: pass def save_RAM_wod_to_csv_core(data_name, eod_data_short, end='', cell_type='', species='', ): amps, file_names = get_min_amp_and_first_file(eod_data_short) for a, amp in enumerate(amps): for file_name in file_names: eod_selected = eod_data_short[(eod_data_short.amp == amp) & (eod_data_short.file_name == file_name)] # eods = get_array_from_pandas(eod_selected['eod'], abs=False) if len(eods) > 0: length = [] for eod in eods: length.append(len(eod)) try: eods_df = pd.DataFrame(eods[np.argmax(length)], columns=['eod']) except: print('something') embed() name = end_calc_ram_name_eod(end='-eod_') + end_calc_ram_name(data_name, end, file_name, amp, cell_type='_' + cell_type, species=species) # +cell_type eods_df.to_csv(name, index=False) # calc_nix_RAM def end_calc_ram_name_eod(end='-eod_'): return 'calc_RAM/calc_nix_RAM' + end def end_calc_ram_name(data_name='', end='', file_name='', amp='', cell_type='', species=''): return data_name + end + '_amp_' + str(amp) + '_filename_' + str(file_name) + cell_type.replace(' ', '') + species.replace( ' ', '') + '.csv' def save_spikes_csv(eod_selected, spikes, spikes_selected): frame_cell = pd.DataFrame() frame_cell['spikes'] = spikes print(spikes_selected.file_name2.iloc[0]) frame_cell['file_name'] = spikes_selected.file_name2.iloc[0] frame_cell['sampling'] = eod_selected.sampling frame_cell['cell'] = eod_selected.cell frame_cell['eod_fr'] = eod_selected.eod_fr return frame_cell def compare_powers(show=False, step=str(30), v='diff', a_fes=[0.1], a_fjs=[0.1], names=['amp_max_05']): # a_fjs=[0, 0.01, 0.05, 0.1, 0.2] lim = [] default_settings(column=2, length=4) # cells = ['2012-07-03-ak-invivo-1' '2012-04-20-ak-invivo-1', '2012-05-10-ad-invivo-1', '2012-06-27-ah-invivo-1', '2012-06-27-an-invivo-1'] # ,'2012-07-03-ak-invivo-1' '2012-04-20-ak-invivo-1','2012-05-10-ad-invivo-1', '2012-06-27-ah-invivo-1','2012-06-27-an-invivo-1', '2012-07-03-ak-invivo-1'] nrs = [0.5, 1, 1.5, 3] adapt = 'adaptoffsetallall2' self = '' version_sinz = 'sinz' for aa, a_fe in enumerate(a_fes): for a, a_fj in enumerate(a_fjs): for what in names: plt.figure() grid = gridspec.GridSpec(len(cells), len(nrs), hspace=0.65, wspace=0.34, bottom=0.15, top=0.97) for c, cell in enumerate(cells): for n, nr in enumerate(nrs): full_name = 'modell_all_cell_no_sinz' + str( nr) + '_afe0.1__afr1__afj0.1__length0.5_adaptoffsetallall2_0.995_1.005____stepefish50Hz_ratecorrrisidual35__modelbigfit_nfft4096_StartEmitter0.5_EndEmitter1.5_StartJammer0.5_EndJammer1.5Three_SameOffset' versions, arrays = get_all_squares(adapt=adapt, full_name=full_name, self=self, version_sinz=version_sinz, cell=cell, a_fe=a_fe, nr=nr, a_fj=a_fj, what=what, step=step, ) if len(versions['diff']) > 0: if len(versions) > 0: plt.subplot(grid[c, n]) plt.title('Power:' + str(nr)) if lim == []: min = np.min(np.min(versions[v])) max = np.max(np.max(versions[v])) lim = np.max([np.abs(min), np.abs(max)]) axs = sns.heatmap(versions[v], vmin=-lim, vmax=lim, cmap="RdBu_r", cbar_kws={'label': 'Nonlinearity [Hz]'}) # 'location': "left" axs.invert_yaxis() plt.subplots_adjust(hspace=0.8, wspace=0.8) save_visualization(show) def get_all_squares(full_name='', self='', version_sinz='sinz', cell=[], a_fe=0.2, nr=1, a_fj=0.2, what='std', step='30', adapt='adaptoffset_bisecting', variant='no'): squares = ['012', 'control_01', 'control_02', 'base_0'] # 'base_0' versions = {} arrays = [] for s, square in enumerate(squares): control, vers_here, cell, eod_m, fr_rate_mult = define_squares_model_three(a_fe=a_fe, nr=nr, a_fj=a_fj, what=what, step=step, adapt=adapt, variant=variant, square=square, cell_data=cell, full_name=full_name, self=self, minimum=0.5, maximum=1.5, version_sinz=version_sinz) versions[square] = vers_here arrays.append(np.array(vers_here)) if len(versions['012']) > 0: versions['diff'] = versions['012'] - versions['control_01'] - versions['control_02'] + versions['base_0'] else: versions['diff'] = [] return versions, arrays def plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, ylim=[-1.1, 1.1], g=0, jammer_label='f2', emitter_label='f1', remove=True, receiver_label='f0', waves_present=['receiver', 'emitter', 'jammer', 'all'], color='grey', eod_fr=700, add=0, xlim=[0, 0.05], extract_here=True, colors = [], sheme_shift=0, extract='', title=[], threshold=0.02, color_eod='grey'): stimulus = np.zeros(len(eod_fish_r)) axes = [] for ww, w in enumerate(waves_present): ax = plt.subplot(grid0[ww + sheme_shift, g]) axes.append(ax) if len(colors)<1: color_eod = 'grey' else: color_eod = colors[ww] if w == 'receiver': ax.plot(time, eod_fish_r, color=color_eod, lw=0.5) stimulus += eod_fish_r ax.set_ylim(ylim) if len(xlim) > 0: ax.set_xlim(xlim) ax.spines['bottom'].set_visible(False) if g == 0: ax.set_ylabel(receiver_label) elif w == 'emitter': ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20) ax.plot(time, eod_fish_e, color=color_eod, lw=0.5) stimulus += eod_fish_e ax.set_ylim(ylim) if len(xlim) > 0: ax.set_xlim(xlim) ax.spines['bottom'].set_visible(False) if g == 0: ax.set_ylabel(emitter_label) # , color='grey' elif w == 'jammer': ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20) ax.plot(time, eod_fish_j, color=color_eod, lw=0.5) stimulus += eod_fish_j ax.set_ylim(ylim) if len(xlim) > 0: ax.set_xlim(xlim) if g == 0: ax.set_ylabel(jammer_label) elif w == 'all': if title: ax.set_title(title, color=color) ax.text(0.5, 1.45 + add, '$=$', va='center', ha='center', transform=ax.transAxes, fontsize=20) eod_interp, eod_norm = extract_am(stimulus, time, extract=extract, norm=False, sampling=1 / time[1], eodf=eod_fr, emb=False, threshold=threshold) # , extract=extract plt.plot(time, stimulus, color=color_eod, lw=0.5) if extract_here: plt.plot(time[1::], eod_interp[1::], color=color) # , clip_on = False plt.ylim(-1.22, 1.22) if len(xlim) > 0: plt.xlim(xlim) plt.ylim(ylim) if g == 0: plt.ylabel('stimulus') if g == 0: if ww == 0: plt.ylabel(receiver_label) elif ww == 1: plt.ylabel(emitter_label) elif ww == 2: plt.ylabel(jammer_label) elif ww == 3: plt.ylabel('stimulus') if remove: ax.show_spines('') ax.set_xticks([]) ax.set_yticks([]) return ax, axes def experimental_protocol_lissbon_amps(add='', show=True, ): default_figsize(column=2, length=3) grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=0.99, bottom=0.07, right=0.98) # height_ratios = [1,6]bottom=0.25, top=0.8, stimulus_length = 0.1 deltat = 1 / 20000 eod_fr = 750 a_fr = 1 eod_fe = 680 # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fes = [0.01, 0.2, 0.6, 1] # ,1.2,2]1 ylim = [-2.3, 2.3] # [-2, 2] eod_fj = 730 # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = 0.05 variant_cell = 'no' # 'receiver_emitter_jammer' waves_presents = [['receiver', 'emitter', 'all']] * len(a_fes) color = ['black'] * len(a_fes) symbols = [''] * len(a_fes) gs = np.arange(0, len(color), 1) grid0 = gridspec.GridSpecFromSubplotSpec(3, len(gs), wspace=0.3, hspace=0.35, subplot_spec=grid[0]) axes0 = [] for i in range(len(waves_presents)): eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fes[i], [eod_fe], 0, eod_fj, a_fj) time = time * 1000 print(eod_fe - eod_fr) xlim = [0, 30] if i == 3: extract_here = False else: extract_here = True ax, axes = plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=gs[i], waves_present=waves_presents[i], color=color[i], eod_fr=eod_fr, ylim=ylim, xlim=xlim, jammer_label='', emitter_label='', receiver_label='', title='', extract_here=extract_here, remove=True, threshold=-0.25) # extract = 'globalmax', axes0.append(axes[0]) if extract_here == False: time = np.arange(0, stimulus_length, deltat) time_fish_r = time * 2 * np.pi * np.abs((eod_fr - eod_fe)) eod_fish_r = 1 + (a_fes[i] - 0.12) * np.cos(time_fish_r) time = time * 1000 ax.plot(time, eod_fish_r + 0.3, color='black') ax.set_ylim(ylim) ax.text(1, 1.03, '$c=%s' % (int(np.round(a_fes[i] * 100))) + '\,\%$', transform=ax.transAxes, ha='right') # +' distance = '+str(int(np.round(a_fes_cm[i])))+' cm') test = False if test: print(str(np.max(eod_fish_r)) + ' ' + str(np.max(eod_fish_e))) plt.plot(eod_fish_r) plt.plot(eod_fish_e) plt.show() if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) if i == 0: manual = False if manual: ax.plot([0, 10], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.1, '10 ms', va='center', fontsize=11, transform=ax.transAxes) ax.xscalebar(0.25, -0.02, 10, 'ms', va='right', ha='bottom') extra = False if extra: models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core') cell = '2012-07-03-ak-invivo-1' deltat, eod_fr, model_params, offset = get_model_params(models, cell=cell) spikes = [[]] * 5 for t in range(5): cvs, adapt_output, baseline_after, _, rate_adapted, rate_baseline_before, rate_baseline_after, spikes[ t], \ stimulus_altered, \ v_dent_output, offset_new, v_mem_output, noise_final = simulate(cell, offset, eod_fish_r + eod_fish_e + eod_fish_j, EODf=eod_fr, deltat=deltat, **model_params) base_cut, mat_base = find_base_fr(spikes, deltat, stimulus_length, time, dev=0.0005) ax = plt.subplot(grid0[-2, i]) ax.eventplot(np.array(spikes) * 1000, color='grey') ax.set_xlim(xlim) ax.show_spines('') ax = plt.subplot(grid0[-1, i]) ax.plot(time, mat_base[0:len(time)], color='black') ax.set_xlim(xlim) ax.set_ylabel('Firing Rate [Hz]') fig = plt.gcf() fig.tag(axes0, xoffs=-3.5, yoffs=0.1) save_visualization('', show, jpg=True, png=False, counter_contrast=0, savename='', add=add) def experimental_protocol_lissbon(add='', color=['green', 'blue', 'red', 'orange'], titles=['receiver', 'receiver + female', 'receiver + intruder', 'receiver + female + intruder', []], waves_presents=[['receiver', '', '', 'all'], ['receiver', 'emitter', '', 'all'], ['receiver', '', 'jammer', 'all'], ['receiver', 'emitter', 'jammer', 'all'], ], figsize=(12, 5.5), show=True, ): plt.figure(figsize=figsize) grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=0.99, bottom=0.07, right=0.95) # height_ratios = [1,6]bottom=0.25, top=0.8, grid0 = gridspec.GridSpecFromSubplotSpec(4, 4, wspace=0.3, hspace=0.35, height_ratios=[1, 1, 1, 1], subplot_spec=grid[0]) stimulus_length = 0.3 deltat = 1 / 40000 eod_fr = 750 a_fr = 1 eod_fe = 600 # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0] a_fe = 0.5 eod_fj = 680 # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0] a_fj = 0.05 variant_cell = 'no' # 'receiver_emitter_jammer' eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves( variant_cell, '', stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj) gs = [0, 1, 2, 3, 4] symbols = ['', '', '', '', ''] ylim = [-2, 2] time = time * 1000 for i in range(len(waves_presents)): ax, axes = plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=gs[i], waves_present=waves_presents[i], color=color[i], eod_fr=eod_fr, ylim=ylim, xlim=[0, 70], jammer_label='intruder', emitter_label='female', receiver_label='receiver', title=titles[i]) if ax != []: ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes) if i == 0: ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black') ax.text(0, -0.1, '20 ms', va='center', fontsize=11, transform=ax.transAxes) ax.set_ylim(ylim) axes.append(ax) fig = plt.gcf() axes = fig.axes fig.tag(axes[0::4], xoffs=-3, yoffs=0.3) save_visualization('', show, jpg=True, png=False, counter_contrast=0, savename='', add=add) def rainbow_title(fig, axt, titles, color_add_pos, ha='left', a=0, start_xpos=0, y_pos=1.1): if type(titles) != str: for aa in range(len(titles)): if aa == 0: pos = start_xpos # + add_pos[a][aa] text = axt.text(pos, y_pos, titles[aa], color=color_add_pos[a][aa], transform=axt.transAxes, ha=ha) # verticalalignment='right', text.draw(fig.canvas.get_renderer()) ex = text.get_window_extent() ex2 = ex.transformed(axt.transAxes.inverted()) pos = ex2.get_points()[1][0] + 0.01 else: axt.text(0, 1.1, titles, color='black', transform=axt.transAxes) # verticalalignment='right',add_pos[a] def calc_areas(path, frame_ref, colr, x_pos, cells_chosen): default_settings(column=2, length=3) frame = pd.read_csv(path) cvs = frame_ref.cv_0 cells = frame_ref.cell.unique() areas_01 = np.ones(len(cells)) * float('nan') areas_012 = np.ones(len(cells)) * float('nan') areas_01_one = np.ones(len(cells)) * float('nan') areas_012_one = np.ones(len(cells)) * float('nan') nonlin = np.ones(len(cells)) * float('nan') nonlin_area = np.ones(len(cells)) * float('nan') areas_01_scatter = colr * 1 name = 'score' for c, cell in enumerate(cells): frame_cell = frame[frame.cell == cell] frame_cell['score'] = get_nonlin_scores(frame_cell) areas_01_one[c] = fin_min_pos(frame_cell, 'auci_base_01', x_pos) areas_012_one[c] = fin_min_pos(frame_cell, 'auci_02_012', x_pos) areas_01[c] = metrics.auc(frame_cell.c1, frame_cell['auci_base_01']) areas_012[c] = metrics.auc(frame_cell.c1, frame_cell['auci_02_012']) nonlin[c] = fin_min_pos(frame_cell, name, x_pos) nonlin_area[c] = metrics.auc(frame_cell.c1, frame_cell[ name]) # metrics.auc(frame_cell.c1,frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean']) if cell in cells_chosen: areas_01_scatter[c] = 'black' diff_areas = np.array(areas_012) - np.array(areas_01) return cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one - areas_01_one def get_nonlin_scores(frame_cell): score = val_new_core(frame_cell) return score def nonlinval_core(frame_cell): return frame_cell[val_nonlin_chapter4()] / (frame_cell['amp_B1_01_mean']) def val_nonlin_chapter4(): return 'amp_B1+B2_012_mean' def fin_min_pos(frame_cell, name, x_pos): nonlin = frame_cell[name].iloc[np.argmin(np.abs( frame_cell.c1 - x_pos))] # metrics.auc(frame_cell.c1,frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean']) return nonlin def plt_scatter_nonlin_all_main(): default_settings(column=2, length=2.3) frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal', 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'] path = load_folder_name('calc_ROC') + '/' + frame_names[1] + '.csv' path_ref = load_folder_name( 'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv' frame_ref = pd.read_csv(path_ref) frame_ref = frame_ref.sort_values(by='cv_0') cm = plt.get_cmap("hsv") cells_chosen = ['2013-01-08-aa-invivo-1', "2012-06-27-ah-invivo-1", "2014-06-06-ac-invivo-1"] # '2012-06-27-an-invivo-1', colr = [cm(float(i) / (len(frame_ref))) for i in range(len(frame_ref))] fig, ax = plt.subplots(1, 3) # figsize = (12,5.5) x_pos = 0.02 # 'amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean', names = ['amp_B1+B2_012-01-02+0_norm_01B1_mean'] # 'amp_B1+B2_012-01-02+0_norm_01B1_mean', for _, _ in enumerate(names): cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one = calc_areas(path, frame_ref, colr, x_pos, cells_chosen) color = 'grey' ax[0].scatter(cvs, diff_areas, color=color, s=15, clip_on=False) # color=colr, ax[0].axhline(0, linestyle='--', linewidth=0.5, color='grey') ax[0].set_xlabel('CV') ax[0].set_xlim(0, 1.1) ax[0].set_ylabel(core_scatter_wfemale()) ax[1].scatter(cvs, nonlin_area, color=color, s=15, clip_on=False) # color = colr, ax[1].axhline(0, linestyle='--', linewidth=0.5, color='grey') ax[1].set_xlabel('CV') ax[1].set_xlim(0, 1.1) ax[1].set_ylabel(peak_b1b2_name()) ax[2].set_xlabel(core_scatter_wfemale()) ax[2].set_ylabel(peak_b1b2_name()) corr, p_value = stats.pearsonr(nonlin_area, diff_areas) label = pearson_label(corr, p_value, nonlin_area, n=True) ax[2].text(1, 1.05, label, ha='right', transform=ax[2].transAxes) ax[2].scatter(nonlin_area, diff_areas, color=color, s=15, clip_on=False) # color=colr, model = LinearRegression() model.fit(nonlin_area.reshape((-1, 1)), diff_areas.reshape((-1, 1))) slope = model.coef_ intercept = model.intercept_ ax[2].plot([0, np.max(nonlin_area) * 1.05], [intercept, intercept + np.max(nonlin_area) * 1.05 * slope], color='grey', linewidth=0.5) ##embed()'Correlation='+str(np.round(corr,2)) make_simple_tags(ax, xpos=-0.07, letters=['A', 'B', 'C']) plt.subplots_adjust(wspace=0.85, hspace=0.4, bottom=0.21, right=0.95) # , top = 0.6 save_visualization() plt.show() def core_scatter_wfemale(): return '$\mathrm{AUC_{Female}}-\mathrm{AUC_{NoFemale}}$ ' def peak_b1b2_name(): return 'Nonlinearity $A(\Delta \mathrm{f_{Sum}})$ [Hz]' def core_decline_ROC(trial_nr='20', absolut=True, short=True, pos=True): lastnr = '0.1' b_cond = b_cond_core() combpos = b_cond + '_FrF1rel_0.3_FrF2rel_0.7' combneg = 'vertical1_FrF1rel_1_FrF2rel_0.7' if not absolut: if not short: if pos: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal'] else: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal'] else: if pos: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal'] else: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal'] else: if not short: if pos: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal'] else: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal'] else: if pos: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal'] else: frame_names = [ 'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal'] return frame_names, trial_nr def comb_talk(): return 'B1+B2_diagonal3' def comb_talk2(): return 'B1-B2_diagonal3' def b_cond_core(): return 'B1+B2_diagonal3' def calc_right_core_nonlin(transient, steps_name, T, steps, results, data_mat, position, dt, add_name=''): c = [[]] * 4 phi = [[]] * 4 c0_tmp = 0.0 a_tmp = np.zeros(4) b_tmp = np.zeros(4) # Calculate a and b for groundmode and higher harmonics for j in range(4): freq_here = (j + 1) / T time_here_all = np.arange(transient + steps) * dt # also hier gibts die option den anfang für die trans rauszunehmen, kann aber auch Null sein if j < 1: c0_tmp = np.mean(data_mat[transient:transient + len(time_here_all)]) try: a_tmp[j] = np.mean(data_mat[transient:len(time_here_all)] * np.cos( 2.0 * np.pi * time_here_all[transient:len(time_here_all)] * freq_here)) except: print('a tmp problems') embed() b_tmp[j] = np.mean(data_mat[transient:len(time_here_all)] * np.sin( 2.0 * np.pi * time_here_all[transient:len(time_here_all)] * freq_here)) c0 = c0_tmp # / (steps) a = a_tmp * 2.0 # / (steps) b = b_tmp * 2.0 # / (steps) for k in range(4): c[k] = np.sqrt(a[k] ** 2 + b[k] ** 2) phi[k] = np.arctan(a[k] / b[k]) test = False if test: c0, c, phi = average_fft() # a_tmp, b_tmp, c, c0_tmp, phi, s for c_nr in range(len(c)): try: results.loc[position, 'c_' + str(c_nr) + steps_name + add_name] = c[c_nr] except: print('c problem') embed() for c_nr in range(len(c)): results.loc[position, 'phi_' + str(c_nr) + steps_name + add_name] = phi[c_nr] results.loc[position, 'c0' + steps_name + add_name] = c0 return results def average_fft(a_tmp, b_tmp, c, c0_tmp, phi, s, steps=1): c0 = c0_tmp / (1.0 * steps) # (single_period) a = a_tmp * 2.0 / (1.0 * steps) # / (single_period) b = b_tmp * 2.0 / (1.0 * steps) # / (single_period) for k in range(4): c[s, k] = np.sqrt(a[k] ** 2 + b[k] ** 2) phi[s, k] = np.arctan(a[k] / b[k]) return c0, c, phi def load_savedir(level=0, individual_tag='', frame=[], save=False, csv=False, pkl=False, emb=False): # ich kann das einmal als übersichtsfile haben auf level 0 # als variierendes file auf level 1 # und vielleich ein example file auf level 1 ziehen damit man eine versinosübersicht hat if 'miniconda3' in inspect.stack()[1][1]: initial_function = \ inspect.stack()[-16][1].split('/')[-1].split('.')[0] last_function = inspect.stack()[-16][4][0].split('(')[0].strip() else: initial_function = \ inspect.stack()[-1][1].split('\\')[-1].split('.')[0] list_name = [] for i in range(len(inspect.stack())): save_name = inspect.stack()[i][3] list_name.append(save_name) pos = -2 # np.where(np.array(list_name) == '')[0][0]-1 last_function = list_name[pos] if emb: embed() print(initial_function) t1 = time.time() data_extra_fold = '' # '_data' if level == 0: # Null für die neuen # auf dem Nuller Level muss man das mit dem Funktionsnamen machen # aber das sollte man selten verwenden save_name = initial_function + data_extra_fold + '/' + last_function + '-' elif level == 1: # 1 für die alten # auf dem Nuller Level muss man das mit dem Funktionsnamen machen # aber das sollte man selten verwenden save_name = initial_function + data_extra_fold + '/' elif level == 2: # am besten tut man die Basic functions auf das 1er Level # die haben einen eigenständigen Namen und sind in dem Funktions Ordner save_name = initial_function + data_extra_fold + '/' + last_function + '/' elif level == 3: # und die Zellen etc Sachen bzw die Versions Sachen sind dann im nächsten Ordner if not os.path.isdir(initial_function + + data_extra_fold + '/' + last_function + '/cells'): os.mkdir(initial_function + data_extra_fold + '/' + last_function + '/cells') save_name = initial_function + data_extra_fold + '/' + last_function + '/cells/' try: if save: if csv: frame.to_csv(save_name + individual_tag + '.csv', index=False) if pkl: frame.to_pickle(save_name + individual_tag + '.pkl') except: print('save problem') embed() t2 = time.time() - t1 print(f'save time {t2}') return save_name def redo_on_cell_level(redo_level, append_cells, redo, beat_results, cell, counter_continued, cell_name='dataset', range_orig1=[], range_orig2=[]): # do_thiseod - of true do this frequency new combs = [] if 'celllevel' in redo_level: # decide if cell in sample or not if (append_cells == True) and (redo == False): if cell in np.unique(beat_results[cell_name]): if 'clusters' in redo_level: f1_present = np.unique(beat_results[beat_results[cell_name] == cell].f1) # ].f1 f2_present = np.unique(beat_results[beat_results[cell_name] == cell].f2) # ].f1 len_required = len(range_orig1) * len(range_orig2) len_present = len(f1_present) * len(f2_present) # ich subtrahiere noch die kontrolle mit 10 Hz, # also ich schau erst ob das wirklich nur eine Kontrolle ist oder doch nicht if (10 not in list(map(int, range_orig1))) & (10 not in list(map(int, range_orig2))): len_remaining = len_present - (len(f1_present) + len(f2_present)) beat_cell = beat_results[beat_results[cell_name] == cell] beat_corrected = beat_cell[(beat_cell['f2'] != 10) & (beat_cell['f1'] != 10)] combs_all = beat_corrected[['f2', 'f1']] elif 10 not in range_orig1: len_remaining = len_present - (len(f1_present)) beat_cell = beat_results[beat_results[cell_name] == cell] beat_corrected = beat_cell[beat_cell['f1'] != 10] combs_all = beat_corrected[['f2', 'f1']] elif 10 not in range_orig2: len_remaining = len_present - (len(f2_present)) beat_cell = beat_results[beat_results[cell_name] == cell] beat_corrected = beat_cell[beat_cell['f2'] != 10] combs_all = beat_corrected[['f2', 'f1']] else: len_remaining = len_present combs_all = beat_results[beat_results[cell_name] == cell][['f2', 'f1']] combs = np.unique(combs_all, axis=0) if len_remaining != len_required: do_thiscell = True do_thiseod = True else: do_thiscell = False do_thiseod = False counter_continued += 1 print('Nr ' + str(counter_continued) + 'already there') else: do_thiscell = False do_thiseod = False counter_continued += 1 print('Nr ' + str(counter_continued) + 'already there') else: do_thiscell = True do_thiseod = True # just do all cells irrespective if they are in any files or not else: do_thiscell = True do_thiseod = True else: do_thiscell = True do_thiseod = True return do_thiscell, do_thiseod, counter_continued, combs def redo_or_append(save_name, redo=False, name_orig=[]): # output # beatresults - preallocated array # addcell - add cells or redo # if we dont redo but continue saving if redo == False: if len(name_orig) < 1: name = save_name + '.pkl' # name1 = folder_name('calc_model')+'/modell_all_cell_' + save_name + '.pkl' # if the datataname exists add new cells to the existing if os.path.exists(name): preallocated = pd.read_pickle(name) position = len(preallocated) if len(preallocated) > 0: append_cells = True else: append_cells = False preallocated = pd.DataFrame() position = 0 else: preallocated = pd.DataFrame() position = 0 append_cells = False else: if os.path.exists(name_orig): preallocated = pd.read_pickle(name_orig) position = len(preallocated) append_cells = True else: preallocated = pd.DataFrame() position = 0 append_cells = False # else preallocate an empty array else: # if we want to redo the whole simulation # preallocated = [] preallocated = pd.DataFrame() position = 0 append_cells = False return append_cells, preallocated, position def calc_nonlinearity_contrasts(transient_s=0, cells=[], n=1, adapt_offset='adaptoffsetallall2', stimulus_length_orig=2, freq_type='_beat_', single_train='', fft='fft', dev='original', trials_nr=150, zeros='zeros', a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1], add_half=0, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto', redo_level='celllevel', us_name='', adapt_type=''): # adapt = ''_adaptMean_ stimulus_length = stimulus_length_orig # single_train = 'single_train' model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") cell_nrs = [11, 22, 5, 13, 12, 6, 15, 20, 4, 7, 26, 23] if len(cells) < 1: frame = pd.read_pickle('cv.pkl') frame = frame.sort_values(by='cv') cells = frame.cell # np.array(model_cells['cell']) if len(cell_nrs) < 1: pass freqs2 = [0] # freqs1#[0]# single_waves = ['_SingleWave_'] a_f2s = [np.max(a_f1s)] for a, a_fr in enumerate(a_frs): ####### VARY HERE for _ in single_waves: # if single_wave == '_SingleWave_': for a_f2 in a_f2s: try: save_name = save_name_nonlinearity(add_half, a_f2s=a_f2s, freqs2=freqs2, a_f1_end=a_f1s[-1], transient_s=transient_s, n=n, adapt_offset=adapt_offset, freq_type=freq_type, adapt=adapt_type, stimulus_length=stimulus_length, fft=fft, dev=dev, trials_nr=trials_nr, a_fr=a_fr, zeros=zeros) except: print('save name problem') embed() print(save_name) redo = False save_dir = load_savedir(level=1) append_cells, results, position = redo_or_append( save_dir + 'modell_all_cell_' + save_name, redo=redo, name_orig=save_name) counter_continued = 0 for cell in cells: # cell_nr in cell_nrs: ########################################### # fig, ax = plt.subplots(len(cell_nrs), 1, figsize=(12, 5.5)) # sharex=True, try: model_cells_here = model_cells[model_cells['cell'] == cell] model_params = model_cells_here.iloc[0] except: print('single positional index doesnt exists') embed() eod_fr = model_params['EODf'] offset = model_params.pop('v_offset') cell = model_params.pop('cell') print(cell) do_this_cell_orig, do_thiseod, counter_continued, combs_all = redo_on_cell_level( redo_level, append_cells, redo, results, cell, counter_continued, cell_name='cell') if type(add_half) == str: freqs1_len = freqs_array(add_half, eod_fr) else: freqs1_len = [0] if do_this_cell_orig == False: results_cell = results[results['cell'] == cell] if len(results_cell) == len(a_f1s) * len(freqs1_len): do_this_cell_now = False else: do_this_cell_now = True else: do_this_cell_now = True if do_this_cell_now: f1 = 0 f2 = 0 sampling_factor = '' phaseshift_fr = 0 cell_recording = '' mimick = 'no' fish_morph_harmonics_var = 'harmonic' fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus'] fish_receiver = 'Alepto' # phase_right = '_phaseright_' constant_reduction = '' lower_tol = 0.995 upper_tol = 1.005 SAM = '' # , damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35 damping_type = '' exponential = '' dent_tau_change = 1 # in case you want a different sampling here we can adujust time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length) # generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus) eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length, phaseshift_fr, cell_recording, zeros, mimick, sampling, fish_receiver, deltat, nfft, nfft_for_morph, fish_morph_harmonics_var=fish_morph_harmonics_var, beat=beat) sampling = 1 / deltat slope = 0 add = 0 plus = 0 if exponential == '': v_exp = 1 exp_tau = 0.001 # prepare for adapting offset due to baseline modification _, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph, phaseshift_fr, mimick, zeros, cell_recording, sampling, stimulus_length, fish_receiver, deltat, nfft, damping_type, damping, us_name, gain, beat=beat, fish_morph_harmonics_var=fish_morph_harmonics_var) # now we are ready for the final modeling part in this function trials_nr_base = 10 spike_adapted = [[]] * trials_nr_base for t in range(trials_nr_base): # get the baseline properties here # baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output if a_fr == 0: power_here = 'sinz' + '_' + zeros else: power_here = 'sinz' # embed() # todo: evnetuell das mit dem zeros am ende dazu! cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \ spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r, deltat=deltat, adaptation_variant=adapt_offset, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, power_variant=power_here, power_alpha=alpha, power_nr=n, **model_params) if t == 0: print('first Baseline ' + str(rate_baseline_before_b)) # here we record the changes in the offset due to the adaptation # and we subsequently reset the offset to be the new adapted for all subsequent trials offset = offset_new * 1 # Baseline Characteristics mean_isi, std_isi, fr, isi, cv0, ser0, ser_first0, ser_sum = calc_baseline_char(spike_adapted, stimulus_length, trials_nr_base) print('fr base ' + str(fr)) freqs1 = get_freqs_contrasts(a_fr, add_half, eod_fr, fr) for ff, freq1 in enumerate(freqs1): freq1 = [freq1] sampling_rate = 1 / deltat for fff, freq2 in enumerate(freqs2): # freq2 = [freq2[fff]] # Array Fouriercoefficients and phase, columns for groundmode and higher harmonics (4 total) # Array for proportional lines if 'f1' in results.keys(): results_f = results[results['f1'] == freq1[0]] else: results_f = results # [results['f1'] == freq1[0]] for aa, a_f1 in enumerate(a_f1s): if do_this_cell_orig: do_af1 = True else: if np.round(a_f1, 6) not in np.array(np.round(results_f.a_f1, 6)): do_af1 = True else: do_af1 = False print('f1_' + str(freq1) + '_af1_' + str(a_f1)) if do_af1: results, position = calc_single_af_nonlin(transient_s, freq_type, ser0, single_train, dev, fft, a_fr, trials_nr, results, nfft, damping_type, gain, save_name, cvs, position, cv0, fr, cell, sampling_rate, model_params, n, dent_tau_change, constant_reduction, exponential, plus, slope, add, deltat, alpha, lower_tol, upper_tol, v_exp, exp_tau, f2, fish_jammer, freq2, damping, us_name, a_f2, eod_fish_r, SAM, aa, offset, freq1, eod_fr, phase_right, a_f1, phaseshift_fr, nfft_for_morph, cell_recording, fish_morph_harmonics_var, time_array, mimick, fish_emitter, f1, sampling, stimulus_length, adapt_type=adapt_type) def get_freqs_contrasts(a_fr, add_half, eod_fr, fr): if a_fr == 1: # das ist fals wir einen freq scan haben if type(add_half) == str: freqs1 = freqs_array(add_half, eod_fr) else: beat1 = fr / 2 + add_half freqs1 = [eod_fr - beat1] else: freqs1 = [fr / 2 + add_half] return freqs1 def freqs_array(add_half, eod_fr): from_val = add_half.split('frange_from_')[1].split('_to')[0] to_val = add_half.split('to_')[1].split('_in')[0] in_step = add_half.split('in_')[1].split('_')[0] # frange_from_0_to_400_in_1 freqs1 = np.arange(eod_fr + float(from_val), eod_fr + float(to_val), float(in_step)) return freqs1 def calc_single_af_nonlin(transient_s, freq_type, ser0, single_train, dev, fft, a_fr, trials_nr, results, nfft, damping_type, gain, save_name, cvs, position, cv0, fr, cell, sampling_rate, model_params, n, dent_tau_change, constant_reduction, exponential, plus, slope, add, deltat, alpha, lower_tol, upper_tol, v_exp, exp_tau, f2, fish_jammer, freq2, damping, us_name, a_f2, eod_fish_r, SAM, aa, offset, freq1, eod_fr, phase_right, a_f1, phaseshift_fr, nfft_for_morph, cell_recording, fish_morph_harmonics_var, time_array, mimick, fish_emitter, f1, sampling, stimulus_length, adapt_type=''): print('af1_nr ' + str(aa) + ' offset ' + str(offset)) beat1 = freq1 - eod_fr phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr) eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, 'zeros', mimick, fish_emitter, thistype='emitter') eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling, stimulus_length, nfft_for_morph, cell_recording, fish_morph_harmonics_var, 'zeros', mimick, fish_jammer, thistype='jammer') eod_stimulus = eod_fish1 + eod_fish2 t1 = time.time() spikes = [[]] * trials_nr for t in range(trials_nr): stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1, eod_fr, time_array, a_f1, eod_fj=freq2, j=f2, a_fj=a_f2) # damping variants std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(stimulus, time_array, damping_type, eod_fr, gain, damping, us_name, plot=False, std_dump=0, max_dump=0, range_dump=0) # WIR BRAUCHEN KEIN ADAPT HIER DAS HABEN WIR SCHON BEI DER BASELIN GEMACHT adapt_offset_here = 'no' if a_fr == 0: power_here = 'sinz' + '_' + zeros # not found else: power_here = 'sinz' _, adapt_output, baseline_after, _, _, _, \ _, spikes[t], \ _, \ _, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat, adaptation_variant=adapt_offset_here, adaptation_yes_j=f2, adaptation_yes_e=f1, adaptation_yes_t=t, adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol, power_variant=power_here, power_alpha=alpha, power_nr=n, tau_change_choice=constant_reduction, tau_change_val=dent_tau_change, sigmoidal_mult=1, sigmoidal_plus=plus, sigmoidal_slope=slope, sigmoidal_add=add, LIF_adapt_type=adapt_type, LIF_exponential=exponential, LIF_exponential_tau=exp_tau, LIF_expontential__v=v_exp, **model_params) mean_isi, std_isi, fr1, isi, cv1, ser1, ser_first_stim, ser_sum_stim = calc_baseline_char(spikes, stimulus_length, trials_nr) print('fr stim' + str(fr1)) # hier noch das psd einer gemittelten rate ################## # hier das mittel der psds t2 = time.time() print('model' + str(t2 - t1)) if fft == 'psd': spikes_mat = [[]] * len(spikes) pp = [[]] * len(spikes) for s in range(len(spikes)): spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat)) for s in range(len(spikes)): pp[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft, noverlap=nfft // 2) pp_mean = np.mean(pp, axis=0) names = peaks_1d(fr, a_fr, beat1, freq1) results = find_peaks_power(names, f, pp_mean, '_original', position, results) t2 = time.time() print('psd' + str(t2 - t1)) elif 'fft' in fft: test = False if test: from utils_test import test_spikes test_spikes() spikes_here = np.concatenate(spikes) mat_orig, time2, samp2 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=deltat) mat_orig2, time2, samp2 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 500) mat_orig_eod, time_eod, samp_eod = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / eod_fr) mat_orig_02, time02, samp02 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 2000) mat_orig_05, time05, samp05 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 5000) sampling_rates = [sampling_rate] if single_train == '_singletrain_': if dev == 'original': mats = [spikes_mat[0]] mat_names = [''] else: smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate) mats = [smoothed05[0]] mat_names = [''] else: if dev == 'original': several = False if several == 'several': mats = [mat_orig2, mat_orig_eod, mat_orig_02, mat_orig_05, mat_orig] mat_names = ['_500', '_eodfr', '_2000', '_5000', '_dt'] test = False if test: from utils_test import test_mean test_mean() sampling_rates = [500, eod_fr, 2000, 5000, sampling_rate] else: mats = [mat_orig] mat_names = ['_dt'] sampling_rates = [sampling_rate] else: smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate) # not there mat = np.mean(smoothed05, axis=0) mats = [mat] mat_names = [''] for m, mat_use in enumerate(mats): dt = 1 / sampling_rates[m] transient = int(transient_s / dt) if 'fr' in freq_type: try: freq_type_var = 0.5 * fr except: print('fr0 problem') embed() elif 'beat' in freq_type: freq_type_var = np.abs(beat1[0]) T = 1 / freq_type_var ## 1 Period of the external signal (1/f) stimulus_times = int( (len(mat_use) - transient) * dt / T) # np.arange(1, int(stimulus_length / (1 / (0.5 * fr))), 25) try: maximal_period = int((T / dt) * stimulus_times) # Mascha hat 476 steps except: print('fr1 problem') embed() single_period = int(np.round(T / dt)) # single_period = int(T / dt+0.5) steps_all = [maximal_period] # , single_period] steps_name = ['_all'] # , '_one'] for s, steps in enumerate(steps_all): results = calc_right_core_nonlin(transient, steps_name[s], T, steps, results, mat_use, position, dt, add_name=mat_names[m]) subsequent = '' if 'subsequent' in subsequent: ################################################# # for subseuqent steps ## das ist die fft berechnung! get_cfs2(T, dt, m, mat_names, mat_use, maximal_period, position, results, single_period) results.loc[position, 'a_fr'] = a_fr results.loc[position, 'a_f2'] = a_f2 results.loc[position, 'a_f1'] = a_f1 results.loc[position, 'f1'] = freq1[0] results.loc[position, 'f2'] = freq2[0] results.loc[position, 'cell'] = cell results.loc[position, 'max_adapt'] = np.nanmin(adapt_output) results.loc[position, 'min_adapt'] = np.nanmax(adapt_output) results.loc[position, 'fr'] = fr results.loc[position, 'ser'] = ser0 results.loc[position, 'cv'] = cv0 results.loc[position, 'fr_stim'] = fr1 results.loc[position, 'ser_stim'] = ser1 results.loc[position, 'ser_sum_stim'] = ser_sum_stim results.loc[position, 'ser_first_stim'] = ser_first_stim results.loc[position, 'cv_stim'] = cv1 results.loc[position, 'eod_fr'] = eod_fr # 500 *30 for cv_all in cvs: if 'cv' in cv_all: try: results.loc[position, cv_all] = cvs[cv_all] except: print('cv something') embed() position += 1 results.to_pickle(save_name) return results, position def get_cfs2(T, dt, m, mat_names, mat_use, maximal_period, position, results, single_period): steps_all = np.arange(0, maximal_period, single_period) c = np.zeros([len(steps_all), 4]) phi = np.zeros([len(steps_all), 4]) for s in range(len(steps_all) - 1): c0_tmp = 0.0 a_tmp = np.zeros(4) b_tmp = np.zeros(4) # Calculate a and b for groundmode and higher harmonics for j in range(4): if j < 1: c0_tmp = np.mean(mat_use[steps_all[s]:steps_all[s + 1]]) freq_here = (j + 1) / T time_here_all = np.arange(0, single_period, 1) * dt try: a_tmp[j] = np.mean( mat_use[steps_all[s]:steps_all[s + 1]] * np.cos(2.0 * np.pi * time_here_all * freq_here)) except: print('a tmp problems') embed() b_tmp[j] = np.mean( mat_use[steps_all[s]:steps_all[s + 1]] * np.sin(2.0 * np.pi * time_here_all * freq_here)) # Average c0, c, phi = average_fft(a_tmp, b_tmp, c, c0_tmp, phi, s) row, col = np.shape(c) for c_nr in range(col): results.loc[position, 'c_' + str(c_nr) + '_mean' + mat_names[m]] = np.mean(c[:, c_nr]) for c_nr in range(col): results.loc[position, 'phi_' + str(c_nr) + '_mean' + mat_names[m]] = np.mean(phi[:, c_nr]) results.loc[position, 'c0' + '_mean' + mat_names[m]] = c0 def arg_left_corr(arg, right_step=3, left_step=2): arg_right = arg + right_step arg_left = arg - left_step if arg_left < 0: arg_right += np.abs(arg_left) arg_left = 0 return arg_left, arg_right def find_peaks_power(names, f, pp, title, position, results, a_start='a_', f_start='f_', points=5): # for name in names: arg = np.argmin(np.abs(f - names[name])) if points == 5: arg_left, arg_right = arg_left_corr(arg, right_step=3, left_step=2) results = results.copy() results.loc[position, a_start + name + title] = np.sqrt( np.sum((pp[arg_left:arg_right]) * np.abs(f[1] - f[0]))) results.loc[position, f_start + name + title] = names[name] elif points == 3: arg_left, arg_right = arg_left_corr(arg, right_step=2, left_step=1) results = results.copy() results.loc[position, a_start + name + title] = np.sqrt( np.sum((pp[arg_left:arg_right]) * np.abs(f[1] - f[0]))) results.loc[position, f_start + name + title] = names[name] elif points == 1: results = results.copy() results.loc[position, a_start + name + title] = np.sqrt(pp[arg] * np.abs(f[1] - f[0])) results.loc[position, f_start + name + title] = names[name] return results def plot_stimulus(ax, time_ms, beat_here, am_corr_synth, color, ylim=[-2.5, 4]): ax.show_spines('') ax.plot(time_ms, beat_here, color='grey', linewidth=0.5) ax.plot(time_ms, am_corr_synth, color=color) ax.set_xlim(0, time_ms[-1]) ax.set_ylim(ylim) def plot_raster(ax, all_spikes, color, i, plot_segment, name='raster'): ax.eventplot(all_spikes, orientation='horizontal', linelengths=0.8, linewidths=1, colors=[color]) ax.show_spines('') ax.set_xlim(0, plot_segment) if i % 3 == 0: ax.text(-0.05, 0.6, name, transform=ax.transAxes, rotation=90, va='center', ha='right') def plot_peri(ax, smoothed, sampling_rate, lw_beat_corr, color, i, name='rate'): time = np.arange(len(smoothed[-1])) / sampling_rate mean_smoothed = np.mean(smoothed, axis=0) ax.plot(1000 * time, mean_smoothed, linewidth=lw_beat_corr, color=color, clip_on=False) ax.set_xlim(0, 1000 * time[-1]) ax.set_ylim(-10, 850) ax.show_spines('') if i % 3 == 0: ax.text(-0.05, 0.3, name, transform=ax.transAxes, rotation=90, va='center', ha='right') if i % 3 == 2: ax.scalebars(1.05, 0, 10, 800, 'ms', 's$^{-1}$', ha='right', vat='bottom') def load_cell(data, fname='singlecellexample5', big_file='beat_results_smoothed_limit35minimalduration0.3', redo=False): if (not os.path.exists(fname + '.csv')) or (redo == True): print('reloaded') data_all = pd.read_pickle(load_folder_name('calc_model') + '/' + big_file + '.pkl') just_cell = data_all[data_all['dataset'] == data] spikes_data = just_cell[just_cell['contrasts'] == 20] results1 = pd.DataFrame(spikes_data) results = results1.groupby(['df']).mean() spikes = [] for d in np.unique(results1['df']): spikes.append(results1[results1['df'] == d].spike_times.iloc[0]) results['base'] = results1['amp_max_beat_05'] results['spikes'] = spikes results['df'] = np.unique(results1['df']) baseline = pd.read_pickle(load_folder_name('calc_base') + '/calc_base_data-base_frame.pkl') baseline_cell = baseline[baseline.cell == data] base = baseline_cell['fr'].iloc[0] results['fr'] = base spikes_data = results names = ['spikes'] results.to_csv(fname + '.csv') save_object_from_frame(results, names, fname) else: names = ['spikes'] spikes_data = load_object_to_pandas(names, fname) return spikes_data def ws_nonlin_systems(): ws = 0.15 return ws def restrict_cell_type(cells, cell_type): p_units_cells = [] pyramidals = [] for cell in cells: if cell == '2021-11-05-ai-invivo-1': p_units_cells.append(cell) elif ('2022-02-08' not in cell) & ('2022-02-07' not in cell) & ( '2022-02-04' not in cell) & ('2022-02-03' not in cell) & ( '2021-11-11' not in cell) & ('2021-11-05' not in cell) & ( '2021-11-04' not in cell): p_units_cells.append(cell) else: pyramidals.append(cell) if cell_type == 'p-units': cells = p_units_cells else: cells = pyramidals return cells, p_units_cells, pyramidals def plt_peaks_several(freqs, p_arrays, axs_p, p0_means, fs, labels=None, j=1, colors=None, emb=False, marker='o', markeredgecolors=None, zorder=2, ha='left', add_texts=None, limit=None, texts_left=None, add_log=2, rots=None, several_peaks_nr=2, exact=True, text_extra=False, perc_peaksize=0.04, rel='rel', alphas=None, extend=False, ms=25, clip_on=False, several_peaks=True, alpha=1, log='', add_not_log = 0): df_passed = [] # , p_passed = [] for ff, f in enumerate(range(len(freqs))): add, add_text = get_add_for_several_peaks(add_log, df_passed, emb, exact, f, freqs, j, log, p_arrays, perc_peaksize, rel, add_not_log = add_not_log) if alphas is not None: if len(alphas)> 0: alpha = alphas[f] if rots is not None: if len(rots)>0: rot = rots[f] else: rot = 45 else: rot = 45 if add_texts is None: add_text = 0 else: add_text = add_texts[f] print('extraf' + str(add_texts[f])) if texts_left is not None: if len(texts_left)>0: text_left = texts_left[f] print('extraf' + str(add_texts[f])) else: text_left = 0 else: text_left = 0 try: if colors is None: color = 'black' else: color = colors[f] except: print('colors something') embed() if markeredgecolors is not None: if len(markeredgecolors)> 0: try: markeredgecolor = markeredgecolors[f] except: print('marker something') embed() else: markeredgecolor = color else: markeredgecolor = color if labels is None: label = '' else: label = labels[f] #embed() f_scatter, p_scatter = plt_peaks(axs_p, p0_means, freqs[f], fs, fr_color=color, s=ms, label=label, marker=marker, zorder=zorder, markeredgecolor=markeredgecolor, ha=ha, several_peaks_nr=several_peaks_nr, limit=limit, rot=rot, text_left=text_left, add_text=add_text, text_extra=text_extra, extend=extend, add=add, alpha=alpha, clip_on=clip_on, several_peaks=several_peaks) df_passed.append(int(freqs[f])) p_passed.append(p_scatter) return p_passed def get_add_for_several_peaks(add_log, df_passed, emb, exact, f, freqs, j, log, p_arrays, perc, rel, add_not_log = 0, j_extra=False): if rel == 'rel': peak_in_peaks = check_if_peak_occured(df_passed, exact, f, freqs) if peak_in_peaks: count = count_of_occurance(df_passed, exact, f, freqs) if log == 'log': add = (np.max(np.max(p_arrays)) * perc) + add_log * count else: add = (np.max(np.max(p_arrays)) * perc * count)+add_not_log*count add_ext = 30 add_text = add + add_ext * count if j_extra: # ich wüsste nicht warum man das brauchen würde if j == 0: add = (np.max(np.max(p_arrays)) * perc * count) add_text = add + add_ext * count if emb: embed() else: add = (np.max(np.max(p_arrays)) * 0.01) # * 0.01 add_text = (np.max(np.max(p_arrays)) * perc) # * 0.01 else: if int(freqs[f]) in df_passed: add = 60 add_text = add + 30 if j_extra: if j == 0: add = (np.max(np.max(p_arrays)) * perc) # * 0.25 add_text = add + 30 else: add = 30 add_text = (np.max(np.max(p_arrays)) * perc) # * 0.01 return add, add_text def count_of_occurance(df_passed, exact, f, freqs): count = 0 for df in df_passed: if exact == True: if int(freqs[f]) == df: count += 1 else: count = np.sum(np.abs(df_passed - freqs[f]) < 10) return count def check_if_peak_occured(df_passed, exact, f, freqs): if exact == True: peak_in_peaks = int(freqs[f]) in df_passed else: if len(df_passed) == 0: peak_in_peaks = False else: peak_in_peaks = np.min(np.abs(df_passed - freqs[f])) < 10 return peak_in_peaks def plt_peaks(ax, p, fr, f_axis, several_peaks_nr=2, zorder=2, marker='o', markeredgecolor=None, several_peaks=True, ha='left', limit=None, text_left=0, rot=45, add_text=0, text_extra=False, extend=True, fr_color='grey', add=0, s=12, label='',f_scatter = None, p_scatter = None, alpha=1, clip_on=False): # DAS ist die RICHTIGE Variante if fr < f_axis[-1]: minimum = np.argmin(np.abs(fr - f_axis)) #embed() try: if several_peaks: # das machen wir in der Regel bei Power Spektren max_pos, minimums = chose_beat_peak(minimum, p, several_peaks_nr) else: # das hier in den anderen Fällen minimums = [minimum] max_pos = np.argmax(p[minimums]) except: print('maxima things') embed() new_f = minimums[max_pos] max_pos_neg = np.argmin(p[minimums]) new_f_neg = minimums[max_pos_neg] if not markeredgecolor: markeredgecolor = fr_color try: if p[new_f] > p[new_f_neg]: f_scatter = f_axis[new_f] p_scatter = p[new_f] if limit: if p_scatter > limit: cont = True else: cont = scatter_peaks() else: cont = True if cont: if label != '': ax.scatter(f_scatter, p_scatter + add, color=fr_color, zorder=zorder, s=s, label=label, clip_on=clip_on, alpha=alpha, edgecolor=markeredgecolor, marker=marker) else: ax.scatter(f_scatter, p_scatter + add, color=fr_color, zorder=zorder, s=s, clip_on=clip_on, alpha=alpha, marker=marker, edgecolor=markeredgecolor) if extend: ax.plot(f_axis[new_f - 2: new_f + 3], p[new_f - 2: new_f + 3], color=fr_color, alpha=0.5, zorder=100) if text_extra: # +add_text ax.text(f_scatter - text_left, p_scatter + add + add_text, label, ha=ha, rotation=rot, color=fr_color) else: try: max_pos = np.argmin(p[minimums]) new_f = minimums[max_pos] except: new_f = minimum # minimums[minimum] f_scatter = f_axis[new_f] p_scatter = p[new_f] if limit: if p_scatter > limit: cont = True else: cont = False else: cont = True if cont: if label != '': ax.scatter(f_scatter, p_scatter - add, color=fr_color, zorder=2, s=s, label=label, clip_on=clip_on, alpha=alpha, edgecolor=markeredgecolor, marker=marker) else: ax.scatter(f_scatter, p_scatter - add, color=fr_color, zorder=2, s=s, clip_on=clip_on, alpha=alpha, edgecolor=markeredgecolor, marker=marker) if extend: ax.plot(f_axis[new_f - 2:new_f + 3], p[new_f - 2:new_f + 3], color=fr_color, alpha=0.5, zorder=100) if text_extra: # +add_text ax.text(f_scatter + 4, p_scatter - add + 2 + add_text, label, rotation=rot, color=fr_color, ha=ha) except: print('peaks thing inside') embed() return f_scatter, p_scatter def chose_beat_peak(minimum, p, several_peaks_nr): # das machen wir in der Regel bei Power Spektren try: minimum_array = [minimum] * (several_peaks_nr * 2 + 1) minus_array = np.arange(0, several_peaks_nr * 2 + 1, 1) - several_peaks_nr minimums = minimum_array + minus_array # [minimum - 2, minimum - 1, minimum, minimum + 1, minimum + 2] max_pos = np.argmax(p[minimums]) except: minimums = [minimum] max_pos = np.argmax(p[minimums]) return max_pos, minimums def scatter_peaks(): cont = False return cont def calc_beat_spikes(final_eod, sampling_rate, final_DF, i, cell, plus_bef, minus_bef, version='spikes', data_beat=[], trial_nr=0): ll = np.abs(plus_bef) ul = np.abs(minus_bef) df = final_DF[i] eod = final_eod[i] len_smoothed = [] len_smoothed_b = [] if version == 'spikes': if len(data_beat[data_beat['df'] == df]['spikes']) == 1: tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[0]) if len(data_beat[data_beat['df'] == df]['spikes'].iloc[0]) == 1: tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[0][0]) else: tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[trial_nr]) size = int(tranformed_spikes[-1] * sampling_rate + 5) # duration.iloc[0] spikes_mat = np.zeros(size) spikes_idx = np.round(tranformed_spikes * sampling_rate) for spike in spikes_idx: spikes_mat[int(spike)] = 1 * sampling_rate smoothed = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate) else: spikes_mat = [] spikes = cell[cell['df'] == df]['local'] if len(spikes) == 1: tranformed_spikes = np.array(spikes.iloc[0]) else: tranformed_spikes = np.array(spikes.iloc[trial_nr]) smoothed = tranformed_spikes * 1 smoothed[smoothed < 0] = 0 _, _ = ml.psd(smoothed ** 3 - np.mean(smoothed ** 3), Fs=sampling_rate, NFFT=2 ** 15, noverlap=2 ** 14) plot_segment = ul - ll _, _ = ml.psd(smoothed - np.mean(smoothed), Fs=sampling_rate, NFFT=4096, noverlap=4096 // 2) corr = create_beat_corr2(df, eod) # den Beat nehmen wir aus den Daten als das local EOD time = np.arange(0, len(smoothed) / sampling_rate, 1 / sampling_rate) beat_version = 'sumu' if beat_version == 'local': beat = data_beat[data_beat['df'] == df]['local'] if len(beat) == 1: beat = np.array(beat.iloc[0]) if len(beat) == 1: beat = np.array(beat.iloc[0][0]) else: beat = np.array(beat.iloc[trial_nr]) else: if len(data_beat[data_beat['df'] == df]['efield']) == 1: efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[0]) if len(efield) == 1: efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[0][0]) else: efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[trial_nr]) efield = zenter_and_normalize(efield, 0.2) if len(data_beat[data_beat['df'] == df]['global']) == 1: global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[0]) if len(global_eod) == 1: global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[0][0]) else: global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[trial_nr]) global_eod = zenter_and_normalize(global_eod, 1) beat = global_eod + efield if 'ds' in data_beat.keys(): ds = int(data_beat.ds.iloc[0]) time_beat = np.arange(0, ds * len(beat) / sampling_rate, 1 / sampling_rate) beat = interpolate(time_beat[::ds], beat, time_beat, kind='cubic') beat3 = beat * 1 beat3[beat3 < 0] = 0 _, _ = ml.psd(beat3 ** 3 - np.mean(beat3 ** 3), Fs=sampling_rate, NFFT=2 ** 15, noverlap=2 ** 14) # period bestimmen wir lieber aus dem corr, weil das f_max ist nfft abhäängig period = 1 / corr if period < plot_segment: pass else: pass # und diese Shifts die sollten hatl ja die Länge des segments haben und keine 0.05 Sekunden.. ################################### # ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht # das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte shift_period = 0.005 # period * 2# shifts = np.arange(0, 200 * shift_period, shift_period) time_b = np.arange(0, len(beat) / sampling_rate, 1 / sampling_rate) am_corr = extract_am(beat, time_b, eodf=eod, norm=False, extract='globalmax', kind='cubic')[0] len_smoothed, smoothed_trial, all_spikes, maxima, error, spikes_cut, beat_cut, am_corr_cut = create_shifted_spikes( eod, len_smoothed_b, len_smoothed, beat, am_corr, sampling_rate, time_b, time, smoothed, shifts, plot_segment, tranformed_spikes, version=version) am_final, beat_final, most_similiar, spike, spike_sm = get_most_similiar_spikes(all_spikes, am_corr_cut, beat_cut, error, maxima, spikes_cut) test = False if test: from utils_test import test_spikes test_spikes() test = False if test == True: from utils_test import test_maximal test_maximal() return am_final, beat_final, smoothed, tranformed_spikes, spike_sm, spike, spikes_mat, plot_segment def gaussian_intro(): return 0.001 def plot_power(ax, stim_f, spikes_mat, sampling_rate, main_color, eod, color, i, ms=3, nfft=4096 * 4): p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2) db = 10 * np.log10(p / np.max(p)) ax.plot(f, db, zorder=1, color=main_color, linewidth=1) maxi = np.argmax(db[f < 0.5 * eod[i]]) # hier habe ich die eine Funktion wo man nur die Frequenzen und Farben reingibt und die kümmert sich um die Punkte, dass sie # nicht überlappen etc. xlim_max = 1000 if stim_f < xlim_max: freqs = [eod[i], f[maxi], stim_f] else: freqs = [eod[i], f[maxi]] plt_peaks_several(freqs, [db], ax, db, f, ['', '', ''], 5, ['white', color, 'black'], markeredgecolors=['black', color, 'black'], add_log=0.1, several_peaks_nr=4, rel='rel', ms=ms, clip_on=False, log='log') ax.axvline(x=eod[i] / 2, color='black', linestyle='dashed', lw=0.5) ax.set_xlim(0, xlim_max) ax.set_ylim(-20, 10) ax.show_spines('b') ax.set_xticks_delta(500) if i >= 3: ax.set_xlabel('Frequency [Hz]') else: ax.set_xticks_blank() if i % 3 == 0: ax.text(-0.05, 0.5, 'psd', transform=ax.transAxes, rotation=90, va='center', ha='right') if i % 3 == 2: ax.yscalebar(1.05, 0.0, 20, 'dB', ha='right') def plt_RAM_explained_single3(): plot_style() cells = ["2012-06-27-ah-invivo-1"] # ,"2013-01-08-aa-invivo-1" , "2014-06-06-ac-invivo-1"] for run in range(1): default_figsize(column=2, length=2.5) grid = gridspec.GridSpec(1, 1, wspace=0.35, left=0.1, top=0.95, bottom=0.13, right=0.87, hspace=0.35) grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.5, hspace=0.4, subplot_spec=grid[0]) for _, _ in enumerate(cells): ''' arrays, arrays2, colors, deltat, spikes = get_RAM_stimulus(cell, exp_tau, exponential, lower_tol, model_cells, upper_tol, v_exp) ################################################## plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes)''' ################################################## ################################## # model part ax = plt.subplot(grid2[:]) perc, im, stack_final, stack_saved = plt_model_big(ax) set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95) plt.subplots_adjust(hspace=0.85, wspace=0.25) save_visualization(str(run), False, counter_contrast=0, savename='') def plt_RAM_explained_single2(exponential=''): plot_style() default_settings(width=12) # , ts=12, ls=13, fs=11 cells = ["2012-06-27-ah-invivo-1"] # ,"2013-01-08-aa-invivo-1" , "2014-06-06-ac-invivo-1"] model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv") for run in range(1): default_settings(column=2, length=3) grid = gridspec.GridSpec(1, 2, wspace=0.35, left=0.1, top=0.95, bottom=0.16, right=0.87, hspace=0.35) grid1 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.5, hspace=0.02, subplot_spec=grid[0]) grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.5, hspace=0.4, subplot_spec=grid[1]) for c, cell in enumerate(cells): arrays, arrays2, colors, deltat, spikes = get_RAM_stimulus(cell, exponential, model_cells) ################################################## plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes) ################################################## ################################## # model part ax = plt.subplot(grid2[:]) perc, im, stack_final, stack_saved = plt_model_big(ax) set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95) fig = plt.gcf() fig.tag([fig.axes[1], fig.axes[-2]], xoffs=-5.5, yoffs=0.7) plt.subplots_adjust(hspace=0.85, wspace=0.25) save_visualization(str(run), False, counter_contrast=0, savename='') def plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes): ax = plt.subplot(grid1[1]) ax.eventplot(np.array(spikes) * 1000, color='black') ax.show_spines('') rx = 0.1 * 1000 ax.set_xlim(0, rx) ylabel = ['', '', 'Firing Rate [Hz]'] for i in range(len(arrays)): if arrays[i] != '': ax = plt.subplot(grid1[i]) try: ax.plot(np.arange(0, len(arrays[i]) * deltat, deltat) * 1000, arrays[i], color=colors[i]) except: print('arrays problem') embed() if arrays2[i] != '': ax.plot(np.arange(0, len(arrays2[i]) * deltat, deltat) * 1000, arrays2[i], color='black') ax.set_xlim(0, rx) if i < 2: ax.show_spines('') remove_xticks(ax) remove_xticks(ax) ax.set_ylabel(ylabel[i]) ax.set_xlabel('Time [ms]') def phaselocking_loss2(show=True): _, _ = find_all_dir_cells() data_names = ['2019-09-10-ae-invivo-1'] plot_style() default_figsize(column=2, length=3.5) for data_name in data_names: print(data_name) ############################################# # print traces name_core = load_folder_name('data') + 'cells/' + data_name nix_name = name_core + '/' + data_name + '.nix' # '/' if os.path.exists(name_core): f = nix.File.open(nix_name, nix.FileMode.ReadOnly) nix_there = True if nix_there: b = f.blocks[0] all_mt_names = find_mt_all(b) ts = find_tags_list(b, names='ficurve') if len(ts) > 0: for n, names_mt_gwn in enumerate(all_mt_names): if ('rectangle' in names_mt_gwn) | ('FI' in names_mt_gwn): mt = b.multi_tags[names_mt_gwn] features, delay_name = feature_extract(mt, ) Intensity, preIntensity, contrasts, precontrasts = find_contrasts(features, mt) if len(np.shape(contrasts)) > 1: contrasts = np.concatenate(contrasts) negativ = 'negativ' # 'positiv'#'highest'#'negativ' # 'positiv' val = 31 indeces_show = np.arange(0, len(contrasts), 1)[ (contrasts > val - 3) & (contrasts < val + 3)] ##[np.argsort(contrasts)[-1]] save_name = load_folder_name('calc_FI_Curve') + '\FI5_with_f0_nfft_16384' # FI_with_f0' frame = pd.read_csv(save_name + '.csv') names_all = [['ss_s', 'ss_r']] # , [['ss_s', 'on_s']] # , 'on_s', 'on_r', linestyles = ['-', '--', '-', '--', '-', '--', '-', '--'] _, _ = find_row_col(frame.cell.unique()) axes = [] grid0 = gridspec.GridSpec(1, 2, bottom=0.15, top=0.92, left=0.1, right=0.98, wspace=0.27, hspace=0.45, width_ratios=[2, 1.3]) # gridr = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid0[1], hspace=0.5) frame_cv = pd.read_pickle( load_folder_name('calc_base') + '/calc_base_data-base_frame.pkl') not_frame_punits_cells = frame_cv[ frame_cv['cell_type_reclassified'] != ' P-unit'].cell.unique() cells_saved = frame.cell.unique() c_counter = 0 cells_p_units = np.setdiff1d(list(cells_saved), list(not_frame_punits_cells)) frame_cell = frame[frame.cell == data_name] frame_cell['contrasts'] = np.round(frame_cell['contrasts']) frame_cell = frame_cell.groupby('contrasts', as_index=True).mean() sort_order = np.argsort(frame_cell.index) logs = [''] # , 'log' lables = ['Onset State Curve', 'Steady State Curve'] # coloro = 'red' colors = 'green' delta = 200 for l in range(len(logs)): my_curves = 'chirp' if my_curves == 'new': for nn, names in enumerate(names_all): ax = plt.subplot(gridr[nn, l]) for n, name in enumerate(names): contrast_labels = frame_cell.index[sort_order] / 100 flip = False if flip: steady_func = frame_cell[name].iloc[sort_order][::-1] ax.scatter(-contrast_labels, steady_func, color=colors[n], linestyle=linestyles[n]) # label=name, else: steady_func = frame_cell[name].iloc[sort_order] ax.scatter(contrast_labels, steady_func, color=colors[n], linestyle=linestyles[n]) # label=name, try: plt_FI_curve(contrast_labels, bolzmann_steady, steady_func, label=lables[n], color=colors[n]) except: print('color something') embed() if logs[l] == 'log': ax.set_xscale(logs[l]) else: ax.axvline(0, color='grey', linestyle='--', linewidth=0.5) if l == 0: ax.legend(ncol=2, loc=(0, 1.05)) ax.set_ylabel('Firing Frequency [Hz]') else: remove_yticks(ax) elif my_curves == 'chirp': onset_state, steady_state, sorted_contrast, steady, onset, indices, mean_indices = np.load( load_folder_name( 'calc_FI_Curve') + '/F_I_curve-distances5st_nm_w2_dm_alpha_consp2_bnr6_ROCsFI_cells.npy', allow_pickle=True) ax = plt.subplot(gridr[0]) ax.plot(sorted_contrast[data_name], steady_function(sorted_contrast[data_name], steady[data_name][0], steady[data_name][1], steady[data_name][2]), zorder=1, label='Fitted function for the steady F-I Cure', color='black') ax.plot(sorted_contrast[data_name], onset_function(sorted_contrast[data_name], onset[data_name][0], onset[data_name][1], onset[data_name][2]), label='Fitted function for the onset F-I Cure', zorder=1, color='black') s = 30 steady_val = np.array(steady_state[data_name])[sorted_contrast[data_name] == val] onset_val = np.array(onset_state[data_name])[sorted_contrast[data_name] == val] ax.scatter(sorted_contrast[data_name], onset_state[data_name], s=s, clip_on=False, color=coloro, zorder=120, alpha=0.5) # color='black', ax.scatter(sorted_contrast[data_name], steady_state[data_name], s=s, clip_on=False, color=colors, zorder=120, alpha=0.5) # color='grey', ax.scatter(sorted_contrast[data_name][sorted_contrast[data_name] == val], onset_val, s=s, color=coloro, clip_on=False, zorder=100, alpha=1, edgecolor='black') ax.scatter(sorted_contrast[data_name][sorted_contrast[data_name] == val], steady_val, s=s, color=colors, clip_on=False, zorder=100, alpha=1, edgecolor='black') ax.set_yticks_delta(delta) print('onsetsnip' + str(onset_val)) print('steadysnip' + str(steady_val)) else: data = pd.read_csv( '../data/Kennlinien/cell_fi_curves_csvs/' + cell + '.csv') # not found ''' df_cell['inputs'].iloc[0] df_cell['on_r'].iloc[0] df_cell['ss_r'].iloc[0] df_cell['on_s'].iloc[0] df_cell['ss_s'].iloc[0] data['contrasts']''' sort_data = np.argsort(data['contrasts']) plt.scatter(data['contrasts'][sort_data], (data['f_onset'][sort_model] - ymin) / ymax, color='black', s=7, zorder=2) # not found colors = ['green', 'blue', 'orange', 'pink', 'purple', 'red'] plt.xlabel('Contrast [%]') plt.ylabel('Firing Rate [Hz]') ax.set_xlabel('Contrast [$\%$]') ax.set_ylabel('Firing Rate [Hz]') axes.append(ax) spike_mats = [] smootheneds = [] for idx, mt_idx in enumerate(indeces_show): # range(len(mt.positions[:])) print(idx) gridl = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid0[0], hspace=0.1, wspace=0.35) try: delay_orig = mt.features[delay_name].data[:][mt_idx][0] except: delay_orig = mt.features[delay_name].data[:][mt_idx] delay = delay_and_reality_alignment(mt, mt.extents[mt_idx], mt_idx, mt.extents[mt_idx]) negativ = 0.5 if delay < negativ: negativ = delay if negativ < 0: negativ = delay_orig if mt_idx != len(mt.extents) - 1: try: pass except: print('positiv thing') embed() positive = 0.5 if delay < positive: positive = delay else: positive = delay # ich glaube da gibts porbleme wenn das mt davor oder danach negativ war # deswegen werden beide in dem fall einfach zum delay! if positive < 0: positive = negativ * 1 if positive < 0: positive = delay_orig print('positive thing') embed() duration = mt.extents[mt_idx][0] if mt.extents[mt_idx] > 0: f_snippet = np.min([duration / 2, negativ, positive]) if f_snippet < 0: print('snippet thing') embed() if f_snippet > 0.1: start_time = mt.positions[mt_idx] - negativ eod_g, spikes_mt, sampling = link_arrays(b, start_time, duration + negativ + positive, start_time, load_eod_array='LocalEOD-1') # 'EOD' spike_mats.append(spikes_mt) v1, sampling = link_arrays_eod(b, start_time, duration + negativ + positive, array_name='V-1') eod_field, sampling = link_arrays_eod(b, start_time, duration + negativ + positive, array_name='GlobalEFieldStimulus') spikes_mat = cr_spikes_mat(spikes_mt, sampling, int((mt.extents[ mt_idx] + negativ + positive) * sampling)) smoothened = gaussian_filter(spikes_mat, sigma=0.001 * sampling) smootheneds.append(smoothened) dt = 1 / sampling axs = [] xlim = [-0.1 * 1000, 0.5 * 1000] ax = plt.subplot(gridl[0]) axes.append(ax) ax.set_xlim(xlim) time_array = np.arange(0, len(eod_g) * dt, dt) - negativ time_fish_e = time_array * 2 * np.pi * 750 # eod_fe[e] eod_g = 100 * np.sin(time_fish_e) eod_g[(time_array > 0) & (time_array < 0.4)] = eod_g[(time_array > 0) & ( time_array < 0.4)] * ((100 + val) / 100) ax.plot(time_array * 1000, eod_g, color='grey', linewidth=0.5) # 0.2 axs.append(ax) ax.set_ylabel('Contrast [$\%$]') remove_xticks(ax) ax.show_spines('l') ax.set_title('Contrast\,$=%s$' % val + '\,$\%$') ax = plt.subplot(gridl[1]) axs.append(ax) ax.set_xlim(xlim) time_array = np.arange(0, len(v1) * dt, dt) - negativ ax.eventplot((spike_mats - negativ) * 1000, color='black', linewidths=0.3) ax.show_spines('l') ax.set_ylabel('Trials') remove_xticks(ax) ax = plt.subplot(gridl[-1]) ax.set_xlabel('Time [ms]') ax.set_ylabel('Firing Rate [Hz]') smoothed_mean = np.mean(smootheneds, axis=0) time_cut = time_array[0: len(np.mean(smootheneds, axis=0))] * 1000 ax.plot(time_cut, smoothed_mean, color='black') ax.set_xlim(xlim) ax.axhline(np.mean(smoothed_mean[time_cut < 1]), color='grey', linewidth=0.5) ax.set_yticks_delta(delta) minus = 15 onset_snip = smoothed_mean[(time_cut > start_fi_o()) & (time_cut < end_fi_o() - minus)] steady_snip = smoothed_mean[ (time_cut > (300 + start_fi_s())) & (time_cut < (300 + end_fi_s()))] print('onsetsnip' + str(np.mean(onset_snip))) print('steadysnip' + str(np.mean(steady_snip))) ax.plot(time_cut[(time_cut > start_fi_o()) & (time_cut < end_fi_o() - minus)], onset_snip, color=coloro) ax.plot(time_cut[(time_cut > (300 + start_fi_s())) & (time_cut < (300 + end_fi_s()))], steady_snip, color=colors) c_counter += 1 print(show) fig = plt.gcf() fig.tag(axes[::-1], xoffs=-3.5, yoffs=1.8) save_visualization( data_name + '_idx_' + str(mt_idx) + '_contrast_' + str(contrasts[mt_idx]), show) print('finished plotting') def plt_model_small(ax, pos_rel=-0.07, ls='--', lw=0.5, cell='2012-07-03-ak-invivo-1', colorx='black', colory='black'): cells_given = [cell] # doch das müsste jetzt mit denen hier funkionieren save_name_rev = load_folder_name( 'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_revQuadrant_' save_name = load_folder_name( 'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' cell_add, cells_save = find_cell_add(cells_given) perc = 'perc' path_rev = save_name_rev + '.pkl' # '../'+ path = save_name + '.pkl' # '../'+ # path_rev = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1__revQuadrant_2012-07-03-ak-invivo-1.csv' # path = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1_2012-07-03-ak-invivo-1.csv' # stack_rev, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path_rev, save_name_rev, redo=False, creation_time_update=True, size_update=True) # direct_load = True, stack, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path, save_name) # direct_load = True # embed() full_matrix = create_full_matrix2(np.array(stack), np.array(stack_rev)) stack_final = stack#get_axis_on_full_matrix(full_matrix, stack) stack_final, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_final, add_nonlin_title='k') # , add_nonlin_title = 'k' add_nonlin_title = '' im = plt_RAM_perc(ax, perc, np.abs(stack_final)) set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95) set_xlabel_arrow(ax, xpos=1, ypos=pos_rel, color=colorx) set_ylabel_arrow(ax, xpos=pos_rel, ypos=0.97, color=colory) cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01) cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8) ax.axhline(0, color='white', linewidth=lw, linestyle=ls) ax.axvline(0, color='white', linewidth=lw, linestyle=ls) return perc, im, stack_final def plt_model_big(ax, pos_rel=-0.07, ls='--', lw=0.5, cell='2012-07-03-ak-invivo-1', colorx='black', colory='black', lines = True): cells_given = [cell] # doch das müsste jetzt mit denen hier funkionieren save_name_rev = load_folder_name( 'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_revQuadrant_' save_name = load_folder_name( 'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV' cell_add, cells_save = find_cell_add(cells_given) perc = 'perc' path_rev = save_name_rev + '.pkl' # '../'+ path = save_name + '.pkl' # '../'+ # path_rev = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1__revQuadrant_2012-07-03-ak-invivo-1.csv' # path = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1_2012-07-03-ak-invivo-1.csv' # stack_rev, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path_rev, save_name_rev, redo=False, creation_time_update=True, size_update=True) # direct_load = True, stack, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path, save_name, redo = True) # direct_load = True #embed() full_matrix = create_full_matrix2(np.array(stack), np.array(stack_rev)) stack_final = get_axis_on_full_matrix(full_matrix, stack) add_nonlin_title = '' stack_final, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_final, add_nonlin_title= add_nonlin_title) # , add_nonlin_title = 'k' im = plt_RAM_perc(ax, perc, np.abs(stack_final)) set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95) set_xlabel_arrow(ax, xpos=1, ypos=pos_rel, color=colorx) set_ylabel_arrow(ax, xpos=pos_rel, ypos=0.97, color=colory) cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01) cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8) #embed() if lines: ax.axhline(0, color='white', linewidth=lw, linestyle=ls) ax.axvline(0, color='white', linewidth=lw, linestyle=ls) stack_final = stack_final * resize_val return perc, im, stack_final, stack_saved def FI_curves_plot(contrast_labels, onset_function, onset_state, steady_function, steady_state, label='Fitted function for the onset F-I Cure'): plt_FI_curve(contrast_labels, steady_function, steady_state) plt.scatter(contrast_labels, onset_state, color='black') params_o, params_covariance = optimize.curve_fit(onset_function, contrast_labels, onset_state, bounds=( [0, -np.inf, - np.inf], [np.max(onset_state) * 4, np.inf, np.inf])) plt.plot(contrast_labels, onset_function(contrast_labels, params_o[0], params_o[1], params_o[2]), label=label, color='black') plt.ylim([0, 720]) plt.xlabel('Contrasts [%]', labelpad=10) plt.ylabel('Firing Frequency [Hz]') plt.legend(['Steady State Curve', 'Onset State Curve']) def plt_FI_curve(contrast_labels, steady_function, steady_state, color='grey', label='Fitted function for the steady F-I Cure'): plt.scatter(contrast_labels, steady_state, color='grey') params_s, params_covariance = optimize.curve_fit(steady_function, contrast_labels, steady_state, bounds=( [0, -np.inf, - np.inf], [np.max(steady_state) * 4, np.inf, np.inf])) plt.plot(contrast_labels, steady_function(contrast_labels, params_s[0], params_s[1], params_s[2]), label=label, color=color) def second_saturation_freq(): return (39.5, -10.5) def time_nonlin_first_sine(sampling=20000, f0=40, duration=0.1, amp = 1): delta = 1 / sampling time_array = np.arange(0, duration, 1 / sampling) time_s = time_array * 2 * np.pi * f0 sine = np.sin(time_s)*amp return delta, f0, sine, time_array def second_sine(f0, time_array, amp=1.25, phase=1): time_s = time_array * 2 * np.pi * f0 sine = amp * np.sin(time_s + phase) return sine def circle_plot(ax, ax_prev, ws=None, lw=1.5): if not ws: ws = ws_nonlin_systems() #rectangle = plt.Circle((0, 0), fc='black', ec="black") rectangle = plt.Rectangle((0, 0), 20, 20, fc='black', ec="black") ax.add_patch(rectangle) #ax.set_title('$H\{s(t)\}$') #embed() ax.show_spines('') def rectangle_plot(ax, ax_prev, ws=None, lw=1.5): if not ws: ws = ws_nonlin_systems() rectangle = plt.Rectangle((0, 0), 20, 20, fc='black', ec="black") ax.add_patch(rectangle) ax.annotate('', ha='center', xycoords='axes fraction', xy=(1 + ws, 0.5), textcoords='axes fraction', xytext=(1, 0.5), arrowprops={"arrowstyle": "->", "linestyle": "-", "linewidth": lw, "color": 'black'}, zorder=1, annotation_clip=False, transform=ax_prev.transAxes, ) ax.set_title('$H\{s(t)\}$') ax.show_spines('') def base_csvs_save(cell, frame=[], load_folder='calc_base'): if len(frame) < 1: path_sascha = load_folder_name('calc_base') + '/' + 'calc_base_data-base_frame_nfftmedium__overview.pkl' frame = pd.read_pickle(path_sascha) frame_c = frame[frame.cell == cell] frame_cell = pd.DataFrame() spikes_all, isi, frs_calc, spikes_cont = load_spikes(np.array(frame_c.spikes.iloc[0]), 1, ms_factor=1) spikes_all, pos_reshuffled = reshuffle_spike_lengths(spikes_all) save_spikestrains_several(frame_cell, spikes_all) frame_cell['fr'] = frame_c.fr.iloc[0] if len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 2: vars = [frame_c['EODfs_medium'].iloc[0][0], frame_c['freq_steps_medium'].iloc[0][0], frame_c['EODfs'].iloc[0][0], frame_c['freq_steps_trial'].iloc[0][0]] elif len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 1: vars = [frame_c['EODfs_medium'].iloc[0], frame_c['freq_steps_medium'].iloc[0], frame_c['EODfs'].iloc[0], frame_c['freq_steps_trial'].iloc[0]] elif len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 3: vars = [frame_c['EODfs_medium'].iloc[0][0][0], frame_c['freq_steps_medium'].iloc[0][0][0], frame_c['EODfs'].iloc[0][0][0], frame_c['freq_steps_trial'].iloc[0][0][0]] names = ['EODf_res', 'freq_step_res', 'EODf', 'freq_step_trial'] frame_cell = reshuffle_eodfs(frame_cell, names, pos_reshuffled, vars) lim = find_lim_here(cell, 'individual') frame_cell['burst_corr_individual'] = float('nan') frame_cell['burst_corr_individual'].iloc[0] = lim frame_cell['sampling'] = frame_c.sampling.iloc[0] frame_cell['cell'] = frame_c.cell.iloc[0] frame_cell['eod_fr'] = frame_c.EODf.iloc[0] save = True # .iloc[0] if save: frame_cell.to_csv(load_folder + '/base_csvs_save-spikesonly_' + cell + '.csv') del frame del frame_cell return frame_c def reshuffle_eodfs(frame_cell, names, pos_reshuffled, vars, res_name='arbitrary'): stack_sp = {} for v, var in enumerate(vars): if (res_name not in names[v]) & ('all' not in names[v]): try: var = np.array(var)[pos_reshuffled] except: print('reshuffle thing') stack_sp = resave_vars_corr(names, res_name, stack_sp, v, var) for key in stack_sp: if key not in frame_cell.keys(): frame_cell[key] = np.float('nan') try: frame_cell[key].loc[0] = stack_sp[key] # .iloc[0] except: print('sequence thing') embed() return frame_cell def save_spikestrains_several(frame_cell, spikes_all): for ss, sp in enumerate(spikes_all): try: frame_cell['spikes' + str(ss)] = sp # frame_c.spikes.iloc[0][0][0] except: frame_cell['spikes' + str(ss)] = float('nan') frame_cell['spikes' + str(ss)].iloc[0:len(sp)] = sp print('spikes something') return frame_cell def reshuffle_spike_lengths(spikes_all): lengths = [] for r in spikes_all: lengths.append(len(r)) pos_reshuffled = np.argsort(lengths)[::-1] spikes_all = np.array(spikes_all)[pos_reshuffled] return spikes_all, pos_reshuffled def rename(model_folder, dir_prev, dir_new, function_name='calc_phaselocking-'): # damit das nicht mehrmals passiert not_renamed = True if (function_name not in dir_prev) | (function_name == ''): change_to = model_folder + '/' + function_name + dir_new if not os.path.exists(change_to): try: os.rename(model_folder + '/' + dir_prev, change_to) except: print('some problem renaming') embed() not_renamed = False else: print('already there:', model_folder + '/' + dir_prev, 'to', change_to) not_renamed = True else: change_to = model_folder + '/' + function_name + dir_prev return not_renamed, change_to def title_motivation(): titles = [f_eod_name_rm(), r'$' + f_eod_name_core_rm() + '$ \& $f_{1}$', r'$' + f_eod_name_core_rm() + '$ \& $f_{2}$', r'$' + f_eod_name_core_rm() + '$ \& $f_{1}$ \& $f_{2}$', []] ##'receiver + ' + 'receiver + receiver return titles def rem_variable(rm_var = {'rm':True, 'size': 'small'}): return rm_var def f_eod_name_rm(): return r'$' + f_eod_name_core_rm() + '$' def f_eod_name_core_rm(): rm_var = rem_variable() if rm_var['rm'] == True: val = r'f\rm{_{EOD}}' else: val = r'f_{EOD}' return val def exp_params(exp_tau, exponential, v_exp): if exponential == '': v_exp = 1 exp_tau = 0.001 elif exponential == 'EIF': v_exp = np.array([0]) exp_tau = np.array([0.001, 0.01, 0.1]) # 10 elif exponential == 'CIF': v_exp = np.array([0, 0.5, 1, 1.5, 2, 0.2, -0.5, -1]) # exp_tau = np.array([0]) # 10 return exp_tau, v_exp def filter_square_params(c_grouped, cell_here, frame, frame_cell_orig): new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique() dfs = [tup[0] for tup in new_f2_tuple] frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)] frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell) diffs = find_deltas(frame_cell, c_grouped[0]) frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original') new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score'] matrix = new_frame.pivot(index='df2', columns='df1', values='diff') return frame_cell, matrix def plt_square_here3(ax, frame, score_here, c_nr=0.1, cls="RdBu_r", c_here='c1'): cs = frame[c_here].unique() c_chosen = cs[np.argmin(np.abs(cs - c_nr))] frame_cell = frame[(frame[c_here] == c_chosen)] # & (frame[c_here] == c_h)] frame_cell = frame_cell[~ (frame_cell.f1 == frame_cell.f2)] frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell) new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).mean() # ['score'] matrix = new_frame.pivot(index='df2', columns='df1', values=score_here) try: im = ax.pcolormesh( np.array(list(map(float, matrix.columns))), np.array(matrix.index), matrix, cmap=cls, rasterized=False) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95), except: print('ims probelem') embed() return im, matrix def roc_filename2(): return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal' def roc_filename1(): return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal' def roc_filename0(): return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal' # 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_1.0_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal', def isi_xlabel(): return '1/$'+f_eod_name_core_rm()+'$' def get_global_eod_for_eodf(b, duration, mt, mt_nr_small): try: global_eod, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small], duration, 'EOD') except: try: global_eod, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small], duration, 'EOD-1') except: b_names = get_data_array_names(b) for b_name in b_names: if 'eod' in b_name: print('EOD 1 prob') embed() sampling = 0 return global_eod, sampling def get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, nfft_eod, sampling): if sampling > 0: if len(global_eod) / sampling > 0.1: try: if len(eodf_orig) > 0: try: eod_fr_orig = eodf_orig.iloc[mt_nr_small] except: eod_fr_orig = eodf_orig[mt_nr_small] else: eod_fr_orig = eodf_orig eod_fr, p, f = get_eodf(global_eod, b, eod_fr_orig, nfft_eod=nfft_eod) except: print('still eodf problem') embed() else: eod_fr = float('nan') else: eod_fr = float('nan') return eod_fr def get_eodf(global_eod, b, eodf_orig, nfft_eod=2 ** 16): if len(global_eod) > 0: sampling_rate = get_sampling(b, load_eod_array='EOD') eod_fr, p, f = calc_freq_from_psd(global_eod, sampling_rate, nfft=nfft_eod) # v else: p = None f = None if eodf_orig: eod_fr = eodf_orig else: eod_fr = float('nan') if eodf_orig: if not np.isnan(eodf_orig): if (np.max(np.array(eodf_orig)) > 300) & (np.min(np.array(eodf_orig)) < 1000): eodf_orig = np.array(eodf_orig) if np.abs(eod_fr - eodf_orig) > 25: sampling_rate = get_sampling(b, load_eod_array='EOD') # ok hier checke ich nochmal ob diese Effekte stabil sind frequency_saves, frequency = nfft_improval(sampling_rate, eodf_orig, global_eod, eod_fr) if np.min(np.abs(frequency_saves - eodf_orig)) > 300: print('EODf diff too big') # also hier sage ich wenn alle immer noch das gleiche sagen dann passt das schon if np.min(np.abs(frequency_saves - eodf_orig)) < 10: # und hier sage ich wenn es doch unterschiede gibt dann wähle das minimum eod_fr = frequency_saves[np.argmin(np.abs(frequency_saves - eodf_orig))] # print('EODf diff solvable') if (eod_fr > 1000) & (eod_fr < 350): # das ist unmöglich # wenn es unmöglich ist nehmen wir wieder die ursprüngliche Abschätzung eod_fr = eodf_orig test = False if test: plt.plot(f, p) return eod_fr, p, f def calc_freq_from_psd(noise_eod, eod_sampling_frequency, nfft=2 ** 16): p, f = ml.psd(noise_eod - np.mean(noise_eod), Fs=eod_sampling_frequency, NFFT=nfft, noverlap=nfft // 2) eod_fr = f[np.argmax(p)] return eod_fr, p, f def nfft_improval(sampling_rate, frequency_orig, global_stimulus, frequency, nffts=[2 ** 16, 2 ** 15, 2 ** 14, 2 ** 13]): # diesen Teil gibt es wegen einer Zelle, da finde ich das nfft was am nächsten zu der Urpsurngsfrequnez ist frequency_saves = [] for nfft_here in nffts: frequency_save, p, f = calc_freq_from_psd(global_stimulus, sampling_rate, nfft=nfft_here) if np.abs(frequency_orig - frequency_save) < 20: frequency = frequency_save frequency_saves.append(frequency_save) return frequency_saves, frequency def get_nffts_medium(baseline_eod_long, nfft, sampling): freq_step_maximal, maximal_length = get_freq_step(baseline_eod_long, sampling) eod_fr_long, p, f = calc_freq_from_psd(baseline_eod_long, sampling, nfft=nfft) return eod_fr_long, freq_step_maximal, maximal_length def get_freq_step(baseline_eod_long, sampling): maximal_length = len(baseline_eod_long) freq_step_maximal = get_freq_steps(maximal_length, sampling) return freq_step_maximal, maximal_length def get_freq_steps(maximal_length, sampling): freq_step_maximal = sampling / maximal_length return freq_step_maximal def find_eod_fr_mt(b, mts, extended=False, indices=[], freq_step_nfft_eod=0.6103515625): # also manche der arrays haben das ja nicht # Hier laden wir erstmal das was schon da ist und dann verifizieren wir das mit dem power spectrum! # ok wir nehmen einfach die weil für alle tags das zu haben ist schon schwierig! eod_redo, eod_frs_orig_b = find_eod_fr_orig(mts, b, mts.positions[:]) if len(eod_frs_orig_b) != len(mts.positions[:]): eod_redo, eod_frs_orig_b = find_eod_fr_orig(mts, b, mts.positions[:], type='redo') # data_array_names = get_data_array_names(b) features, delay_name = feature_extract(mts) eod_frs_orig = [] for feat in features: if 'EODf' in feat: print(True) if len(indices) > 0: eod_frs_orig = mts.features[feat].data[:] # [indices] else: eod_frs_orig = mts.features[feat].data[:] # [indices] if len(eod_frs_orig) < 1: # | (np.sum(np.isnan(np.array(eod_frs_orig_b)))> 0) if len(indices) > 0: eod_frs_orig = np.array(eod_frs_orig_b) # [indices] else: eod_frs_orig = eod_frs_orig_b eod_frs, eodf_orig, freq_steps_single = find_eodf_three(b, mts, eod_frs_orig, mt_idx=indices, freq_step_nfft_eod=freq_step_nfft_eod, max_eod=True) eod_fr_medium = [] freq_step_medium = [] freq_step_mts = [] eod_fr_mts = [] if extended: ################################## # eodfs for all mts print('doing the rest') mt_poss = concat_mts_pos(mts) # [indices] mt_poss_ind = concat_mts_pos(mts)[indices] min_pos = mt_poss_ind[np.argmin(mt_poss_ind)] mt_nr_small = np.where(mt_poss == min_pos)[0][0] start_pos = mts.positions[:][mt_nr_small] max_pos = mt_poss_ind[np.argmax(mt_poss_ind)] mt_nr_max = np.where(mt_poss == max_pos)[0][0] duration = (np.max(mt_poss_ind) + mts.extents[:][mt_nr_max]) - start_pos global_eod, sampling = get_global_eod_for_eodf(b, duration, mts, mt_nr_small) nfft_eod = len(global_eod) eod_fr_mts = get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, nfft_eod, sampling) freq_step_mts, maximal_length = get_freq_step(global_eod, sampling) ################################## # eodfs for all with higher resolution # hier das mit der desired auflösung noch machen freq_step = 0.01 nfft = int(np.round(sampling / freq_step)) baseline_eod_long = link_arrays_eod(b, first=start_pos, second=nfft / sampling, array_name='EOD')[0] eod_fr_medium, freq_step_medium, maximal_length = get_nffts_medium(baseline_eod_long, nfft, sampling) return eod_frs, eod_frs_orig, eod_fr_medium, freq_step_medium, freq_step_mts, eod_fr_mts, freq_steps_single def find_eod_fr_orig(mts, b, mt_length, type='SAM'): names = get_data_array_names(b) if (mts.name + '_EOD Rate' in names) & (type == 'SAM'): eod_frs = b.data_arrays[mts.name + '_EOD Rate'][:] # sinewave-1 eod_redo = False else: try: eod_frs = b.metadata['Recording']['Subject']['EOD Frequency'] except: eod_frs = float('nan') # b.metadata.pprint(max_depth = -1) eod_frs = [eod_frs] * len(mt_length) eod_redo = True return eod_redo, eod_frs def concat_mts(indices, mt): if len(np.shape(mt.extents[:][indices])) > 1: try: mt_extends = np.concatenate(mt.extents[:][indices]) except: print('still some shape mt problems') embed() else: mt_extends = mt.extents[:][indices] return mt_extends def concat_mts_pos(mts): if len(np.shape(mts.positions[:])) > 1: try: mt_extends = np.concatenate(mts.positions[:]) except: print('still some shape mt problems') embed() else: mt_extends = mts.positions[:] return mt_extends def resave_vars_eodfs(names, stack_sp, vars, res_name='res'): for v, var in enumerate(vars): stack_sp = resave_vars_corr(names, res_name, stack_sp, v, var) return stack_sp def resave_vars_corr(names, res_name, stack_sp, v, var): if (res_name not in names[v]) & ('all' not in names[v]): for vv, var_trial in enumerate(var): stack_sp[names[v] + str(vv)] = var_trial stack_sp[names[v]] = np.mean(var) else: stack_sp[names[v]] = var return stack_sp def names_eodfs(): names = ['EODf', 'EODf_all', 'EODf_res', 'freq_step_trial', 'freq_step_res', 'freq_step_all', ] return names def first_saturation_freq(): return 20.5, -300.5 # (39.5, -210.5)#(39.5, -210.5) def plt_FI_data_alex(cell, data, model, sort_data, sort_model): plt.title(cell) x, d, y, ymax, ymin, _ = interp_fi(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data]) plt.plot(x, d, color='black', linewidth=2, zorder=2, label='data') plt.scatter(data['contrasts'][sort_data], (data['f_onset'][sort_model] - ymin) / ymax, color='black', s=7, zorder=2) colors = ['green', 'blue', 'orange', 'pink', 'purple', 'red'] plt.xlabel('Contrast [%]') plt.ylabel('Firing Rate [Hz]') return colors def load_fi_curves_alex(cell, df_cell, n, nn, results): results.append({}) results[-1]['cell'] = cell results[-1]['n'] = n model = df_cell[df_cell['n'] == n] data = pd.read_csv('../data/Kennlinien/cell_fi_curves_csvs/' + cell + '.csv') ''' df_cell['inputs'].iloc[0] df_cell['on_r'].iloc[0] df_cell['ss_r'].iloc[0] df_cell['on_s'].iloc[0] df_cell['ss_s'].iloc[0] data['contrasts']''' sort_data = np.argsort(data['contrasts']) sort_model = np.argsort(model['inputs'].iloc[0]) mses = mse(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data], model['on_r'].iloc[0][sort_model]) names = ['', '_fmax', '_k', 'half'] for nn, n in enumerate(names): results[-1]['on_r' + n] = mses[nn] mses = mse(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data], model['on_s'].iloc[0][sort_model]) for nn, n in enumerate(names): results[-1]['on_s' + n] = mses[nn] mses = mse(model['inputs'].iloc[0][sort_model], data['f_steady_state'][sort_data], model['ss_r'].iloc[0][sort_model]) for nn, n in enumerate(names): results[-1]['ss_r' + n] = mses[nn] mses = mse(model['inputs'].iloc[0][sort_model], data['f_steady_state'][sort_data], model['ss_s'].iloc[0][sort_model]) for nn, n in enumerate(names): # results[-1]['ss_s' + n] = mses[nn] return data, model, n, nn, sort_data, sort_model def interp_fi(xdata, ydata): try: popt, pcov = curve_fit(bolzmann, xdata, ydata, bounds=( [0, - np.inf, - np.inf], [np.max(ydata) * 4, np.inf, np.inf])) x = np.linspace(xdata[0] * 1.1, xdata[-1] * 1.1, 1000) y = bolzmann(x, *popt) y_norm1 = y - np.min(y) y_norm2 = y_norm1 / np.max(y_norm1) except: popt = [float('nan'), float('nan'), float('nan')] x = float('nan') y = float('nan') y_norm1 = float('nan') y_norm2 = float('nan') plot = False if plot: plt.subplot(1, 2, 1) plt.plot(x, y) plt.scatter(xdata, ydata) plt.subplot(1, 2, 2) plt.plot(x, y_norm2) plt.scatter(xdata, ydata / (np.max(y))) plt.show() return x, y_norm2, y, np.max(y_norm1), np.min(y), popt def mse(x, data, model): _, d, _, _, _, d_popt = interp_fi(x, data) _, m, _, _, _, m_popt = interp_fi(x, model) return np.mean((d - m) ** 2), (d_popt[0] - m_popt[0]) ** 2, (d_popt[1] - m_popt[1]) ** 2, ( d_popt[2] - m_popt[2]) ** 2 def onset_function(x, f_max, k, I_half): return f_max / (1 + np.exp(-k * (x - I_half))) def steady_function(x, f_max, k, I_half): return f_max / (1 + np.exp(-k * (x - I_half))) def start_fi_o(): return 7 def end_fi_o(): return 55 def start_fi_s(): return 30 def end_fi_s(): return 90 def trial_nrs_ram_model(): trial_nrs_here = np.array([9, 11, 20, 30, 100, 500, 1000, 10000, 100000, 250000, 500000, 750000, 1000000]) return trial_nrs_here def colors_suscept_paper_dots(): color0 = 'blue' color0_burst = 'darkgreen' color01 = 'green' color02 = 'purple' color012 = 'orange' color01_2 = 'red' ## return color01, color012, color01_2, color02, color0_burst, color0 def plt_voltage_trace(cell, eod_fr, frame_cell, axs, lim_here, test, spikes_plotted_lower=True, spikes_plotted=True, dir='', scaling_factor=1, color_trace='grey', color_first_spike='black', color_second_spike='blue', xlim=0.400): spike_times_all_full = [] if os.path.exists(dir + '../data/cells/' + cell + '/' + cell + '.nix'): f, nix_exists, nix_missing = load_f(['cells'], 0, cell, dir=dir) if nix_exists: b = f.blocks[0] cont_baseline, nix_missing, ts = find_tags_baseline(b, nix_missing) if cont_baseline & (len(ts) > 0): spike_times_all = [] data_array_names = get_data_array_names(b) if 'eod' in ''.join(data_array_names).lower(): lengths = [] for t in ts: lengths.append(t.extent[:][0]) tag = ts[np.argmax(lengths)] add_mt = True if add_mt: if 'base' in tag.name.lower(): print(tag.name) tag_here = b.tags[tag.name] # 2/3 if len(tag_here.extent[:]) > 0: duration = restrict_base_durationts(tag_here.extent[:][0]) if duration < xlim: duration_base = xlim else: duration_base = xlim spike_times = link_arrays_spikes(b, first=tag.position[:][0], second=duration_base, minus_spikes=tag.position[:][0]) spike_times_full = link_arrays_spikes(b, first=tag.position[:][0], second=tag.extent[:][0], minus_spikes=tag.position[:][0]) spike_times_all.append(spike_times) spike_times_all_full.append(spike_times_full) eods_g, sampling = link_arrays_eod(b, first=tag.position[:][0], second=duration_base, array_name='V-1') axs.plot(np.arange(0, len(eods_g) / sampling, 1 / sampling) * scaling_factor, eods_g, linewidth=0.8, color=color_trace) if spikes_plotted: if len(spike_times) > 0: spikes_here = spike_times spikes_here = spikes_here[spikes_here < xlim] axs.scatter(spikes_here * scaling_factor, np.percentile(eods_g, 100) * np.ones(len(spikes_here)), color=color_first_spike, clip_on=False) if spikes_plotted_lower: if len(spike_times) > 0: axs.scatter(spike_times * scaling_factor, np.percentile(eods_g, 90) * np.ones(len(spike_times)), color=color_first_spike) hists2 = [(np.diff(spike_times) / (1 / eod_fr))] if len(hists2[0]) > 0: try: np.min(hists2) < 1.5 except: print('hist thing') embed() if np.min(hists2[0]) < 1.5: burst_corr = '_burstIndividual_' hists2, spikes_ex, frs_calc2 = correct_burstiness(hists2, [spike_times], [eod_fr] * len( [spike_times]), [eod_fr] * len( [spike_times]), lim=lim_here, burst_corr=burst_corr) if spikes_plotted: try: spikes_here = spikes_ex[0][spikes_ex[0] < xlim] axs.scatter(spikes_here * scaling_factor, np.percentile(eods_g, 100) * np.ones(len(spikes_here)), clip_on=False, color=color_second_spike) except: print('scatter thing') embed() if spikes_plotted_lower: axs.scatter(spikes_ex[0] * scaling_factor, np.percentile(eods_g, 90) * np.ones(len(spikes_ex[0])), color=color_second_spike) axs.set_xlim(0, xlim * scaling_factor) if test: plt.plot(np.arange(0, len(eods_g) / sampling, 1 / sampling) * scaling_factor, eods_g, color=color_first_spike) plt.scatter(spike_times * scaling_factor, np.percentile(eods_g, 90) * np.ones(len(spike_times)), color=color_first_spike) if len(ts) < 1: axs.set_title('no nix') return spike_times_all_full def find_tags_baseline(b, cont_rlx): try: ts = find_tags_list(b, names='baseline') cont_baseline = True except: print('ts problem') try: ts = find_tags_list(b, names='baseline') cont_baseline = True except: cont_baseline = False cont_rlx = True ts = [] return cont_baseline, cont_rlx, ts def load_f(data_dir, c, cell, dir=''): cont_rlx = False try: f = nix.File.open(dir + '../data/' + data_dir[c] + '/' + cell + '/' + cell + '.nix', nix.FileMode.ReadOnly) cont_here = True except: f = [] cont_here = False cont_rlx = True return f, cont_here, cont_rlx def perc_model_full(): return 95 def get_frame_for_base_plot(cells, save_names=None, based_on_ram_overview=True, species=' Apteronotus leptorhynchus'): frame, frame_spikes = load_cv_vals_susept(cells, EOD_type='synch', names_keep=['gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell']) # redo = True, cell_type_type = 'cell_type_reclassified' frame = unify_cell_names(frame, cell_type=cell_type_type) redo = False if not save_names: save_names = [ 'calc_RAM_overview-_simplified_' + version_final()] # 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_','calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_', frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo) test = False if test: frame_load_sp[frame_load_sp.species == species] frame_load_sp.cell_type_reclassified.unique() dated_up = update_ssh_file(load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv') if dated_up == 'yes': frame_load_sp = load_overview_susept(save_names[0], redo=True, redo_class=redo) cell_types = [' P-unit', ' Ampullary', ] if based_on_ram_overview: cells_exclude = [] for cell_type_here in cell_types: frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species) cells_exclude.extend(frame_file.cell.unique()) frame_load = frame[frame['cell'].isin(cells_exclude)] else: frame_load = frame return cell_type_type, frame_load, frame_spikes def get_grids_for_cv_fr(grid_scatter): grid_panel = gridspec.GridSpecFromSubplotSpec(2, 2, grid_scatter, wspace=0, hspace=0.05, height_ratios=[0.4, 2.5], width_ratios=[2.5, 0.25]) # 0.2, 2.5,width_ratios=[2.5, 0.2, 2.5, 0.7], axx = plt.subplot(grid_panel[ 0, 0]) # , height_ratios=[0.35, 3], width_ratios=[3, 0.5], height_ratios=[0.7, 2.5],width_ratios=[2.5,0.5] axs = plt.subplot(grid_panel[1, 0]) try: axy = plt.subplot(grid_panel[1, 1]) except: print('hist thing') embed() return axs, axx, axy def plt_fr_cv_base(ax0, ax_cv, ax_fr, add, frame_load, gg, s=5, fr_lim=700, cell_types=[' P-unit', ' Ampullary'], alpha=0.3, xmax=[1.51, 1.51], cvmax=1.51, cell_type_type='cell_type_reclassified', color_given=None, colors=[], annotate=False, species=' Apteronotus leptorhynchus'): if not colors: colors = colors_overview() for c_here, cell_type in enumerate(cell_types): if not color_given: color_given = colors[str(cell_type)] # not found frame_g = frame_load[ (frame_load[cell_type_type] == cell_type) & ((frame_load.gwn == True) | (frame_load.fs == True))] print(cvmax) kernel_histogram(ax_cv, color_given, np.array(frame_g['cv' + add[gg]]), xmin=0, xmax=xmax[c_here], alpha=0.5, step=0) # step=0.06 ax_cv.show_spines('') remove_yticks(ax_cv) ax_fr.get_shared_y_axes().join(*[ax_fr, ax0]) ax_cv.get_shared_x_axes().join(*[ax_cv, ax0]) test = False if test: pass kernel_histogram(ax_fr, color_given, np.array(frame_g['fr' + add[gg]]), step=0, alpha=0.5, orientation='vertical') # step=4, ax_fr.show_spines('') ax_cv.show_spines('') remove_xticks(ax_cv) remove_xticks(ax_fr) remove_yticks(ax_fr) y_axis = 'fr' x_axis = 'cv' frame_g = ptl_fr_cv(add[gg], alpha, annotate, ax0, cell_type_type, cell_types, frame_load, s, cv=y_axis, fr=x_axis, color_given=color_given) add_namex = [cv_base_name(), cv_base_name_corr()] add_namey = [fbasenamehz(), fbasecorrectedname()] ax0.set_xlabel(add_namex[gg]) ax0.set_ylabel(add_namey[gg]) ax0.set_ylim(0, fr_lim) ax0.set_xlim(0, cvmax) return x_axis, y_axis def fbasecorrectedname(): rm_var = rem_variable() if rm_var['rm']: val = r'$f\rm{_{BaseCorrected}}$ [Hz]' else: val = r'$f_{BaseCorrected}$ [Hz]' return val def fbasenamehz(): rm_var = rem_variable() if rm_var['rm']: val = fbasename() + ' [Hz]' else: val = fbasename() + ' [Hz]' return val def fbasename(): return r'$f' + basename() + '$' def fbasename_small(): return r'$f' + basename_small() + '$' def stimname(): rm_var = rem_variable() if rm_var['rm']: val = r'\rm{_{Stim}}' else: val = r'_{Stim}' return val def basename(): rm_var = rem_variable() if rm_var['rm']: val = r'\rm{_{Base}}' else: val = r'_{Base}' return val def basename_small(): rm_var = rem_variable() if rm_var['rm']: val = r'\rm{_{base}}' else: val = r'_{base}' return val def cv_base_name_corr(): rm_var = rem_variable() if rm_var['rm']: val = r'CV$\rm{_{BaseCorrected}}$' else: val = r'CV$_{BaseCorrected}$' return val def annotate_left_arrow(ax, lw=1.5, ws=None): if not ws: ws = ws_nonlin_systems() ax.annotate('', ha='center', xycoords='axes fraction', xy=(1 + ws, 0.5), textcoords='axes fraction', xytext=(1, 0.5), arrowprops={"arrowstyle": "->", "linestyle": "-", "linewidth": lw, "color": 'black'}, zorder=1, annotation_clip=False, transform=ax.transAxes, ls=8) def plt_single_matrix(ax, stack_final, ls=8, y_label=True, fr_name = 'fr'): new_keys, stack_plot = convert_csv_str_to_float(stack_final) mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot, stack_final['snippets'].unique()[0], stack_here=stack_final) # mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat) im = plt_RAM_perc(ax, 'no', mat) if y_label: set_ylabel_arrow(ax, xpos=-0.17, ypos=0.97) set_xlabel_arrow(ax, xpos=1, ypos=-0.23) set_clim_same([im], mats=[mat], lim_type='up', nr_clim='perc', clims='', percnr=95) fr = stack_final[fr_name].iloc[0] #embed() plt_triangle(ax, fr, fr, new_keys[-1], eod_metrice=False, nr=1) # eod_fr_half_color='purple', power_noise_color='blue', # todo: change clim values with different Hz values cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01, ls=ls) cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8) ''' eod_fr, stack_spikes = plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp, title, width)''' return cbar, fr, mat, im def load_stack_data_susept(cell, save_name, end=''): load_name = load_folder_name('calc_RAM') + '/' + save_name + end add = '_cell' + cell + end # str(f) # + '_amp_' + str(amp) stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell, add=add, load_version='csv') file_names_exclude = get_file_names_exclude() stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)] file_names = stack_cell.file_name.unique() file_names = exclude_file_name_short(file_names) cut_off_nr = get_cutoffs_nr(file_names) try: maxs = list(map(float, cut_off_nr)) except: embed() file_names = file_names[np.argmax(maxs)] stack_file = stack_cell[stack_cell['file_name'] == file_names] stack_final = get_stack_final(cell, stack_file) mat, new_keys = get_mat_susept(stack_final) return mat, stack_final def get_stack_final(cell, stack_file): amps = [np.min(stack_file.amp.unique())] amps = restrict_punits(cell, amps) amp = np.min(amps) # [0] stack_amps = stack_file[stack_file['amp'] == amp] lengths = stack_amps.stimulus_length.unique() trial_nr_double = stack_amps.trial_nr.unique() trial_nr = np.max(trial_nr_double) stack_final = stack_amps[ (stack_amps['stimulus_length'] == np.max(lengths)) & (stack_amps.trial_nr == trial_nr)] return stack_final def data_overview_punit(cell_types=[' P-unit']): plot_style() default_figsize(width=cm_to_inch(28), length=cm_to_inch(12)) default_ticks_talks() var_it = 'Response Modulation [Hz]' var_it2 = '' #print(right) grid0 = overview_mod_grid(cell_types) ########################## # Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln save_names = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_'] species = ' Apteronotus leptorhynchus' burst_fraction = [1, 1] # ,1,1] burst_corr_reset = 'burst_fraction_burst_corr_individual_stim' redo = False counter = 0 tags = [] frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo) scores = ['max(diag5Hz)/med_diagonal_proj_fr', 'max(diag5Hz)/med_diagonal_proj_fr', ] # + '_diagonal_proj' max_xs = [[[], [], []], [[], [], []]] for c, cell_type_here in enumerate(cell_types): frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species) test = False # ok das schließe ich aus weil da irgendwas in der Detektion ist, das betrifft jetzt genau 3 Zellen, also nicht so schlimm # 63 2018-08-14-af-invivo-1 # 241 2018-09-05-aj-invivo-1 # 252 2022-01-08-ah-invivo-1 frame_file = frame_file[frame_file.cv_stim < 5] if test: frame_file[frame_file.cv_base > 3].cell frame_file[frame_file.cv_stim > 3].cv_stim frame_file.groupby('cell').groups.keys() frame_file.group_by('cell') len(frame_file.cell.unique()) ############################################## # modulatoin comparison for both cell_types ################################ # Modulation, cell type comparison x_axis = ['cv_stim', 'cv_base', 'response_modulation'] # ,'fr_base']# var_item_names = [var_it, var_it, var_it2] # ,var_it2]#['Response Modulation [Hz]',] var_types = [''] # ,'response_modulation','']#,'']#'response_modulation' max_x = max_xs[c] x_axis_names = [x_axis_talk(), 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',] score = scores[c] scores_here = [score, score, score] # ,score] score_name = [nonlinearity_name_talk(), NLI_scorename2(), NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med' ax_j = [] axls = [] axss = [] log = '' # 'logall'#''#'logy','logall'True#False for v, var_type in enumerate(var_types): #axx, axy, axs, axls, axss, ax_j = get_grid_4(ax_j, axls, axss, grid0[v, counter]) axs = plt.subplot(grid0[v, counter]) if log == 'logy': pass else: pass if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]): pass else: pass xlimk = None labelpad = 0.5 # -1 fs, ms = size_talk_overview() cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v], frame_file, scores_here[v], burst_fraction_reset=burst_corr_reset, var_item=var_type, labelpad=labelpad, max_x=max_x[v], x_pos=1, fs=fs, ms=ms, burst_fraction=burst_fraction[c], sides=False, ha='right', color_given=colors_overview()[' P-unit_talk'], legend_spacing=0.15) print(cell_type_here + ' median ' + scores_here[v] + '' + str(np.nanmedian(frame_file[scores_here[v]]))) print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]]))) if v == 0: pass axs.set_ylabel(score_name[v]) axs.set_xlabel(x_axis_names[v], labelpad=labelpad) axs.set_ylim(0, 3.8) axs.set_xlim(0, 1.8) extra_lim = False if extra_lim: if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]): axs.set_xlim(xlimk) if log == 'logy': axs.set_yscale('log') make_log_ticks([axs]) elif log == 'logall': axs.set_yscale('log') make_log_ticks([axs]) axs.set_xscale('log') counter += 1 save_visualization(pdf=True) def overview_mod_grid(cell_types, right = 0.98, ws = 0.75): grid0 = gridspec.GridSpec(1, len(cell_types), wspace=ws, bottom=0.2, hspace=0.45, left=0.12, right=right, top=0.95) return grid0 def nonlinearity_name_talk(): return 'Nonlinearity' def size_talk_overview(): fs = 20 ms = 13 return fs, ms def plt_specific_cells(axs, cell_type_here, cv_name, frame_file, score, marker=[]): ###################################################### # hier kommen die kontrast Punkte dazu if cell_type_here == ' P-unit': cells_plot2 = p_units_to_show(type_here='contrasts')[0:2] else: cells_plot2 = [p_units_to_show(type_here='amp')[0]] cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index if not marker: axs.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra], s=9, facecolor="None", edgecolor='black', alpha=0.7, clip_on=False) # colors[str(cell_type_here)] else: axs.scatter(frame_file[cv_name].loc[cells_extra][0:2], frame_file[score].loc[cells_extra][0:2], s=9, facecolor="None", marker=marker[1], edgecolor='black', alpha=0.7, clip_on=False) # colors[str(cell_type_here)] axs.scatter(frame_file[cv_name].loc[cells_extra][2:4], frame_file[score].loc[cells_extra][2:4], s=9, facecolor="None", marker=marker[0], edgecolor='black', alpha=0.7, clip_on=False) # colors[str(cell_type_here)] def get_grid_4(ax_j, axls, axss, grid0): grid_k = gridspec.GridSpecFromSubplotSpec(2, 2, grid0, hspace=0.1, wspace=0.1, height_ratios=[0.35, 3], width_ratios=[3, 0.5]) axk = plt.subplot(grid_k[0, 0]) ax_j.append(axk) axs = plt.subplot(grid_k[1, 0]) axss.append(axs) axl = plt.subplot(grid_k[1, 1]) axls.append(axl) return axk, axl, axs, axls, axss, ax_j def conv_integers(threshold, power_kernel): power_threshold = [] for i in range(len(threshold)): power_threshold.append(np.array(power_kernel) + np.array(threshold[i])) power_threshold = np.concatenate(power_threshold) return power_threshold def NLI_burstcorr_name2(): return 'PNL($f_{BurstCorr}$)' def grid_evolutionary(): #gridr = gridspec.GridSpec(1, 2, wspace=0.45, hspace=0.5, top=0.85, left=0.05, bottom=0.45, right=1, # width_ratios=[2,1.3]) # 2, 2,, height_ratios = [1,3] #gridl = gridspec.GridSpecFromSubplotSpec(1, 2, gridr[0], wspace=0.6, hspace=0.5, # width_ratios=[1, 1]) # 2, 2,, height_ratios = [1,3] gridr = gridspec.GridSpec(1, 4, wspace=0.6, hspace=0.5, top=0.83, left=0.05, bottom=0.24, right=0.89, width_ratios=[2.2,0, 1.8, 2.2]) # 2, 2,, height_ratios = [1,3] bottom = 0.45 return gridr def didactic_sine_spectrum(axps, axts, color, sampling, sines, time_array, titles, freqs=None, colors=None, colors_peaks=[['red', 'purple']], labels=[[r'$f_{1}$', r'$f_{2}$']]): for ss, sine in enumerate(sines): if colors: color = colors[ss] axts[ss].plot(time_array * 1000, sine, color=color) axts[ss].set_ylim(np.min(sine) * 1.02, np.max(sine) * 1.02) axts[ss].set_xlim(0, 0.1 * 1000) axts[ss].show_spines('lb') axts[ss].set_xlabel('Time [ms]') axts[ss].set_title(titles[ss]) # , transform=axts[ss].transAxes) # r$\tilde{s}(f)$' ################################################################ p_array, f = ml.psd(sine - np.mean(sine), Fs=sampling, NFFT=2 ** 17, noverlap=2 ** 15 // 2) log = True if log: p_array = calc_log(p_array) axps[ss].plot(f, p_array, color='black') axps[ss].set_xlim(0, 100) if ss == 0: axts[ss].set_ylabel('Signal') axps[ss].set_ylabel('PSD [1/Hz]') else: remove_yticks(axps[ss]) remove_yticks(axts[ss]) axps[ss].set_xlabel('Frequency [Hz]') ################################################################ # embed() if freqs: plt_peaks_several(freqs[ss], [p_array], axps[ss], p_array, f, labels=labels[ss], colors=colors_peaks[ss], perc_peaksize=2) if log: axps[ss].set_ylim(-25,0) axps[ss].set_ylabel('dB') def retrieve_mat(diff_load, name): droped = diff_load.dropna(axis=1) cleaned = droped[droped['dist'] == name] cleaned.pop('dist') output = cleaned.reindex(sorted(cleaned.columns), axis=1) return output def load_baseline_matrix(what, cell, pivot1, a_fr=1): baseline = pd.read_pickle( load_folder_name( 'calc_model') + '/modell_all_cell_no_sinz1_afe0__afr1__afj0__length1.5_adaptoffsetallall2___stepefish10Hz_ratecorrrisidual35__modelbigfit_nfft4096_base.pkl') baseline_cell = baseline[baseline['dataset'] == cell] base_matrix = pivot1 * 1 if what != 'spike_times': base = np.nanmean(baseline_cell[what]) if a_fr == 1: base_matrix[:] = base * np.ones_like(pivot1) else: base_matrix[:] = np.zeros_like(pivot1) else: base = baseline_cell.iloc[0]['spike_times'] for i in range(len(base_matrix)): for j in range(len(base_matrix.iloc[0])): base_matrix.iloc[i, j] = baseline_cell.iloc[0]['spike_times'] # base_matrix = base * np.ones_like(pivot1) return base, base_matrix, baseline def get_control(nr, cell_nr, what, afe, contrast1='0.1', a_fr=1, contrast2='0', minimum=0.5, maximum=1.5, cell=[], version_sinz='sinz', adapt='adaptoffsetallall2', symetric='', beat_type='', step=10, variant='no', self='', varied='emitter'): name = 'modell_all_cell_' + variant + '_' + version_sinz + str( nr) + self + '_afe' + str(contrast1) + '__afr' + str(a_fr) + '__afj' + str( contrast2) + '__length1.5_' + adapt + '___stepefish' + str( step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + beat_type + symetric + '.pkl' control = pd.read_pickle( load_folder_name('calc_model') + '/' + name) if not cell: cell = np.unique(control['dataset'])[cell_nr] control_array = control[control['dataset'] == cell] DF_j = np.unique(np.array( (control_array['eodj'] - control_array['eodf']) / control_array[ 'eodf'] + 1)) * 100 dict_here = dict( zip(np.unique(control_array['eodj']), np.round(DF_j) / 100)) control_afj = rename(columns=dict_here) DF_e = np.round(np.unique( np.array((control_array['eode'] - control_array['eodf']) / control_array[ 'eodf'] + 1)) * 100) / 100 control_afj = control_afj.set_index(DF_e) if varied == 'emitter': control_afj.columns.name = 'fish1-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eodj-eodf' control_afj.index.name = 'fish2-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eode-eodf' else: control_afj.columns.name = 'fish2-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eodj-eodf' control_afj.index.name = 'fish1-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eode-eodf' if maximum != []: control_afj, column_chosen, index_chosen = cut_matrix_generation(control_afj, minimum, maximum) return control_afj, DF_e, dict_here, control_array['eodf'] def create_spikes_mat(length, spikes_cut, sampling_rate, results=[], trial_nr=1, test_saturation=False): # reset to the first spike spikes_mat = [[]] * len(spikes_cut) for s in range(len(spikes_cut)): spikes_mat[s] = cr_spikes_mat(spikes_cut[s], sampling_rate, int(length * sampling_rate)) smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate) smoothed2 = gaussian_filter(spikes_mat, sigma=0.002 * sampling_rate) smoothened_spikes_mat05 = np.mean(smoothed05, axis=0) smoothened_spikes_mat2 = np.mean(smoothed2, axis=0) if test_saturation: # plt_saturation_effect(sampling_rate, smoothened_spikes_mat2, smoothed2, smoothed05, results, smoothened_spikes_mat05, spikes_mat) from utils_test import plt_saturation_effect2 plt_saturation_effect2(sampling_rate, smoothened_spikes_mat2, smoothed2, smoothed05, results, smoothened_spikes_mat05, spikes_mat, show=False) return spikes_cut, spikes_mat, smoothened_spikes_mat05, smoothened_spikes_mat2 def get_cut_off_for_wn(cut_off_nr, file_name): split = file_name.lower().split('hz')[0] if 'wn' in split: cut_off_nr.append(split.split('wn')[1]) else: cut_off_nr.append(split.split('_')[1]) return cut_off_nr def color_beats(): return 'red' def power_didactic_subplots(): plt.subplots_adjust(wspace=0.25, top=0.85, left=0.1, hspace=0.85, bottom=0.2, right=0.97) def reset_yaxis_cords(axes, ypos=-0.1): for ax in axes: ax.yaxis.set_label_coords(ypos, 0.5) def x_axis_talk(): return 'CV (Noise)' def cellscompar2(cells_plot2, amp_desired=[5, 20]): # [0, 1.1] plot_style() # 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_ # ] # save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_'] save_names = [version_final()] amps_desired = amp_desired # amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class) cell_type_type = 'cell_type_reclassified' # frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type, redo = True) frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_sp='/calc_base_data-base_frame_overview.pkl', frame_general=False) default_settings_cells_susept(cells_plot2) # 0.21 cell_types = [' P-unit']#, ' Ampullary'] cell_types_name = [' P-units']#, 'Ampullary cells', ] plot_style() size_evolutionary() default_ticks_talks() default_lw_RAM_talks() names = [' P-unit_talk', ' eigen_P-unit_talk'] gridr = grid_evolutionary() species = [' Apteronotus leptorhynchus', ' Eigenmannia virescens'] for c, cell in enumerate(cells_plot2): print(cell) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame) ims = [] add_here = '_cell' + cell mats = [] zorders = [100, 50] for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot2) if len(stack) > 0: files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True) stack_file = stack[stack['file_name'] == files[0]] #embed() amps = stack_file['amp'].unique() predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps trues = [] for amp in amps_defined: if amp in amps: trues.append(True) # if len(trues) < len(amps): amps_defined = [np.min(amps)] cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1'] if cell == cells_amp: print('cell thing') embed() ims = [] for aa, amp in enumerate(amps_defined): mat, stack_final = load_stack_data_susept(cell, save_name=version_final(), end='') if amp in np.array(stack_file['amp']): print(zorders[aa]) ax = plt.subplot(gridr[c*3]) cbar, fr, mat, im = plt_single_matrix(ax, stack_final, ls=None) colors = colors_overview() ax.set_title(species[c], color=colors[names[c]], pad = 30) set_clim_same(ims, mats=mats, lim_type='up', percnr=95) ########################## # Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln # save_names = [] save_names2 = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_'] burst_fraction = [1, 1] # ,1,1] burst_corr_reset = 'burst_fraction_burst_corr_individual_stim' redo = False counter = 0 tags = [] frame_load_sp = load_overview_susept(save_names2[0], redo=redo, redo_class=redo) scores = ['max(diag5Hz)/med_diagonal_proj_fr_base_w_burstcorr','max(diag5Hz)/med_diagonal_proj_fr', ] # + '_diagonal_proj' x_axiss = ['cv_stim', 'cv_stim', ]# 'burst_fraction_burst_corr_individual_base'] max_xs = [[[], [], []], [[], [], []]] for c, species in enumerate([' Apteronotus leptorhynchus',' Eigenmannia virescens'][0:len(cells_plot2)]): for cc, cell_type_here in enumerate(cell_types): frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species) frame_file = frame_file[frame_file.cv_stim < 5] ################################ # Modulation, cell type comparison var_it = 'Response Modulation [Hz]' x_axis = [x_axiss[c]]#, 'cv_base', 'response_modulation'] # ,'fr_base']# var_item_names = [var_it, var_it] # ,var_it2]#['Response Modulation [Hz]',] var_types = [''] # ,'response_modulation','']#,'']#'response_modulation' max_x = max_xs[c] x_axis_names = [x_axis_talk()]#, 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',] score = scores[c] scores_here = [score]#, score_burst_corr, score] # ,score] score_name = ['Nonlinearity($f_{BaseCorr}$)']#nonlinearity_name_talk()]#, NLI_name2(),NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med' axss = [] log = '' # 'logall'#''#'logy','logall'True#False for v, var_type in enumerate(var_types): if c == 0: axs = plt.subplot(gridr[2]) axss.append(axs) if log == 'logy': ymin = 'no' else: ymin = 0 xmin = 0 xlimk = None labelpad = 0.5 # -1 colors = colors_overview() fs, ms = size_talk_overview() cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v], frame_file, scores_here[v], ymin=ymin, xmin=xmin, burst_fraction_reset=burst_corr_reset, var_item=var_type, labelpad=labelpad, max_x=max_x[v], xlim=xlimk, x_pos=1, fs=fs, ms=ms, c=c, burst_fraction=burst_fraction[c], sides=False, color_text=colors[names[c]], ha='right', y_val=1.15, color_given=colors[names[c]], legend_spacing=0.1) # : 'tab:blue', print(cell_type_here + ' median ' + scores_here[v] + '' + str( np.nanmedian(frame_file[scores_here[v]]))) print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]]))) axs.set_ylabel(score_name[v]) axs.set_xlabel(x_axis_names[v], labelpad=labelpad) axs.set_ylim(0,7) axs.set_xlim(0,1.7) if log == 'logy': axs.set_yscale('log') make_log_ticks([axs]) elif log == 'logall': axs.set_yscale('log') make_log_ticks([axs]) axs.set_xscale('log') make_log_ticks([axs]) counter += 1 save_visualization(pdf=True, individual_tag=cells_plot2[0]) def size_evolutionary(): default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(11.8)) def cellscompar(amp_desired=[5, 20], xlim=[], cells_plot2=[], RAM=True, scale_val=False): # [0, 1.1] plot_style() # 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_ # ] # save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_', # 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_'] save_names = [version_final()] amps_desired = amp_desired # amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class) cell_type_type = 'cell_type_reclassified' # frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type, redo = True) frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch', names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs', 'cv_burst_corr_individual', 'fr_burst_corr_individual', 'width_75_burst_corr_individual', 'vs_burst_corr_individual', 'cell_type_reclassified', 'cell'], path_sp='/calc_base_data-base_frame_overview.pkl', frame_general=False) default_settings_cells_susept(cells_plot2) if len(cells_plot2) == 1: pass else: pass grid = gridspec.GridSpec(1, 1, wspace=0.4, hspace=0.5, top=0.9, left=0.25, bottom=0.12, right=0.99) # 0.21 cell_types = [' P-unit', ' Ampullary', ] cell_types_name = [' P-units', 'Ampullary cells', ] style() default_figsize(width=cm_to_inch(30.5), length=cm_to_inch(12.19)) default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(13.19)) default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(11.8)) size_evolutionary() #default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(13.19)) default_ticks_talks() default_lw_RAM_talks() gridr = grid_evolutionary() # 0.21 #grid1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid[0], hspace=0.35, # wspace=0.35) # , # axos = [] # axds = [] name = [' P-unit_talk', ' Ampullary_talk'] for c, cell in enumerate(cells_plot2): print(cell) cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame) # embed() # stack_final ims = [] add_here = '_cell' + cell # str(c) # axo2 = None # axd2 = None mats = [] zorders = [100, 50] if c == 1: y_label = True else: y_label = True for s, save_name in enumerate(save_names): load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot2) if len(stack) > 0: files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True) stack_file = stack[stack['file_name'] == files[0]] amps = stack_file['amp'].unique() predefined_amp = True if predefined_amp: amps_defined = amps_desired else: amps_defined = amps trues = [] for amp in amps_defined: if amp in amps: trues.append(True) # if len(trues) < len(amps): amps_defined = [np.min(amps)] #, np.max(amps) # embed() cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1'] if cell == cells_amp: print('cell thing') embed() ims = [] for aa, amp in enumerate(amps_defined): mat, stack_final = load_stack_data_susept(cell, save_name=version_final(), end='') if amp in np.array(stack_file['amp']): print(zorders[aa]) ax = plt.subplot(gridr[c*3]) colors = colors_overview() #axx.set_title(cell_types_name[c], color=colors[cell_type_here]) ax.set_title(cell_types_name[c], color = colors_overview()[name[c]]) cbar, fr, mat, im = plt_single_matrix(ax, stack_final, y_label = y_label, ls=None) #if c == 0: #cbar.set_label('') #set_clim_same(ims, mats=mats, lim_type='up', percnr=95) #if c == 1: # remove_yticks(ax) ################################# # overveiw ################################### ############################### # Das ist der Finale Score # 'max(diag5Hz)/med_diagonal_proj_fr','max(diag5Hz)/med_diagonal_proj_fr_base_w_burstcorr', ################################### # scores = [scoreall+'_diagonal_proj'] ########################## # Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln # save_names = [] save_names2 = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_'] # save_names = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__abs_'] ##################################################### # grid_lower_lower = gridspec.GridSpecFromSubplotSpec(1, 2, grid0[1], wspace = 0.5, hspace=0.55)#, height_ratios = [1,3] species = ' Apteronotus leptorhynchus' burst_fraction = [1, 1] # ,1,1] burst_corr_reset = 'burst_fraction_burst_corr_individual_stim' redo = False # embed() counter = 0 tags = [] frame_load_sp = load_overview_susept(save_names2[0], redo=redo, redo_class=redo) scores = ['max(diag5Hz)/med_diagonal_proj_fr', 'max(diag5Hz)/med_diagonal_proj_fr', ] # + '_diagonal_proj' max_xs = [[[], [], []], [[], [], []]] for c, cell_type_here in enumerate(cell_types[0:len(cells_plot2)]): frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species) # embed() test = False # ok das schließe ich aus weil da irgendwas in der Detektion ist, das betrifft jetzt genau 3 Zellen, also nicht so schlimm # 63 2018-08-14-af-invivo-1 # 241 2018-09-05-aj-invivo-1 # 252 2022-01-08-ah-invivo-1 frame_file = frame_file[frame_file.cv_stim < 5] if test: frame_file[frame_file.cv_base > 3].cell frame_file[frame_file.cv_stim > 3].cv_stim # frame_file.groupby('cell').count() frame_file.groupby('cell').groups.keys() frame_file.group_by('cell') len(frame_file.cell.unique()) ############################################## # modulatoin comparison for both cell_types ################################ # Modulation, cell type comparison var_it = 'Response Modulation [Hz]' x_axis = ['cv_stim', 'cv_base', 'response_modulation'] # ,'fr_base']# var_item_names = [var_it, var_it] # ,var_it2]#['Response Modulation [Hz]',] var_types = [''] # ,'response_modulation','']#,'']#'response_modulation' max_x = max_xs[c] x_axis_names = [x_axis_talk(), 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',] # score = scores[0] score = scores[c] scores_here = [score, score, score] # ,score] score_name = [nonlinearity_name_talk(), NLI_scorename2(), NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med' ax_j = [] axls = [] axss = [] # embed() # frame_max = frame_file[frame_file[score]>5] log = '' # 'logall'#''#'logy','logall'True#False for v, var_type in enumerate(var_types): # ax = plt.subplot(grid0[1+v])#grid_lower[0, v] if c == 0: #axx, axy, axs, axls, axss, ax_j = get_grid_4(ax_j, axls, axss, gridr[1]) axs = plt.subplot(gridr[2]) axss.append(axs) if log == 'logy': ymin = 'no' else: ymin = 0 xmin = 0 xlimk = None labelpad = 0.5 # -1 fs, ms = size_talk_overview() cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v], frame_file, scores_here[v], ymin=ymin, xmin=xmin, burst_fraction_reset=burst_corr_reset, var_item=var_type, labelpad=labelpad, max_x=max_x[v], xlim=xlimk, x_pos=1, fs=fs, ms=ms, c=c, burst_fraction=burst_fraction[c], sides=False, color_text=colors_overview()[name[c]], ha='right', y_val=1.15, color_given=colors_overview()[name[c]], legend_spacing=0.1) print(cell_type_here + ' median ' + scores_here[v] + '' + str( np.nanmedian(frame_file[scores_here[v]]))) print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]]))) axs.set_xlim(0, 1.7) axs.set_ylim(0, 35) axs.set_ylabel(score_name[v]) axs.set_xlabel(x_axis_names[v], labelpad=labelpad) extra_lim = False if extra_lim: if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]): axs.set_xlim(xlimk) if log == 'logy': axs.set_yscale('log') make_log_ticks([axs]) elif log == 'logall': axs.set_yscale('log') make_log_ticks([axs]) axs.set_xscale('log') make_log_ticks([axs]) counter += 1 save_visualization(pdf=True, individual_tag=cells_plot2[0]) def default_lw_RAM_talks(): plt.rcParams['lines.linewidth'] = 3 #plt.rcParams['axes.linewidth'] = 22 def diff_label(): return '$|\Delta f_{1} - \Delta f_{2}|$' def two_deltaf2_label(): return '$2|\Delta f_{2}|$' def two_deltaf1_label(): return '$2|\Delta f_{1}|$' def sum_label(): return '$|\Delta f_{1} + \Delta f_{2}|$' def deltaf2_label(): return '$|\Delta f_{2}|$' def deltaf1_label(): return '$|\Delta f_{1}|$'