susceptibility1/utils_paper.py
2024-04-27 22:44:48 +02:00

35777 lines
1.7 MiB

import ast
import csv
import warnings
from random import sample
import numpy
import seaborn as sns
from scipy.optimize import curve_fit
from scipy.signal import vectorstrength
from scipy.stats import alpha, gaussian_kde
from sklearn import metrics
from sklearn.linear_model import LinearRegression
from thunderfish.eventdetection import hist_threshold
# from utils_all import cr_spikes_mat, load_folder_name
warnings.filterwarnings("ignore", message="WARNING:root:MultiTag type relacs.stimulus.segment")
from scipy import optimize, stats
'''try:
from utils_all import default_settings, load_folder_name
except:
# das ist das gleiche wie drüber nur dass es einen anderen Namen hat
from utils_all_down import column2,find_mean_period, default_settings, load_folder_name, chose_mat_max_value, \
create_stimulus_SAM, \
default_settings, find_code_vs_not, load_folder_name, plt_peaks, \
plt_peaks_several, resave_small_files, \
restrict_cell_type, thresh_crossings, zenter_and_normalize'''
try:
from utils_all import *
except:
# das ist das gleiche wie drüber nur dass es einen anderen Namen hat
from utils_all_down import *
try:
pass
except:
a = 0
try:
import nixio as nix
except:
print('nixio not there')
import numpy as np
import pandas as pd
import scipy
from IPython import embed
from matplotlib import gridspec, pyplot as plt, pyplot, ticker as ticker
import os
import matplotlib.mlab as ml
import matplotlib.gridspec as gridspec
from scipy.ndimage import gaussian_filter
from thunderfish import fakefish
try:
import rlxnix as rlx
except:
a = 5
try:
from numba import jit
except ImportError:
def jit():
def decorator_jit(func):
return func
return decorator_jit
import inspect
if 'cv_cell_types' not in inspect.stack()[-1][1]:
try:
from plotstyle import plot_style, plot_style as style, spines_params
except:
a = 5
import itertools as it
def plot_rec_stimulus(grid, transform_fact, stimulus, color1, time, counter, eod_fr, deltat, nfft, xlim=0.05, shift=0,
lw=0.5):
axt = plt.subplot(grid[0])
time_here = (time[0:len(stimulus)] - shift) * transform_fact
stim_here = stimulus[time_here < xlim * transform_fact]
extracted, _ = extract_am(stimulus, time / 1000, norm=False, extract='globalmax', sampling=1 / deltat,
eodf=eod_fr) # time_here_here/1000
extracted_here = extracted[time_here < xlim * transform_fact]
time_here_here = time_here[time_here < xlim * transform_fact]
axt.plot(time_here_here, stim_here, color=color1, linewidth=lw)
axt.plot(time_here_here, extracted_here, color='red', linewidth=1)
counter += 1 # am_time*1000
axt.set_xlim(0, xlim * transform_fact)
axt.set_ylim(-1.2, 1.7)
axt.show_spines('lb')
axt.axhline(0, color='black', lw=0.5)
axt.set_xticks_blank()
axp = plt.subplot(grid[1])
ff, pp = calc_psd(stimulus, deltat, nfft)
axp.set_xticks_blank()
return counter, axt, ff, pp, axp
def plot_lowpass(g_p_t, transform_fact, time, shift, v_dent_output, color1, deltat, fft_type, nfft, eod_fr,
extract=True, lw=0.5,
xlim=0.05):
ff, ff_am, pp, pp_am, time_here, extracted = time_psd_calculation(deltat, eod_fr, extract, fft_type, nfft, shift,
time, transform_fact, v_dent_output)
axt_p2 = plt_time_arrays(color1, g_p_t, lw, v_dent_output, extracted, xlim, time_here,
transform_fact=transform_fact)
axp_p2 = plt.subplot(g_p_t[1])
axp_p2.set_xticks_blank()
return axt_p2, ff, pp, axp_p2, ff_am, pp_am
def plt_time_arrays_here(color1, g_p_t, lw, time_here, transform_fact, v_dent_output, xlim):
axt_p2 = plt.subplot(g_p_t[0])
axt_p2.plot(time_here, v_dent_output, color=color1, linewidth=lw)
axt_p2.show_spines('lb') # am_time*1000
axt_p2.set_xlim(0.0, xlim * transform_fact)
axt_p2.axhline(0, color='black', lw=0.5)
return axt_p2
def model_sheme_only(grid_sheme, stimulus_length=5, a_fr=1, a_fe=0.2,
v_exp=1, exp_tau=0.1):
# need to reduce parameters
# parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0)
# load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv"
# parameters = pd.read_csv(load_name)
# potentiell zellen wo der range nicht zu weit ist : 0, 9
# problem 0, 10: spikt manchmal
# problem 9: presynaptic oscilation nicht so schön
# gute Zellen: 5
# ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus
# good_cells = pd.read_csv("good_model_cells.csv", index_col=0)
# 2 ist wohl sehr nah, das geht!
# 0,1,3,7 weit entfernt
# 4,5,6 sehr weit entfernt
# 8 ist wohl am nächsten!
# embed()
# model_params = parameters.iloc[0]
# model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2]
# model_params = load_model(load_name=load_name, cell_nr = cell_nr)
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0]
eod_fr = model_params.pop('EODf') # .iloc[0]
deltat = model_params.pop("deltat") # .iloc[0]
eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50
# REMAINING rows
color_p3 = 'grey' # 'red'#palette['red']
color_p1 = 'grey' # 'blue'#palette['blue']
counter_here = 0
grid_sheme = gridspec.GridSpecFromSubplotSpec(3, 1,
subplot_spec=grid_sheme, wspace=0.2, hspace=0.95)
counter_g = 0
for mult_nr in range(len(eod_fe)):
try:
time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters(
stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr)
except:
print('parameter thing6')
embed()
colorful_title = False
# einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn
eod_fish_r_rec = eod_fish_r * 1
eod_fish_r_rec[eod_fish_r_rec < 0] = 0
_, _, _, _, _ = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1,
color_p3, mult_nr, eod_fr, eod_fe, stimulus,
stimulus_rec, colorful_title)
# for g, stimulus_here in enumerate([stimuli[1]]):
# And plot correspoding sheme
axsheme = plt.subplot(grid_sheme[0])
plot_sheme_nonlinearity(axsheme, color_p1)
# SECOND Row: Dendridic Low pass filter
axsheme = plt.subplot(grid_sheme[1])
plot_sheme_lowpass(axsheme)
# THIRD /FORTH Row: LIF /EIF
axsheme = plt.subplot(grid_sheme[2])
exponential = ''
plot_sheme_IF(axsheme, exp_tau, v_exp, exponential)
counter_g += 1
counter_here += 1
def model_and_data_vertical(nr_clim=10, many=False, width=0.005, row='no', HZ50=True, fs=8, hs=0.39, nffts=['whole'],
powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'],
contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz',
mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'],
external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300,
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1],
c_signal=[0.9],
cut_offs1=[300], clims='all', restrict='restrict'):
plot_style()
default_settings(lw=0.5, column=2, length=8.5)
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
repeats = [30, 100000] # ,
aa = 0
_, _ = overlap_cells()
cells_all = ['2012-07-03-ak-invivo-1',
'2018-05-08-ae-invivo-1',
'2011-10-25-ad-invivo-1'] # good_data[,
good_data = cells_all
for _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, in it.product(
cells, D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs,
c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
if row == 'no':
col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange(
else:
pass
if row == 2:
default_settings(column=2, length=7.5) # 2+2.25+2.25
elif row == 1:
default_settings(column=2, length=4)
row = 5
fig = plt.figure(figsize=(6.8, 7.5))
grid_orig = gridspec.GridSpec(1, 2, wspace=0.15, bottom=0.05,
hspace=0.1, left=0.07, width_ratios=[4, 1.3], right=0.99,
top=0.88) # , height_ratios = [0.4,3]
# plot lower part
grid_lower = gridspec.GridSpecFromSubplotSpec(4, 1, grid_orig[0], wspace=0.05,
hspace=0.53, height_ratios=[0.2, 1, 1, 1])
wr = [1, 1, 1]
if row == 2:
plt.subplots_adjust(bottom=0.067, wspace=0.45, top=0.81, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
elif row == 1:
plt.subplots_adjust(bottom=0.1, wspace=0.45, top=0.81, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
else:
plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
a = 0
maxs = []
mins = []
ims = []
perc05 = []
perc95 = []
iternames = [D_extraction_method, external_noise,
internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
for all in it.product(*iternames):
var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
hs = 0.25
#################################
# model cells
ax_model = []
for t, trials_stim in enumerate(repeats):
grid_model = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[2 + t], hspace=hs,
width_ratios=wr)
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant,
trials_stim, power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract,
noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
# '../calc_model/noise2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
path = save_name + '.pkl'
model = load_model_susept(path, cells_all, save_name) # cells
adapt_type_name, ref_type_name, dendrid_name, stim_type_noise_name = define_names(var_type, stim_type_noise,
dendrid, ref_type,
adapt_type)
cells_all = model.groupby('cv_stim').first().sort_values(by='cv_stim').cell # ('cv_stim')
for c, cell in enumerate(cells_all):
print(c)
ax = plt.subplot(grid_model[c]) # grid_30_s[1]
if len(model) > 0:
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe,
stim_type_noise_name)
suptitles, titles = titles_susept_names(a_fe, extract, noise_added, stim_type_afe_name,
stim_type_noise_name2, trials_stim, var_items,
var_type) # find_titles_susept(a_fe, cell, extract, noise_added, stim_type_afe_name,
# stim_type_noise_name2, suptitles, titles, trials_stim,
# var_items, var_type)
model_show = model[
(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
new_keys = model_show.index.unique() # [0:490]
try:
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_aspect('equal')
ax.set_xticks_delta(100)
ax.set_yticks_delta(100)
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='')
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived)
if many == True:
titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0]))
color = title_color(cell)
print(color)
if t == 0:
ax.set_title(
titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str(
int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str(
np.round(model_show.cv.iloc[0], 2)) + \
' $cv_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2)) + ' $D_{sig}$=' + str(
np.round(D_derived, 5)) + ' s=' + str(
np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color)
perc = '' # 'perc'
im = plt_RAM_perc(ax, perc, stack_plot)
ims.append(im)
maxs.append(np.max(np.array(stack_plot)))
mins.append(np.min(np.array(stack_plot)))
perc05.append(np.percentile(stack_plot, 5))
perc95.append(np.percentile(stack_plot, 95))
plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300,
model_show.eod_fr.iloc[0])
if HZ50:
plt_50_Hz_noise(ax, 300)
ax.set_aspect('equal')
cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, width=width)
if c == 0:
ax.set_ylabel(F2_xlabel())
else:
remove_yticks(ax)
if c == 2:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
if t == 1:
ax.set_xlabel(F1_xlabel(), labelpad=20)
else:
remove_xticks(ax)
print(c)
ax_model.append(ax)
a += 1
model_sheme_in_one(grid_orig[1]) # grid_sheme grid_lower[3]
#################################
# data cells
grid_data = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[1],
hspace=hs, width_ratios=wr)
grid_isi = gridspec.GridSpecFromSubplotSpec(1, len(good_data), grid_lower[0],
hspace=hs, width_ratios=wr)
frame = load_cv_base_frame(cells_all)
ax_isi = []
for f, cell in enumerate(cells_all):
ax = plt.subplot(grid_data[f])
if f == 2:
plot = True
else:
plot = False
ax_data = plt_data_up(cell, ax, fig, cells_all, cell_type='p-unit', cbar_label=plot, width=width)
if f == len(cells) - 1:
ax.set_ylabel(F2_xlabel()) #
else:
remove_yticks(ax)
remove_xticks(ax)
axi = plt.subplot(grid_isi[f]) # grid_30_d[0]
frame_cell = frame[(frame['cell'] == cell)]
spikes = frame_cell.spikes.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
alpha = 1
for hh, h in enumerate(hists):
axi.hist(h, bins=100, color='blue', alpha=float(alpha - 0.05 * hh))
ax_isi.append(axi)
axi.set_ylabel('Nr')
if f == len(cells_all) - 1:
axi.set_xlabel('EODf multiple')
axi.set_ylabel('Nr')
ax_isi[0].get_shared_x_axes().join(*ax_isi)
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title) # +' file '
set_clim_shared(clims, ims, maxs, mins, nr_clim, perc05, perc95)
axes = np.array([np.array(ax_data), np.array(ax_model[0:int(len(ax_model) / 2)]),
np.array(ax_model[int(len(ax_model) / 2)::]), np.array(ax_isi)])
fig.tag(np.transpose(axes), xoffs=-3, yoffs=1.2, minor_index=2)
save_visualization(pdf=True)
def model_sheme_in_one(grid_sheme, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5, fft_type='mppsd',
a_fr=1, a_fe=0.2,
v_exp=1, exp_tau=0.1, counter=0, shift=0.25):
# need to reduce parameters
# parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0)
# load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv"
# parameters = pd.read_csv(load_name)
# potentiell zellen wo der range nicht zu weit ist : 0, 9
# problem 0, 10: spikt manchmal
# problem 9: presynaptic oscilation nicht so schön
# gute Zellen: 5
# ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus
# good_cells = pd.read_csv("good_model_cells.csv", index_col=0)
# 2 ist wohl sehr nah, das geht!
# 0,1,3,7 weit entfernt
# 4,5,6 sehr weit entfernt
# 8 ist wohl am nächsten!
# embed()
# model_params = parameters.iloc[0]
# model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2]
load_name = 'models_big_fit_d_right.csv'
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0]
eod_fr = model_params.pop('EODf') # .iloc[0]
deltat = model_params.pop("deltat") # .iloc[0]
v_offset = model_params.pop("v_offset") # .iloc[0]
eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50
# REMAINING rows
color_p3 = 'grey' # 'red'#palette['red']
color_p1 = 'grey' # 'blue'#palette['blue']
color_diagonal = 'grey' # 'cyan'#palette['cyan']
colors = [color_diagonal, color_p1, color_p1, color_p3]
ax_rec = [[]] * 4
ax_low = [[]] * 4
axt_IF2 = []
delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0]
counter_here = 0
nrs = [1, 2, 3, 4]
# first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3
grid0 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_sheme, width_ratios=[1, 3], wspace=0.35)
# Grid for the sheme
try:
pass
except:
print('grid thing5')
embed()
lw = 0.5
wr = [1, 1.2]
xlim = 0.065
axps = []
axps_lowpass = []
axps_stimulus = []
pps = []
pps_lowpass = []
pps_stimulus = []
colors_chosen = []
counter_g = 0
for mult_nr in range(len(eod_fe)):
try:
time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters(
stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr)
except:
print('parameter thing5')
embed()
colorful_title = False
# einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn
eod_fish_r_rec = eod_fish_r * 1
eod_fish_r_rec[eod_fish_r_rec < 0] = 0
add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1,
color_p3, mult_nr, eod_fr, eod_fe, stimulus,
stimulus_rec, colorful_title)
stimulus_here = do_withenoise_stimulus(deltat, eod_fr, stimulus_length)
titles = [titles[1]]
g = 0
color = colors[counter_here]
# A grid for a single POWER column
grid_power_col = gridspec.GridSpecFromSubplotSpec(6, 1,
subplot_spec=grid0[nrs[counter_here]],
height_ratios=[0.7, 1, 0.7, 1, 0.7, 1], wspace=0.45,
hspace=0.65)
# FIRST Row: Rectified stimulus
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[1], wspace=ws, hspace=1.3,
width_ratios=wr)
counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform,
stimulus_here, color, time, counter, eod_fr,
deltat,
nfft, shift=shift,
lw=lw,
xlim=xlim)
pps_stimulus.append(pp)
axps_stimulus.append(axp)
colors_chosen.append(color)
if counter_here == 0:
ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
ax_rec[counter_here].text(add_pos[g], 1.1, titles[g],
transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right',
# And plot correspoding sheme
axsheme = plt.subplot(grid_power_col[0])
plot_sheme_nonlinearity(axsheme, color_p1)
# REMAINING Rows: dendridic filter / LIF /EIF stimulus
exponential = '' # , 'EIF'
manual_offset = False
if manual_offset:
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
eod_fish_rs[g], deltat=deltat,
exponential=exponential, v_exp=v_exp,
exp_tau=exp_tau,
**model_params)
print('Firing rate baseline ' + str(len(spike_times) / stimulus_length))
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
stimulus_here, deltat=deltat, exponential=exponential,
v_exp=v_exp, exp_tau=exp_tau,
**model_params)
# SECOND Row: Dendridic Low pass filter
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[3], width_ratios=wr, wspace=ws,
hspace=1.3)
ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform, time,
shift, v_dent_output, color, deltat,
fft_type, nfft, eod_fr,
xlim=xlim, lw=lw)
pps_lowpass.append(pp)
axps_lowpass.append(axp_p2)
colors_chosen.append(color)
if counter_here == 0:
ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
axsheme = plt.subplot(grid_power_col[2])
plot_sheme_lowpass(axsheme)
# THIRD /FORTH Row: LIF /EIF
# plot the voltage of the exponentials
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=wr,
subplot_spec=grid_power_col[5], wspace=ws,
hspace=1.45)
axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color,
spike_times, shift, deltat, fft_type, nfft, eod_fr, xlim=xlim,
exponential=exponential,
counter_here=counter_here) # , add = add
axps.append(axp_s)
pps.append(pp_s)
colors_chosen.append('black')
axt_IF2.append(axt_IF)
if g == 0:
axsheme = plt.subplot(grid_power_col[4]) # grid_sheme[ee + 2]
plot_sheme_IF(axsheme, exp_tau, v_exp, exponential)
################################
# here plot the amplitude modulation relationship
counter_g += 1
counter_here += 1
# plot psd with shared log lim
####################################
# cut first parts
# because otherwise there is a dip at the beginning and thats a problem for the range thing
ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0)
# here I calculate the log and do the same range for all power spectra
# this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down
# not to see the peaks in the noise
pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True)
axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ])
pps3 = create_same_max(pps, same=True)
pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3]))
axps_stimulus = np.concatenate([axps_stimulus, axps])
# there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm
# therefore we dont show it
colors = [color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3, ]
plot_points = [[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes', ]
for a, axp in enumerate(axps_stimulus):
lw_p = 0.8
plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr)
if a % 4 == 3:
axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right')
axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus)
def model_sheme(grid_sheme, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5, fft_type='mppsd', a_fr=1,
a_fe=0.2,
v_exp=1, exp_tau=0.1, counter=0, shift=0.25):
# need to reduce parameters
# parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0)
# load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv"
# parameters = pd.read_csv(load_name)
# potentiell zellen wo der range nicht zu weit ist : 0, 9
# problem 0, 10: spikt manchmal
# problem 9: presynaptic oscilation nicht so schön
# gute Zellen: 5
# ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus
# good_cells = pd.read_csv("good_model_cells.csv", index_col=0)
# 2 ist wohl sehr nah, das geht!
# 0,1,3,7 weit entfernt
# 4,5,6 sehr weit entfernt
# 8 ist wohl am nächsten!
# embed()
# model_params = parameters.iloc[0]
# model_params = parameters[parameters['cell'].isin(good_cells.cell[0:-1]+'-invivo-1')].iloc[2]
load_name = 'models_big_fit_d_right.csv'
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0]
eod_fr = model_params.pop('EODf') # .iloc[0]
deltat = model_params.pop("deltat") # .iloc[0]
v_offset = model_params.pop("v_offset") # .iloc[0]
eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50
# REMAINING rows
color_p3 = 'grey' # 'red'#palette['red']
color_p1 = 'grey' # 'blue'#palette['blue']
color_diagonal = 'grey' # 'cyan'#palette['cyan']
colors = [color_diagonal, color_p1, color_p1, color_p3]
ax_rec = [[]] * 4
ax_low = [[]] * 4
axt_IF1 = []
axt_IF2 = []
delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0]
counter_here = 0
nrs = [1, 2, 3, 4]
# first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3
grid0 = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=grid_sheme, width_ratios=[1, 3], wspace=0.35)
# Grid for the sheme
try:
grid_sheme = gridspec.GridSpecFromSubplotSpec(3, 1,
subplot_spec=grid0[0], wspace=0.2, hspace=0.95)
except:
print('grid thing2')
embed()
lw = 0.5
wr = [1, 1.2]
xlim = 0.065
axps = []
axps_lowpass = []
axps_stimulus = []
pps = []
pps_lowpass = []
pps_stimulus = []
colors_chosen = []
counter_g = 0
for mult_nr in range(len(eod_fe)):
try:
time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters(
stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr)
except:
print('parameter thing3')
embed()
colorful_title = False
# einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn
eod_fish_r_rec = eod_fish_r * 1
eod_fish_r_rec[eod_fish_r_rec < 0] = 0
add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1,
color_p3, mult_nr, eod_fr, eod_fe, stimulus,
stimulus_rec, colorful_title)
stimulus_here = do_withenoise_stimulus(deltat, eod_fr, stimulus_length)
titles = [titles[1]]
g = 0
color = colors[counter_here]
# A grid for a single POWER column
grid_power_col = gridspec.GridSpecFromSubplotSpec(3, 1,
subplot_spec=grid0[nrs[counter_here]],
height_ratios=[1, 1, 1], wspace=0.45, hspace=0.5)
# FIRST Row: Rectified stimulus
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[0], wspace=ws, hspace=1.3,
width_ratios=wr)
counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform,
stimulus_here, color, time, counter, eod_fr,
deltat,
nfft, shift=shift,
lw=lw,
xlim=xlim)
pps_stimulus.append(pp)
axps_stimulus.append(axp)
colors_chosen.append(color)
if counter_here == 0:
ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
ax_rec[counter_here].text(add_pos[g], 1.1, titles[g],
transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right',
# And plot correspoding sheme
if g == 0:
axsheme = plt.subplot(grid_sheme[0])
plot_sheme_nonlinearity(axsheme, color_p1)
# REMAINING Rows: dendridic filter / LIF /EIF stimulus
exponentials = [''] # , 'EIF'
for ee, exponential in enumerate(exponentials):
manual_offset = False
if manual_offset:
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
eod_fish_rs[g], deltat=deltat,
exponential=exponential, v_exp=v_exp,
exp_tau=exp_tau,
**model_params)
print('Firing rate baseline ' + str(len(spike_times) / stimulus_length))
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
stimulus_here, deltat=deltat, exponential=exponential,
v_exp=v_exp, exp_tau=exp_tau,
**model_params)
if ee == 0:
# SECOND Row: Dendridic Low pass filter
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[1], width_ratios=wr,
wspace=ws, hspace=1.3)
ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform,
time, shift, v_dent_output, color,
deltat,
fft_type, nfft, eod_fr,
xlim=xlim, lw=lw)
pps_lowpass.append(pp)
axps_lowpass.append(axp_p2)
colors_chosen.append(color)
if counter_here == 0:
ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
if g == 0:
axsheme = plt.subplot(grid_sheme[1])
plot_sheme_lowpass(axsheme)
# THIRD /FORTH Row: LIF /EIF
# plot the voltage of the exponentials
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=wr,
subplot_spec=grid_power_col[ee + 2], wspace=ws,
hspace=1.45)
axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color,
spike_times, shift, deltat, fft_type, nfft, eod_fr,
xlim=xlim, exponential=exponential,
counter_here=counter_here) # , add = add
axps.append(axp_s)
pps.append(pp_s)
colors_chosen.append('black')
if ee == 0:
axt_IF1.append(axt_IF)
else:
axt_IF2.append(axt_IF)
if g == 0:
axsheme = plt.subplot(grid_sheme[ee + 2])
plot_sheme_IF(axsheme, exp_tau, v_exp, exponential)
################################
# here plot the amplitude modulation relationship
counter_g += 1
counter_here += 1
####################################
# cut first parts
# because otherwise there is a dip at the beginning and thats a problem for the range thing
ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0)
# here I calculate the log and do the same range for all power spectra
# this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down
# not to see the peaks in the noise
pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True)
axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ])
pps3 = create_same_max(pps, same=True)
pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3]))
axps_stimulus = np.concatenate([axps_stimulus, axps])
# there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm
# therefore we dont show it
colors = [color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3, ]
plot_points = [[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes', ]
for a, axp in enumerate(axps_stimulus):
lw_p = 0.8
plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr)
if a % 4 == 3:
axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right')
axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus)
def model_sheme_vertical(grid_sheme_orig, time_transform=1000, ws=0.1, nfft=4096 * 6, stimulus_length=5,
fft_type='mppsd', a_fr=1, a_fe=0.2,
v_exp=1, exp_tau=0.1, counter=0, shift=0.25):
# need to reduce parameters
# parameters = pd.read_csv("models_big_fit_d_right.csv", index_col=0)
# load_name = "models_big_fit_d_right.csv"#"models_big_fit.csv"
# parameters = pd.read_csv(load_name)
# potentiell zellen wo der range nicht zu weit ist : 0, 9
# problem 0, 10: spikt manchmal
# problem 9: presynaptic oscilation nicht so schön
# gute Zellen: 5
# ok ich mach Zelle Null weil sie am schönsten aussieht, die spikt manchmal aber das klammern wir jetzt halt aus
# good_cells = pd.read_csv("good_model_cells.csv", index_col=0)
# 2 ist wohl sehr nah, das geht!
# 0,1,3,7 weit entfernt
# 4,5,6 sehr weit entfernt
# 8 ist wohl am nächsten!
# embed()
# model_params = parameters.iloc[0]
load_name = 'models_big_fit_d_right.csv'
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
deltat, eod_fr, model_params, v_offset = get_model_params(models, cell='2012-07-03-ak-invivo-1')
eod_fe = [eod_fr + 50] # eod_fr*1+50,, eod_fr * 2 + 50
# REMAINING rows
color_p3 = 'grey' # 'red'#palette['red']
color_p1 = 'grey' # 'blue'#palette['blue']
color_diagonal = 'grey' # 'cyan'#palette['cyan']
colors = [color_diagonal, color_p1, color_p1, color_p3]
ax_rec = [[]] * 4
ax_low = [[]] * 4
axt_IF1 = []
delta_f = [50] # create_beat_corr(np.array([eod_fe[mult_nr] - eod_fr]), np.array([eod_fr]))[0]
counter_here = 0
# first row for the stimulus, and then three cols for the sheme, and the power 1 and power 3
grid0 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid_sheme_orig, wspace=0.35) # height_ratios=[1, 3]
# Grid for the sheme
lw = 0.5
wr = [1, 1.2]
hr = [1]
xlim = 0.065
axps = []
axps_lowpass = []
axps_stimulus = []
pps = []
pps_lowpass = []
pps_stimulus = []
colors_chosen = []
for mult_nr in range(len(eod_fe)):
try:
time, stimulus_rec, eod_fish_r, eod_fish_e, stimulus = make_paramters(
stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, mult_nr)
except:
print('parameter thing2')
embed()
colorful_title = False
# einfach eine stimulus schleife zu machen würde mehrere änderungen bedeutetn
eod_fish_r_rec = eod_fish_r * 1
eod_fish_r_rec[eod_fish_r_rec < 0] = 0
add_pos, color_add_pos, titles, stimuli, eod_fish_rs = titles_EIF(eod_fish_r, eod_fish_r_rec, color_p1,
color_p3, mult_nr, eod_fr, eod_fe, stimulus,
stimulus_rec, colorful_title)
sampling = 1 / deltat
time_eod = np.arange(0, stimulus_length, deltat)
eod_interp, time_wn_cut, _ = load_noise('gwn300Hz50s0.3')
eod_interp = interpolate(time_wn_cut, eod_interp,
time_eod,
kind='cubic')
fake_fish = fakefish.wavefish_eods('Alepto', frequency=eod_fr,
samplerate=sampling,
duration=len(time_eod) / sampling,
phase0=0.0, noise_std=0.00)
stimulus_here = fake_fish * (1 + eod_interp * 0.2)
titles = [titles[1]]
g = 0
color = colors[counter_here]
hs = 0.2 # 1.3
# A grid for a single POWER column
grid_power_col = gridspec.GridSpecFromSubplotSpec(1, 6,
subplot_spec=grid0[0], width_ratios=[0.5, 1, 0.5, 1, 0.5, 1],
wspace=0.45, hspace=0.5)
# FIRST Row: Rectified stimulus
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[1], wspace=ws, hspace=hs,
width_ratios=wr, height_ratios=hr)
counter, ax_rec[counter_here], ff, pp, axp = plot_rec_stimulus(grid_lowpass, time_transform,
stimulus_here, color, time, counter, eod_fr,
deltat,
nfft, shift=shift,
lw=lw,
xlim=xlim)
pps_stimulus.append(pp)
axps_stimulus.append(axp)
colors_chosen.append(color)
if counter_here == 0:
ax_rec[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
ax_rec[counter_here].text(add_pos[g], 1.1, titles[g],
transform=ax_rec[counter_here].transAxes, ) # verticalalignment='right',
# And plot correspoding sheme
axsheme = plt.subplot(grid_power_col[0])
plot_sheme_nonlinearity(axsheme, color_p1)
# REMAINING Rows: dendridic filter / LIF /EIF stimulus
exponential = '' # , 'EIF'
manual_offset = False
if manual_offset:
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
eod_fish_rs[g], deltat=deltat,
exponential=exponential, v_exp=v_exp,
exp_tau=exp_tau,
**model_params)
print('Firing rate baseline ' + str(len(spike_times) / stimulus_length))
spike_times, v_dent_output, v_mem_output = simulate2(load_name,
v_offset,
stimulus_here, deltat=deltat, exponential=exponential,
v_exp=v_exp, exp_tau=exp_tau,
**model_params)
# SECOND Row: Dendridic Low pass filter
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2,
subplot_spec=grid_power_col[3], height_ratios=hr,
width_ratios=wr, wspace=ws, hspace=hs)
ax_low[counter_here], ff, pp, axp_p2, ff_am, pp_am = plot_lowpass(grid_lowpass, time_transform, time,
shift, v_dent_output, color, deltat,
fft_type, nfft, eod_fr,
xlim=xlim, lw=lw)
pps_lowpass.append(pp)
axps_lowpass.append(axp_p2)
colors_chosen.append(color)
if counter_here == 0:
ax_low[counter_here].text(-7, 0, '0', color='black', ha='center', va='center')
axsheme = plt.subplot(grid_power_col[2])
plot_sheme_lowpass(axsheme)
# THIRD /FORTH Row: LIF /EIF
# plot the voltage of the exponentials
grid_lowpass = gridspec.GridSpecFromSubplotSpec(1, 2, height_ratios=hr, width_ratios=wr,
subplot_spec=grid_power_col[5], wspace=ws,
hspace=hs)
axt_IF, axp_IF, ff, pp, axp_s, pp_s = plot_spikes(grid_lowpass, time_transform, v_mem_output, time, color,
spike_times, shift, deltat, fft_type, nfft, eod_fr, xlim=xlim,
exponential=exponential,
counter_here=counter_here) # , add = add
axps.append(axp_s)
pps.append(pp_s)
colors_chosen.append('black')
axt_IF1.append(axt_IF)
axsheme = plt.subplot(grid_power_col[4])
plot_sheme_IF(axsheme, exp_tau, v_exp, exponential)
# plot psd with shared log lim
####################################
# cut first parts
# because otherwise there is a dip at the beginning and thats a problem for the range thing
ff, pps_stimulus, pps_lowpass, pps = cut_first_parts(ff, pps_stimulus, pps_lowpass, pps, ll=0)
# here I calculate the log and do the same range for all power spectra
# this is kind of complicated but some cells spike even withouth thresholding and we want to keep their noise floor down
# not to see the peaks in the noise
pp3_stimulus = create_same_max(np.concatenate([pps_stimulus, pps_lowpass]), same=True)
axps_stimulus = np.concatenate([axps_stimulus, axps_lowpass, ])
pps3 = create_same_max(pps, same=True)
pp3 = create_same_range(np.concatenate([pp3_stimulus, pps3]))
axps_stimulus = np.concatenate([axps_stimulus, axps])
# there are only few cells where the distance is not so high and this cells spike occationally but very randomly still we dont wanna se their power specturm
# therefore we dont show it
colors = [color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3,
color_diagonal, color_p1, color_p1, color_p3, ]
plot_points = [[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes',
[], 'yes', [], 'yes', ]
for a, axp in enumerate(axps_stimulus):
lw_p = 0.8
plot_power_common_lim(axp, pp3[a], ff / eod_fr, colors[a], lw_p, plot_points[a], delta_f / eod_fr)
if a % 4 == 3:
axp.yscalebar(1, 0.5, 20, 'dB', va='center', ha='right')
axps_stimulus[0].get_shared_y_axes().join(*axps_stimulus)
def get_model_params(models, cell='2012-07-03-ak-invivo-1'):
model_params = models[models['cell'] == cell].iloc[0]
eod_fr = model_params.pop('EODf') # .iloc[0]
deltat = model_params.pop("deltat") # .iloc[0]
v_offset = model_params.pop("v_offset") # .iloc[0]
return deltat, eod_fr, model_params, v_offset
def share_yaxis(axes):
for ax in axes: # , axt_IF2
ax[0].get_shared_y_axes().join(*ax)
maxs = []
mins = []
for a in ax:
maxs.append(np.nanmax(a.get_ylim()))
mins.append(np.nanmin(a.get_ylim()))
for a in ax:
a.set_ylim(np.min(mins), np.max(maxs))
def flowchart():
default_settings(column=2, length=7)
cell = "2013-01-08-aa-invivo-1"
model_cells = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
model_params = model_cells[model_cells.cell == cell].iloc[0]
noise_strength = model_params.noise_strength # **2/2
a_fr = 1 # ,0]#0,,0,]#,0 ,0 ] # ,0,]#0#1
eod_fr = model_params['EODf']
deltat = model_params.pop("deltat")
cut_offs = [eod_fr / 2]
var_type = 'additiv_cutoff_scaled' # ]#'additiv_visual_d_4_scaled']
fig, ax = plt.subplots(6, (len(cut_offs) + 1) * len(var_types)) # , constrained_layout = True figsize=(12, 5),
colors_title = [['black', 'purple', 'black', 'black', 'black', 'black'],
['black', 'black', 'black', 'black', 'black', 'purple']]
d_new_zeros = {}
tags = []
c_sig = 0.9
c_noise = 0.1
d_new_zeros[var_type] = []
arrays2 = np.load(load_folder_name('calc_RAM') + '\RAM_extraction_a1_' + var_type + '.npy', allow_pickle=True)
arrays = np.load(load_folder_name('calc_RAM') + '\RAM_extraction_a0_' + var_type + '.npy', allow_pickle=True)
titles = ['Noise', 'RAM', 'RAM*Carrier',
'RAM*Carrier to RAM', 'V_dent', 'V_dent to RAM']
colors = ['grey', 'red', 'grey', 'red', 'grey', 'red']
max_f = 9999.875 # hab ich aus np.fft.fft(noise)
var_desired = '$var_{desired} = $' + str(np.round(np.var(arrays[0]) * c_sig * cut_off / max_f, 5))
D_desired = np.round(np.sqrt(noise_strength * 2 * c_sig), 5)
plt.suptitle(
'$Contrast_{receiver}=$' + str(a_fr) + ', $c_{noise}=$' + str(c_noise) + ', $c_{signal}=$' + str(
c_sig) + ', ' + var_desired + r' $\sqrt{2D*c_{signal}}$=' + str(
np.round(D_desired, 5)))
grid = gridspec.GridSpecFromSubplotSpec(1, 4, grid_orig[0], wspace=1.2,
hspace=0.13)
for i in range(len(arrays)):
sampling = 1 / deltat
ax = plt.subplot(grid[i])
if len(np.arange(0, len(arrays[i]) / sampling, 1 / sampling)) > len(arrays[i]):
ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays[i]) / sampling, 1 / sampling)[0:-1], arrays[i],
color=colors[i])
else:
ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays[i]) / sampling, 1 / sampling), arrays[i],
color=colors[i])
if len(arrays2[i]) > 0:
if len(np.arange(0, len(arrays2[i]) * deltat, deltat)) > len(arrays2[i]):
ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays2[i]) * deltat, deltat)[0:-1], arrays2[i],
color='red')
else:
ax[0 + i, 0 + v * 2].plot(np.arange(0, len(arrays2[i]) * deltat, deltat), arrays2[i],
color='red')
tags.append(ax[0 + i, 0 + v * 2])
ax[0 + i, 0 + v * 2].set_title(titles[i] + ' var=' + str(np.round(np.var(arrays[i]), 5)),
color=colors_title[v][i],
fontsize=8) # +' var/c='+str(np.round(np.var(arrays[i])/cs[i],5))
ax[0 + i, 0 + v * 2].set_xlim(0, 0.1)
p_array_fft = np.fft.fft(arrays[i] - np.mean(arrays[i]), norm='forward')
f = np.fft.fftfreq(len(arrays[i]), deltat)
f_sorted = np.sort(f)
p_sorted = np.abs(p_array_fft)[np.argsort(f)]
ax[0 + i, 1 + v * 2].plot(f_sorted, p_sorted, color='grey') # np.log10(p_noise / np.max(p_noise))
left = np.argmin(np.abs(f_sorted) - 0) - 10
left2 = np.argmin(np.abs(f_sorted) - 0)
d_new_zero = np.mean(p_sorted[left:left2])
ax[0 + i, 1 + v * 2].plot(f_sorted[left:left2], p_sorted[left:left2], color='blue')
d_new_zeros[var_type].append(d_new_zero)
ax[0 + i, 1 + v * 2].set_title('D close to 0 = ' + str(np.round(d_new_zero, 5)),
color=colors_title[v][i])
ax[0 + i, 1 + v * 2].set_xlim(-eod_fr / 2 * 1.2, eod_fr / 2 * 1.2)
if i < len(arrays) - 1:
remove_yticks(ax[0 + i, 0])
remove_yticks(ax[0 + i, 1])
remove_yticks(ax[0 + i, 0 + v * 2])
remove_yticks(ax[0 + i, 1 + v * 2])
ax[0, 0 + v * 2].text(0, 1.6, var_type + ': $D_{RAM}/V_dent_{RAM} =$' + str(
np.round(d_new_zeros[var_type][1] / d_new_zeros[var_type][-1], 3)),
transform=ax[0, 0 + v * 2].transAxes, color='purple')
ax[-1, 0 + v * 2].set_xlabel('Time [s]')
ax[-1, 1 + v * 2].set_xlabel('Frequency [Hz]')
ax[-1, 1 + v * 2].set_ylabel('[Hz]')
ax[0, 0 + v * 2].get_shared_x_axes().join(*np.concatenate([ax[:, 0], ax[:, 0 + v * 2]]))
ax[0, 1 + v * 2].get_shared_x_axes().join(*np.concatenate([ax[:, 1], ax[:, 1 + v * 2]]))
ax[2, 1 + v * 2].get_shared_y_axes().join(ax[2, 1], ax[4, 1], ax[2, 1 + v * 2], ax[4, 1 + v * 2])
ax[1, 1 + v * 2].get_shared_y_axes().join(ax[0, 1], ax[1, 1], ax[3, 1], ax[5, 1], ax[0, 1 + v * 2],
ax[1, 1 + v * 2], ax[3, 1 + v * 2], ax[5, 1 + v * 2])
ax[2, 0 + v * 2].get_shared_y_axes().join(ax[2, 0], ax[4, 0], ax[2, 0 + v * 2], ax[4, 0 + v * 2])
ax[1, 0 + v * 2].get_shared_y_axes().join(ax[1, 0], ax[3, 0], ax[5, 0], ax[1, 0 + v * 2], ax[3, 0 + v * 2],
ax[5, 0 + v * 2]) # ax[0, 0],ax[0, 0+v*2],
def find_cells(file_names_exclude, sorting, cells_chosen, cell_type, cell_type_type, cell_type_chosen, load_path):
frame_base = load_cv_table()
frame_base = unify_cell_names(frame_base, cell_type=cell_type_type)
frame_base = frame_base[frame_base[cell_type_type] == cell_type_chosen]
cell_base = frame_base.cell.unique()
if '.csv' not in load_path:
stack = pd.read_csv(load_path + '.csv') # ,index_col = 0
else:
stack = pd.read_csv(load_path)
stack = stack[~stack['file_name'].isin(file_names_exclude)]
stack_files = stack # [stack['celltype'].isin(cell_type)]#cell_type_type
cells_gwn = stack_files.cell.unique()
cell_chose = 'base'
if cell_chose == 'base':
cells = cell_base
else:
cells = cells_gwn
if 'p-unit' in cell_type:
if len(cells_chosen) == 0:
stack_cells = stack_files[stack_files['cell'].isin(cells)]
cvs = stack_cells[sorting] # .iloc[0]
cells = np.array(stack_cells.cell)
cvs = np.array(cvs)
lengths = stack_cells['stimulus_length']
cv_min = False
if cv_min:
cells = cells[cvs < 0.3]
lengths = lengths[cvs < 0.3]
cvs = cvs[cvs < 0.3]
cells = cells[lengths > 3]
cvs = cvs[lengths > 3]
cells, cvs_unique = make_cell_unique(cvs, cells)
cells = list(cells)
# Zellen mit starken Artefakten
cells_rem = ['2010-08-25-ab-invivo-1',
'2010-11-08-aa-invivo-1',
'2010-11-11-al-invivo-1',
'2011-02-18-ab-invivo-1',
'2011-09-21-ab-invivo-1',
'2011-10-25-ac-invivo-1',
'2011-11-10-ab-invivo-1',
'2011-11-10-ag-invivo-1',
'2012-12-19-aa-invivo-1',
'2012-12-19-ab-invivo-1',
'2012-12-19-ac-invivo-1',
'2013-02-21-aa-invivo-1',
'2013-04-09-ab-invivo-1',
'2013-04-16-aa-invivo-1',
'2013-04-16-ab-invivo-1',
'2013-04-16-ac-invivo-1',
'2013-04-17-af-invivo-1',
'2013-04-18-ac-invivo-1',
] #
cells_rem_wo_base = ['2010-11-08-ab-invivo-1', '2010-07-29-ae-invivo-1', '2011-11-10-ah-invivo-1',
'2011-11-10-ak-invivo-1', '2012-05-30-aa-invivo-1', '2012-07-12-al-invivo-1',
'2012-10-19-aa-invivo-1', '2012-10-19-ad-invivo-1', '2012-12-20-af-invivo-1',
'2013-04-16-ad-invivo-1', '2013-04-11-ab-invivo-1', '2014-01-23-ac-invivo-1',
'2014-01-16-aj-invivo-1']
cells_rem_wo_base_not_nice = ['2010-11-26-al-invivo-1', '2010-11-11-aj-invivo-1']
for cell_rem in cells_rem:
if cell_rem in cells:
cells.remove(cell_rem)
for cell_rem in cells_rem_wo_base:
if cell_rem in cells:
cells.remove(cell_rem)
for cell_rem in cells_rem_wo_base_not_nice:
if cell_rem in cells:
cells.remove(cell_rem)
cells = cells[0:16]
else:
cells = cells_chosen
elif cell_type == [' A-unit', ' Ampullary']:
stack_cells = stack_files[stack_files['cell'].isin(cells)]
cvs = stack_cells[sorting] # .iloc[0]
cells = np.array(stack_cells.cell)
cvs = np.array(cvs)
lengths = stack_cells['stimulus_length']
cv_min = False
if cv_min:
cells = cells[cvs < 0.3]
lengths = lengths[cvs < 0.3]
cvs = cvs[cvs < 0.3]
cells = cells[lengths > 3]
cvs = cvs[lengths > 3]
cells, cvs_unique = make_cell_unique(cvs, cells)
cells = list(cells)
# Zellen mit starken Artefakten
cells_rem = ['2010-08-25-ab-invivo-1',
'2010-11-08-aa-invivo-1',
'2010-11-11-al-invivo-1',
'2011-02-18-ab-invivo-1',
'2011-09-21-ab-invivo-1',
'2011-10-25-ac-invivo-1',
'2011-11-10-ab-invivo-1',
'2011-11-10-ag-invivo-1',
'2012-12-19-aa-invivo-1',
'2012-12-19-ab-invivo-1',
'2012-12-19-ac-invivo-1',
'2013-02-21-aa-invivo-1',
'2013-04-09-ab-invivo-1',
'2013-04-16-aa-invivo-1',
'2013-04-16-ab-invivo-1',
'2013-04-16-ac-invivo-1',
'2013-04-17-af-invivo-1',
'2013-04-18-ac-invivo-1',
] #
cells_rem_wo_base = ['2010-11-08-ab-invivo-1', '2010-07-29-ae-invivo-1', '2011-11-10-ah-invivo-1',
'2011-11-10-ak-invivo-1', '2012-05-30-aa-invivo-1', '2012-07-12-al-invivo-1',
'2012-10-19-aa-invivo-1', '2012-10-19-ad-invivo-1', '2012-12-20-af-invivo-1',
'2013-04-16-ad-invivo-1', '2013-04-11-ab-invivo-1', '2014-01-23-ac-invivo-1',
'2014-01-16-aj-invivo-1']
cells_rem_wo_base_not_nice = ['2010-11-26-al-invivo-1', '2010-11-11-aj-invivo-1']
for cell_rem in cells_rem:
if cell_rem in cells:
cells.remove(cell_rem)
for cell_rem in cells_rem_wo_base:
if cell_rem in cells:
cells.remove(cell_rem)
for cell_rem in cells_rem_wo_base_not_nice:
if cell_rem in cells:
cells.remove(cell_rem)
cells = cells[0:16]
return cells
def load_cell_types(file_name_exclude, sorting, load_path, cell_type, cells_chosen=[],
cell_type_chosen=' Ampullary',
cell_type_type='cell_type_reclassified'):
# das ist jetzt ein Funktion die das selber auswählt für die punit.py und ampullary.py functions
if os.path.exists(load_path + '.csv'):
# hier finde ich quasi nur die Zellen raus die ich haben will
cells = find_cells(file_name_exclude, sorting, cells_chosen, cell_type, cell_type_type, cell_type_chosen,
load_path)
stack = load_data_susept(load_path + '.csv', load_path, cells=cells)
else:
# wenn das noch nicht abgespeichert ist machen wir das so
stack = load_data_susept(load_path + '.csv', load_path)
stack_files = stack[stack['celltype'].isin(cell_type)]
cells = stack.cell.unique()
return stack_files, cells
def colorbar_outside_right(ax, fig, im, shrink=0.6, width=0.02, plusx=0.01):
pos = ax.get_position() # [[xmin, ymin], [xmax, ymax]].
pos = np.array(pos)
xmin = pos[0][0]
ymin = pos[0][1]
ymax = pos[1][1]
left = xmin + plusx
bottom = ymax # - 0.076+add#85
height = (ymax - ymin)
cbar_ax = fig.add_axes([left, bottom, width, height]) # [left, bottom, width, height
cbar_ax.xaxis.set_label_position('bottom')
cbar_ax.set_xticklabels(cbar_ax.get_xticklabels(), rotation='vertical')
cbar_ax.tick_params(labelsize=6)
cbar = fig.colorbar(im, orientation="vertical", cax=cbar_ax, shrink=shrink)
return cbar, left, bottom, width, height
def plt_cv_part(cell, frame_save, frame, cell_nr, ax, lim_here=[], color_bar='grey', xlim=(0, 17)):
cv_isi = frame.iloc[cell_nr].cv
# embed()#'\n cv_inst '+str(np.round(cv_inst,2))+' cv_inst_fr '+ str(np.round(cv_inst_fr,2))' cv_mat '+ str(np.round(cv_mat,2))' m_isi '+str(np.round(mean_isi))'\n m_inst_fr '+ str(np.round(mean_inst_fr))+
cv_title = False
if cv_title:
if cv_isi < 0.2:
color = 'red'
elif cv_isi < 0.3:
color = 'purple'
elif cv_isi < 0.4:
color = 'orange'
elif cv_isi < 0.7:
color = 'green'
else:
color = 'blue'
else:
color = title_color(cell)
frame_here = frame_save[frame_save.cell == cell]
try:
hist = frame_here['hist'].iloc[0][0]
except:
print('hist problem')
embed()
width = (hist[1][1] - hist[1][0])
if lim_here != []:
ex = list(hist[1] > lim_here)
ex_bars = ex[0:-1]
y = hist[0][ex_bars] # [0]+ width / 2
x = hist[1][np.array(ex)][0:-1] # [0]
else:
x = hist[1][0:-1] + width / 2
y = hist[0]
ax.bar(x, height=y, width=width, color=color_bar)
if xlim:
ax.set_xlim(xlim)
return color, cv_isi
def plt_psd_traces(grid1, grid2, axs, min_lim, max_lim, eod_fr, fr, fr_stim, stack_final1, fr_color='red',
fr_stim_color='darkred', peaks=True, db='', rmv_axo=True, stack_isf=[],
stack_osf=[], rmv_axi=True, reset_pos=True, eod_fr_color='magenta', eod_fr_half_color='purple'):
ax_pos, axi, colors_f, freqs, isf_resaved = plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr,
fr_color, fr_stim, fr_stim_color, grid2, max_lim, min_lim,
peaks, reset_pos, rmv_axi, stack_final1, stack_isf,
isf_name='isf')
# plot output trace
ax_pos, axo, colors_f, freqs, isf_resaved = plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr,
fr_color, fr_stim, fr_stim_color, grid1, max_lim, min_lim,
peaks, reset_pos, rmv_axo, stack_final1, stack_osf,
isf_name='osf')
return axo, axi
def plt_trace(axs, db, eod_fr, eod_fr_color, eod_fr_half_color, fr, fr_color, fr_stim, fr_stim_color, grid2, max_lim,
min_lim, peaks, reset_pos, rmv_axi, stack_final1, stack_isf, isf_name='isf', clip_on=True):
ax_pos = np.array(axs.get_position()) # [[xmin, ymin], [xmax, ymax]].
if len(stack_isf) == 0:
isf = stack_final1[isf_name]
isf_resaved = False
else:
isf = stack_isf
isf_resaved = True
axi = plt.subplot(grid2)
ax_pos2 = np.array(axi.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width
if reset_pos:
axi.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0],
ax_pos2[1][1] - ax_pos2[0][1]])
freqs = [fr, fr * 2, fr_stim, fr_stim * 2, eod_fr, eod_fr * 2, eod_fr / 2]
colors_f = [fr_color, fr_color, fr_stim_color, fr_stim_color, eod_fr_color, eod_fr_color,
eod_fr_half_color]
plt_isf_ps_red(stack_final1, isf, 0, axi, freqs=freqs, colors=colors_f,
clip_on=clip_on, peaks=peaks, db=db, max_lim=max_lim, osf_resaved=osf_resaved, )
axi.set_xlim(min_lim, max_lim)
if rmv_axi:
remove_xticks(axi)
return ax_pos, axi, colors_f, freqs, isf_resaved
def plt_isf_ps_red(stack_final, isf, l, ax_i, color='black', power=1, several=False, maxi=1, peaks=True,
freqs=[], db='', max_lim=None, colors=[], osf_resaved=False, clip_on=False):
f = find_f(stack_final)
try:
if osf_resaved:
f_axis = f[0:len(isf)]
means = np.transpose(isf)
means_all = np.mean(np.abs(means) ** power, axis=0)
p = np.abs(means.iloc[0]) ** power
else:
f_axis = f[0:len(isf.iloc[l][0])]
means = get_array_from_pandas(isf)
means_all = np.mean(np.abs(means) ** power, axis=0)
p = np.abs(isf.iloc[l][0]) ** power
if db == 'db':
p = 10 * np.log10(p / maxi)
means_all = 10 * np.log10(means_all / maxi)
add = np.percentile(means_all, 90)
if max_lim:
if several:
ax_i.plot(f_axis[f_axis < max_lim], p[f_axis < max_lim], color='grey', zorder=1)
ax_i.plot(f_axis[f_axis < max_lim], means_all[f_axis < max_lim], color='black', zorder=1)
else:
if several:
ax_i.plot(f_axis, p, color='grey', zorder=1)
ax_i.plot(f_axis, means_all, color=color, zorder=1)
ax_i.set_xlim(0, 700)
if peaks:
for i in range(len(freqs)):
plt_peaks(ax_i, p, freqs[i], f_axis, fr_color=colors[i], add=add, clip_on=clip_on)
except:
print('f axis problem')
embed()
def find_cells_plot(save_names, amps_desired=[5, 10, 20], cell_class=' Ampullary'):
# 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv'
# frame_csv_overview_test = pd.read_csv('../data/Noise/noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv')
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv'
frame_csv_overview = pd.read_csv(load_name, low_memory=False)
frame_csv_overview[['cell', 'amp']].sort_values('cell')
########################
# here find the cells that are in the amps
unique_combos = frame_csv_overview[['cell', 'amp']].drop_duplicates()
if len(amps_desired) > 0:
combos_three = unique_combos[unique_combos.amp.isin(amps_desired)]
cell_counts = combos_three.cell.value_counts()
cell_to_plot = cell_counts[cell_counts == len(amps_desired)].keys()
else:
cell_counts = unique_combos.cell.value_counts()
cell_to_plot = cell_counts[cell_counts > 3].keys()
# hier nehmen wir wirklich nur die die auch ein GWN haben, das ist der Unterschied
frame = load_cv_table()
cell_type_type = 'cell_type_reclassified'
frame = unify_cell_names(frame, cell_type=cell_type_type)
cell_types = frame[cell_type_type].unique()
cells_dict = cluster_cells_by_group_dict(cell_types, frame, cell_type_type)
cells = cells_dict[cell_class]
cells_plot = cell_to_plot[cell_to_plot.isin(cells)]
frame_cv = frame[frame.cell.isin(cells_plot)]
frame_cv = frame_cv.sort_values('cv')
cells_plot = frame_cv.cell
cells_plot = list(cells_plot)
if '2012-06-08-ae-invivo-1' in cells_plot:
cells_plot.remove('2012-06-08-ae-invivo-1')
return amps_desired, cell_type_type, cells_plot, frame, cell_types
def load_cells_in_sample(cell_class, save_names, amps_desired, cell_type_type, frame):
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv'
frame_csv_overview = pd.read_csv(load_name, low_memory=False) # dtype={'three': bool,'cell':str,'highest_fr':float}
frame_csv_overview[['cell', 'amp']].sort_values('cell')
unique_combos = frame_csv_overview[['cell', 'amp']].drop_duplicates()
if len(amps_desired) > 0:
combos_three = unique_combos[unique_combos.amp.isin(amps_desired)]
cell_counts = combos_three.cell.value_counts()
cell_to_plot = cell_counts[cell_counts == len(amps_desired)].keys()
else:
cell_counts = unique_combos.cell.value_counts()
cell_to_plot = cell_counts[cell_counts > 3].keys()
# hier nehmen wir wirklich nur die die auch ein GWN haben, das ist der Unterschied
cell_types = frame[cell_type_type].unique()
cells_dict = cluster_cells_by_group_dict(cell_types, frame, cell_type_type)
cells = cells_dict[cell_class]
cells_plot = cell_to_plot[cell_to_plot.isin(cells)]
frame_cv = frame[frame.cell.isin(cells_plot)]
frame_cv = frame_cv.sort_values('cv_min')
cells_plot = frame_cv.cell
cells_plot = list(cells_plot)
if '2012-06-08-ae-invivo-1' in cells_plot:
cells_plot.remove('2012-06-08-ae-invivo-1')
return cells_plot, cell_types
def load_isis(save_names, amps_desired=[5, 10, 20], cells_given=[], cell_class=' Ampullary'):
# 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv'
# frame_csv_overview_test = pd.read_csv(load_folder_name('calc_RAM')+'/calc_RAM_model-2_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s.csv')
# if os.path.exists():
cell_type_type = 'cell_type_reclassified'
frame = load_cv_base_frame(cells_given, cell_type_type=cell_type_type)
#####################
# rausfinden welche Zellen wir plotten wollen
cells_plot, cell_types = load_cells_in_sample(cell_class, save_names, amps_desired, cell_type_type, frame)
return amps_desired, cell_type_type, cells_plot, frame, cell_types
def remove_tick_marks(ax):
ax.xaxis.set_major_formatter(ticker.NullFormatter())
return ax
def plt_scatter_two(ax0, ax2, frame, cell_types, cell_type_type, annotate, colors):
add = ['', '_burst_corr_individual', ]
# ok hier plotten wir nur den scatter der auch ein gwn hat, aber was ist wenn es mehr sind?
# ok im prinzip sollte das zwar schon stimmen aber für das Bild kann man wirklich mehr machen
for c, cell_type_it in enumerate(cell_types):
frame_g = frame[
(frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))]
plt_cv_fr(annotate, ax0, add[0], frame_g, colors, cell_type_it)
ax2.set_title('burst')
for c, cell_type_it in enumerate(cell_types):
frame_g = frame[
(frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))]
plt_cv_fr(annotate, ax2, add[1], frame_g, colors, cell_type_it)
return ax0, ax2
def square_func(ax, stack_final, perc_min=5, perc_max=95, norm='', s=0):
new_keys, stack_plot = retrieve_mat_plot(stack_final)
eod_fr = stack_final.eod_fr.unique()[0]
fr2 = np.unique(stack_final.fr_stim)
fr = stack_final.fr.unique()[0]
# todo: hier das normen noch anpassen
mat = ram_norm_choice(stack_plot, norm, stack_final)
imshow = True
if imshow:
vmin = np.nanpercentile(mat, perc_min)
vmax = np.nanpercentile(mat, perc_max)
im = ax[s].imshow(mat, vmin=vmin,
extent=[mat.index[0], mat.index[-1], mat.columns[0],
mat.columns[-1]], vmax=vmax,
origin='lower', cmap='viridis')
else:
im = ax[s].pcolormesh(mat.index, mat.columns, mat, vmin=0,
vmax=np.nanpercentile(mat, 97), cmap='viridis',
rasterized=True) # np.nanpercentile(mat, 1) , cmap ='hot''Greens' pcolormesh
ax[s].set_aspect('equal')
ax[s].set_xlabel(F1_xlabel())
ax[s].set_ylabel(F2_xlabel(), labelpad=0.2)
ax[s].set_xlim(mat.index[0], mat.index[-1])
ax[s].set_ylim(mat.columns[0], mat.columns[-1])
plt_triangle(ax[s], fr, np.mean(fr2), new_keys[-1], eod_fr, eod_fr_half_color='purple', fr_color='red',
eod_fr_color='magenta', fr_stim_color='darkred')
return im, mat.columns[0], mat.columns[-1]
def retrieve_mat_plot(stack_final):
keys = stack_final.keys()
new_keys = stack_final.index
stack_plot = stack_final[new_keys]
return new_keys, stack_plot
def plot_square_core(ax, stack_final, s=0, nr=3, eod_metrice=True, fr=None, cbar_do=True, perc=True, line_length=1 / 4,
add_nonlin_title=None):
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
eod_fr = stack_final.eod_fr.unique()[0]
fr2 = np.unique(stack_final.fr_stim)
if not fr:
fr = stack_final.fr.unique()[0]
norm_d = False
if norm_d:
mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0])
else:
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat, add_nonlin_title=add_nonlin_title)
print(add_nonlin_title)
imshow = True
if imshow:
if perc:
im = ax[s].imshow(mat, vmin=np.nanpercentile(mat, 5),
extent=[mat.index[0], mat.index[-1], mat.columns[0],
mat.columns[-1]], vmax=np.nanpercentile(mat, 95), cmap='viridis',
origin='lower') #
else:
im = ax[s].imshow(mat,
extent=[mat.index[0], mat.index[-1], mat.columns[0],
mat.columns[-1]], cmap='viridis',
origin='lower')
else:
im = ax[s].pcolormesh(mat.index, mat.columns, mat, vmin=0,
vmax=np.nanpercentile(mat, 97), cmap='Greens',
rasterized=True) # np.nanpercentile(mat, 1) , cmap ='hot''Greens' pcolormesh
ax[s].set_aspect('equal')
ax[s].set_xlabel(F1_xlabel())
ax[s].set_ylabel(F2_xlabel(), labelpad=0.2)
ax[s].set_xlim(mat.index[0], mat.index[-1])
ax[s].set_ylim(mat.columns[0], mat.columns[-1])
plt_triangle(ax[s], fr, np.mean(fr2), new_keys[-1], eod_fr, line_length=line_length, eod_metrice=eod_metrice,
nr=nr) # eod_fr_half_color='purple', power_noise_color='blue',
if cbar_do:
try:
cbar = plt.colorbar(im, ax=ax[s], shrink=0.6)
except:
print('colorbar problem')
cbar = []
else:
cbar = []
return cbar, mat, im, add_nonlin_title
def cluster_cells_by_group_dict(cell_types, frame, cell_type_type):
cells = {}
for ct in np.sort(cell_types):
fr = frame[frame[cell_type_type] == ct].cell
fr = fr.astype('str')
cells[ct] = np.sort(np.array(fr.unique()))
return cells
def plt_cell_body_isf_single_rotate2(axi, grid1, ax0, ax1, ax2, b, cell, frame, colors, amps_desired, save_names,
cell_type_type,
xlim=[0, 13], burst_corr='', predefined_amps2=False, norm=False):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
cell_type = frame_cell[cell_type_type].iloc[0]
spikes = frame_cell.spikes.iloc[0]
fr = frame_cell.fr.iloc[0]
cv = frame_cell.cv.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont = load_spikes(spikes, eod_fr)
# cont heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
if cont:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
lim_here = find_lim_here(cell, burst_corr=burst_corr)
if np.min(np.concatenate(hists)) < lim_here:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all), lim=lim_here,
burst_corr=burst_corr)
hists_both = [hists, hists2]
else:
hists_both = [hists, hists]
# das ist der title fals der square nicht plottet
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz',
fontsize=11, ) # cell[0:13] + color=color+ cell_type
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = punit_file_exclude() #
else:
file_names_exclude = ampullary_file_exclude() #
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amps2:
for a, amp in enumerate(amps):
if amp not in amps_desired:
predefined_amp = False
if predefined_amp:
pass
else:
pass
amps_defined = [np.min(amps)]
file, cut_offs = find_optimal_files(files)
stack_file = stack[stack['file_name'] == file]
for a, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
axs, axo, axin = square_isf(grid1, norm, b, cell, stack_file, amp, eod_fr, file)
################################
# do the scatter of these cells
add = ['', '_burst_corr', ]
try:
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('colors_f problem')
embed()
if len(ax1) > 0:
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
plt_hists(axi, cell_type, colors, hists_both, xlim, b, alpha=1)
return axs, axo, axin
def find_optimal_files(files):
cut_offs = []
for file in files:
cut_offs.append(calc_cut_offs(file))
file = files[np.argmax(cut_offs)]
return file, cut_offs
def square_isf(grid1, norm, b, cell, stack_file, amp, eod_fr, file):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
axs = plt.subplot(grid1[2])
except:
print('grid problem6')
embed()
im, min_lim, max_lim = square_func([axs], stack_final1, norm=norm)
cbar = plt.colorbar(im, ax=axs, orientation='vertical')
if b != 0:
cbar.set_label(nonlin_title(), rotation=270, labelpad=1000)
fr = stack_final1.fr.unique()[0]
snippets = stack_final1['snippets'].unique()[0]
cv = stack_final1.cv.unique()[0]
ser = stack_final1.ser.unique()[0]
cv_stim = stack_final1.cv_stim.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
ser_stim = stack_final1.ser_stim.unique()[0]
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + '' + 'S.Nr ' + str(
snippets) + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz' + ' ser ' + str(np.round(ser))
+ ' Stim: cv ' + str(np.round(cv_stim, 2)) + ' fr ' + str(
np.round(fr_stim)) + ' Hz' + ' ser ' + str(np.round(ser_stim)) + ' length ' + str(
length)
,
fontsize=11, ) # cell[0:13] + color=color+ cell_type
eod_fr_half_color = 'purple'
fr_color = 'red'
eod_fr_color = 'magenta'
fr_stim_color = 'darkred'
axo, axin = plt_psd_traces(grid1[0], grid1[1], axs, min_lim, max_lim, eod_fr, fr, fr_stim, stack_final1, fr_color,
fr_stim_color,
eod_fr_color, eod_fr_half_color)
axo.set_title(' std ' + str(amp) + ' ' + file)
return axs, axo, axin
def plt_hists(axi, cell_type, colors, hists_both, xlim, b, alpha=1):
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
for gg in range(len(range(b + 1))):
hists_here = hists_both[gg]
for hh, h in enumerate(hists_here):
try:
axi.hist(h, bins=100, color=colors_hist[gg], label='CV ' + str(np.round(np.std(h) / np.mean(h), 3)),
alpha=float(alpha - 0.05 * hh))
except:
print('alpha problem4')
embed()
axi.legend(ncol=2)
if len(xlim) > 0:
axi.set_xlim(xlim)
axi.set_xlabel('isi')
def plt_cv_fr(annotate, ax0, add, frame_g, colors, cell_type):
ax0.scatter(frame_g['cv' + add], frame_g['fr' + add], alpha=0.5, label=cell_type, s=7, color=colors[str(cell_type)])
exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax0.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['fr' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax0.set_xlim(0, 1.5)
ax0.set_ylabel('Base Freq [Hz]')
ax0.set_xlabel('CV')
def plt_cv_part_several(row, col, cell, frame_save, frame, cell_nr, counter, ax):
cv_isi = frame.iloc[cell_nr].cv
fr = frame.iloc[cell_nr].fr
cv_title = False
if cv_title:
if cv_isi < 0.2:
color = 'red'
elif cv_isi < 0.3:
color = 'purple'
elif cv_isi < 0.4:
color = 'orange'
elif cv_isi < 0.7:
color = 'green'
else:
color = 'blue'
else:
color = title_color(cell)
frame_here = frame_save[frame_save.cell == cell]
ax[counter].text(0, 1.26, cell[0:-9] + '\n cv ' + str(np.round(cv_isi, 2)) + ' fr ' + str(np.round(fr)) + ' Hz',
transform=ax[counter].transAxes, color=color, fontsize=8)
try:
hist = frame_here['hist'].iloc[0][0]
except:
print('hist problem')
embed()
width = (hist[1][1] - hist[1][0])
ax[counter].bar(hist[1][0:-1] + width / 2, height=hist[0], width=width, )
if counter == row * col - col:
ax[counter].set_xlabel('Inter Spike Interval, EODf multiples')
ax[counter].set_ylabel('nr')
ax[counter].set_xlim(0, 17)
counter += 1
return counter
def plt_squares_special(params, col_desired=2, var_items=['contrasts'], show=False, contrasts=[0], noises_added=[''],
fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['RAM'],
external_noise=['RAM'], level_extraction=['_RAMdadjusted'], cut_off2=300,
repeats=[1000000], receiver_contrast=[1], visualize=True, dendrids=[''], ref_types=[''],
adapt_types=[''],
c_noises=[0.1], perc='', share=True, c_signal=[0.9], new_plot=True, cut_offs1=[300],
clims='all', restrict='restrict',
label=r'$\frac{1}{mV^2S}$', width=0.005, cells_given=None, lp=100, ax=[], titles_plot=True):
nffts = ['whole'] # ,int(2 ** 16) int(2 ** 16), int(2 ** 15),
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
powers = [1] # ,3]#, 3, 1, 1.5, 0.5, ] # ,1,1.5, 0.5] #[1,1.5, 0.5] # 1.5,0.5]3, 1,
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
if new_plot:
plot_style()
if cells_given:
params_len = np.arange(0, len(params) * len(cells_given), 1)
else:
params_len = params
col, row = find_row_col(params_len, col=col_desired) # np.arange(
if col == 2:
default_settings(column=2, length=7.5) # 2+2.25+2.25
elif col == 1:
default_settings(column=2, length=4)
elif col > 2:
if row == 2:
default_settings(column=2, length=4.5)
else:
default_settings(column=2, length=7.5)
else:
default_settings(column=2, length=7.5)
fig, ax_orig = plt.subplots(row, col, sharex=True,
sharey=True) # constrained_layout=True,, figsize=(11, 5)
if row != 1:
ax = np.concatenate(ax_orig)
else:
ax = ax_orig
if col == 2:
plt.subplots_adjust(bottom=0.067, top=0.81, hspace=0.39, right=0.95,
left=0.075) # , hspace = 0.6, wspace = 0.5
elif col == 1:
plt.subplots_adjust(bottom=0.1, top=0.81, hspace=0.39, right=0.95,
left=0.075) # , hspace = 0.6, wspace = 0.5
else:
if row == 2:
plt.subplots_adjust(bottom=0.07, top=0.76, wspace=0.9, hspace=0.4, right=0.85,
left=0.075) # , hspace = 0.6, wspace = 0.5
else:
plt.subplots_adjust(bottom=0.05, top=0.81, wspace=0.9, hspace=0.2, right=0.85,
left=0.075) # , hspace = 0.6, wspace = 0.5
else:
col = col_desired
maxs = []
mins = []
ims = []
#######################################################################
# das ist jetzt der core
a = 0
aa = 0
for var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product(
D_extraction_method, external_noise
, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
nr = '2'
for p, param in enumerate(params):
print(a)
'contrasts'
a_fe = params[p]['contrasts'][0]
var_type = params[p]['D_extraction_method'][0]
extract = params[p]['level_extraction'][0]
if 'repeats' in params[p]:
trials_stim = params[p]['repeats'][0]
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim,
power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract, noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
adapt_type_name, dendrid_name, ref_type_name, stim_type_noise_name = add_ends(adapt_type, dendrid, ref_type,
stim_type_noise, var_type)
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe, stim_type_noise_name)
path = save_name + '.pkl' # '../'+
cell_add, cells_save = find_cell_add(cells_given)
model = load_model_susept(path, cells_save, save_name.split(r'/')[-1] + cell_add)
test = False
if test:
from utils_test import test_model
test_model()
if len(model) > 0:
cells = model.cell.unique() # model = pd.read_pickle(path)
if not cells_given:
cells = [cells[0]]
else:
cells = cells_given
for c, cell in enumerate(cells):
suptitles, titles = titles_susept_names(a_fe, extract, noise_added, stim_type_afe_name,
stim_type_noise_name2,
trials_stim, var_items, var_type)
if len(cells) > 1:
titles = cell + ' ' + titles
add_nonlin_title, cbar, fig, stack_plot, im = plt_single_square_modl(ax[a], cell, model, perc, titles,
width, titles_plot)
ims.append(im)
maxs.append(np.max(np.array(stack_plot)))
mins.append(np.min(np.array(stack_plot)))
if a in np.arange(col - 1, 100, col):
cbar.set_label(label, labelpad=lp) # rotation=270,
if new_plot:
if a >= row * col - col:
ax[a].set_xlabel(F1_xlabel(), labelpad=20)
if len(cells) > 1:
a += 1
ax[0].set_ylabel(F2_xlabel())
if a in np.arange(0, len(ax), 1) * col:
if a < len(ax):
try:
ax[a].set_ylabel(F2_xlabel())
except:
print('ax a thing')
embed()
if len(cells) == 1:
a += 1
if titles_plot:
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title) # +' file '
if share:
set_clim_same(clims=clims, ims=ims, maxs=maxs, mins=mins, lim_type='', nr_clim='10', perc='')
improved = False
if improved:
set_clim_same(ims, lim_type='up')
if new_plot:
if col < 3:
fig.tag(ax, xoffs=-3, yoffs=5.8)
else:
if row == 2:
fig.tag([ax_orig[0, :], ax_orig[1, :]], xoffs=-5.5, yoffs=3.8)
else:
fig.tag([ax_orig[0, :], ax_orig[1, :], ax_orig[2, :]], xoffs=-3, yoffs=3.8)
if visualize:
save_visualization(pdf=True)
if show:
plt.show()
def plt_single_square_modl(ax, cell, model, perc, titles, width, bias_factor=1, fr_print=False, eod_metrice=True, nr=3,
titles_plot=False, xpos=1.1, resize=False, ls=8):
model_show, stack_plot, stack_plot_wo_norm = get_stack(cell, model, bias_factor=bias_factor)
print(np.max(np.max(stack_plot)))
#embed()
if resize:
stack_plot, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_plot)
else:
add_nonlin_title = ''
try:
ax.set_xlim(0, 300)
except:
print('aa thing')
embed()
ax.set_ylim(0, 300)
ax.set_aspect('equal')
cbar = []
im = []
if len(model_show) > 0:
if fr_print:
add_here = '\n fr$_{S}$=' + str(int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz cv$_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2))
else:
add_here = ''
if titles_plot:
ax.text(xpos, 1.05, titles + add_here, ha='right',
transform=ax.transAxes) # , fontsize7= + cell_type# cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str(
im = plt_RAM_perc(ax, perc, stack_plot)
#print(np.max(np.max(stack_plot)))
#embed()
plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300, model_show.eod_fr.iloc[0],
eod_metrice=eod_metrice, nr=nr)
ax.set_aspect('equal')
fig = plt.gcf()
cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, ls=ls, shrink=0.6, width=width) # 0.02
return add_nonlin_title, cbar, fig, stack_plot, im
def get_stack(cell, model, bias_factor=1):
try:
model_show = model[(model.cell == cell)]
except:
print('cell something')
embed()
stack_plot_wo_norm = change_model_from_csv_to_plots(model_show)
stack_plot = RAM_norm(stack_plot_wo_norm, model_show=model_show, bias_factor=bias_factor)
return model_show, stack_plot, stack_plot_wo_norm
def plt_all_scatter_rotated(ax0, ax1, frame, cell_types, add='', alpha=0.5, s=7, annotate=False,
cell_type_type='cell_type_info'):
frame_g = ptl_fr_cv(add, alpha, annotate, ax0, cell_type_type, cell_types, frame, s)
colors = colors_overview()
for c, cell_type in enumerate(cell_types):
vs = np.array(list(map(float, np.array(frame_g['vs' + add]))))
cv = np.array(list(map(float, np.array(frame_g['cv' + add]))))
exclude = np.isnan(cv) | np.isnan(vs)
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax1.text(frame_g_ex['vs' + add].iloc[f], frame_g_ex['cv' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax0.set_ylim(0, 1.5)
ax1.scatter(frame_g['vs' + add], frame_g['cv' + add], alpha=alpha, label=cell_type, s=s,
color=colors[str(cell_type)]) #
ax1.set_ylim(0, 1.5)
ax1.set_xlim(0, 1)
if 'burst' in add:
ax1.set_ylabel('$CV_{Burst Corr}$')
ax0.set_ylabel('$CV_{Burst Corr}$')
else:
ax0.set_ylabel('CV')
ax1.set_ylabel('CV')
ax0.set_xlabel('Base Freq [Hz]')
ax1.set_xlabel('VS')
plt.subplots_adjust(wspace=0.25, bottom=0.1)
def ptl_fr_cv(add, alpha, annotate, ax0, cell_type_type, cell_types, frame, s, color_given=None, cv='cv', fr='fr'):
colors = colors_overview()
for c, cell_type in enumerate(cell_types):
print(cell_type)
frame_g = frame[(frame[cell_type_type] == cell_type) & ((frame.gwn == True) | (frame.fs == True))]
if not color_given:
color_given = colors[str(cell_type)]
try:
ax0.scatter(frame_g[fr + add], frame_g[cv + add], alpha=alpha, label=cell_type, s=s,
color=color_given, clip_on=True)
except:
print('scatter thing')
embed()
print('mean(' + str(fr + add) + str(np.mean(frame_g[fr + add])) + ' ' + 'mean(' + str(cv + add) + str(
np.mean(frame_g[cv + add])))
c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_g, cv + add, cv_name=fr + add,
score=cv + add)
try:
legend_wo_dot(ax0, 0.9 - 0.1 * c, x_axis, y_axis, ha='left', color=color_given, x_pos=0)
except:
print('something')
embed()
exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax0.text(frame_g_ex['fr' + add].iloc[f], frame_g_ex['cv' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
test = False
if test:
pass
return frame_g
def plt_all_scatter(ax0, ax1, frame, cell_types, colors, add='', alpha=0.5, s=7, annotate=False,
cell_type_type='cell_type_info'):
for c, cell_type in enumerate(cell_types):
frame_g = frame[(frame[cell_type_type] == cell_type) & ((frame.gwn == True) | (frame.fs == True))]
ax0.scatter(frame_g['cv' + add], frame_g['fr' + add], alpha=alpha, label=cell_type, s=s,
color=colors[str(cell_type)])
exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['fr' + add])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax0.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['fr' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['vs' + add])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax1.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['vs' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax0.set_xlim(0, 1.5)
ax1.scatter(frame_g['cv' + add], frame_g['vs' + add], alpha=alpha, label=cell_type, s=s,
color=colors[str(cell_type)]) #
ax1.set_xlim(0, 1.5)
ax1.set_ylim(0, 1)
if 'burst' in add:
ax0.set_xlabel('$CV_{Burst Corr}$')
ax1.set_xlabel('$CV_{Burst Corr}$')
ax1.set_ylabel('$VS_{Burst Corr}$')
ax0.set_ylabel('Base Freq [Hz] $_{Burst Corr}$')
else:
ax0.set_xlabel('CV')
ax1.set_xlabel('CV')
ax0.set_ylabel('Base Freq [Hz]')
ax1.set_ylabel('VS')
if add == '':
ax0.legend(ncol=5, loc=(0, 1.05))
plt.subplots_adjust(wspace=0.25, bottom=0.1)
def plt_all_width_rotated(frame, cell_types, frame_cell, add, gg, cell_type, ax2, annotate=False, alpha=1, xlim=[0, 25],
s=15):
colors = colors_overview()
if 'width_75' + add[gg] in frame_cell.keys():
ax2.scatter(frame_cell['width_75' + add[gg]], frame_cell['width_75' + add[gg]], alpha=1, label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
for c, cell_type in enumerate(cell_types):
# frame_all = frame[(frame['cell_type_info'] == cell_type)]
#
frame_g = frame[(frame['cell_type_reclassified'] == cell_type) & ((frame.gwn == True) | (frame.fs == True))]
ax2.scatter(frame_g['width_75' + add[gg]], frame_g['cv' + add[gg]], alpha=alpha, label=cell_type, s=s,
color=colors[str(cell_type)])
exclude = np.isnan(frame_g['cv' + add[gg]]) | np.isnan(frame_g['width_75' + add[gg]])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax2.text(frame_g_ex['width_75' + add[gg]].iloc[f], frame_g_ex['cv' + add[gg]].iloc[f],
frame_g_ex.cell.iloc[f][2:13], rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax2.set_ylim(0, 1.5)
if 'burst' in add[gg]:
ax2.set_ylabel('$CV_{Burst Corr}$')
else:
ax2.set_ylabel('CV')
ax2.set_xlabel('Width at 75 %')
ax2.set_xlim(xlim)
def plt_all_width(frame, cell_types, frame_cell, add, gg, colors, cell_type, ax2, annotate=False, alpha=1, s=15):
if 'width_75' + add[gg] in frame_cell.keys():
ax2.scatter(frame_cell['width_75' + add[gg]], frame_cell['width_75' + add[gg]], alpha=1, label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
for c, cell_type in enumerate(cell_types):
frame_g = frame[(frame['cell_type_reclassified'] == cell_type) & ((frame.gwn == True) | (frame.fs == True))]
ax2.scatter(frame_g['cv' + add[gg]], frame_g['width_75' + add[gg]], alpha=alpha, label=cell_type, s=s,
color=colors[str(cell_type)])
exclude = np.isnan(frame_g['cv' + add[gg]]) | np.isnan(frame_g['width_75' + add[gg]])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax2.text(frame_g_ex['cv' + add[gg]].iloc[f], frame_g_ex['width_75' + add[gg]].iloc[f],
frame_g_ex.cell.iloc[f][2:13], rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax2.set_xlim(0, 0.9)
if 'burst' in add[gg]:
ax2.set_xlabel('$CV_{Burst Corr}$')
ax2.set_ylabel('Width at 75 % $_{Burst Corr}$')
else:
ax2.set_ylabel('Width at 75 %')
ax2.set_xlabel('CV')
ax2.set_ylim(0, 25)
def plt_scatter_three2(grid2, frame, cell_type_type, annotate, colors, cell_types=[' P-unit', ' Ampullary'],
add=['', '_burst_corr_individual']):
ax0 = plt.subplot(grid2[0])
# ok hier plotten wir nur den scatter der auch ein gwn hat, aber was ist wenn es mehr sind?
# ok im prinzip sollte das zwar schon stimmen aber für das Bild kann man wirklich mehr machen
for c, cell_type_it in enumerate(cell_types):
frame_g = frame[
(frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))]
plt_cv_fr(annotate, ax0, add[0], frame_g, colors, cell_type_it)
ax1 = plt.subplot(grid2[1])
for c, cell_type_it in enumerate(cell_types):
frame_g = frame[
(frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))]
plt_cv_vs(frame_g, ax1, add[0], annotate, colors, cell_type_it)
ax2 = plt.subplot(grid2[2])
for c, cell_type_it in enumerate(cell_types):
frame_g = frame[
(frame[cell_type_type] == cell_type_it) & ((frame.gwn == True) | (frame.fs == True))]
plt_cv_fr(annotate, ax2, add[1], frame_g, colors, cell_type_it)
ax2.set_ylabel('Base Freq [Hz] $_{Burst Corr}$')
ax2.set_xlabel('$CV_{Burst Corr}$')
return ax0, ax1, ax2
def plt_cell_body_single_amp(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type,
ax3=[]):
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
cell_type = frame_cell[cell_type_type].iloc[0]
spikes = frame_cell.spikes.iloc[0]
fr = frame_cell.fr.iloc[0]
cv = frame_cell.cv.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
# cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
# also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben!
if cont_spikes:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
if np.min(np.concatenate(hists)) < 1.5:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all))
hists_both = [hists, hists2]
else:
hists_both = [hists]
# das ist der title fals der square nicht plottet
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz',
fontsize=11, ) # cell[0:13] + color=color+ cell_type
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
file_names_exclude = file_names_to_exclude(cell_type)
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
for a, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
grid_s = gridspec.GridSpecFromSubplotSpec(5, 1, grid1[c],
height_ratios=[1.5, 1.5, 5, 1.5, 1.5, ],
hspace=0)
axs = plt.subplot(grid_s[2])
except:
print('grid problem5')
embed()
cbar, mat, im = plot_square_core([axs], stack_final1)
if a == len(amps) - 1:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
fr = stack_final1.fr.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns),
np.max(mat.columns), eod_fr, fr, fr_stim, stack_final,
)
if c == 0:
axi.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length)
if a != 0:
axi.set_ylabel('')
# do the scatter of these cells
add = ['', '_burst_corr', ]
if type(ax0) != list:
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if type(ax1) != list:
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if type(ax2) != list:
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if ax3 != []:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type)
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
################################
# do the hist
alpha = 1
axi = plt.subplot(grid_s[-1])
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
try:
for gg in range(len(hists_both)):
hists_here = hists_both[gg]
for hh, h in enumerate(hists_here):
try:
axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh))
except:
print('alpha problem5')
axi.set_title(
'CV ' + str(np.round(np.std(h) / np.mean(h), 3)) + ' ' + cell) # +' VS '+str(vs)
axi.set_xlabel('isi')
except:
print('hists not there yet')
def plt_cell_body3(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=[],
xlim=[]):
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
embed()
spikes = frame_cell.spikes.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
# cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
# also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben!
if cont_spikes:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
if np.min(np.concatenate(hists)) < 1.5:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all))
hists_both = [hists, hists2]
else:
hists_both = [hists]
# das ist der title fals der square nicht plottet
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
file_names_exclude = file_names_to_exclude(cell_type)
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
for a, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a + 1],
height_ratios=[1.5, 1.5, 5],
hspace=0)
axs = plt.subplot(grid_s[2])
except:
print('grid problem4')
embed()
cbar, mat, im = plot_square_core([axs], stack_final1)
if xlim:
axs.set_xlim(xlim)
axs.set_ylim(xlim)
if a == len(amps) - 1:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
fr = stack_final1.fr.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns),
np.max(mat.columns), eod_fr, fr, fr_stim, stack_final,
)
if c == 0:
axo.set_title(' $std=$' + str(amp) + ' %') # files[0] + ' l ' + str(length)
if a != 0:
axi.set_ylabel('')
axo.set_ylabel('')
axs.set_ylabel('')
if c != 2:
axs.set_xlabel('')
remove_xticks(axi)
if a == 1:
axo.set_title(cell) # +' VS '+str(vs)
################################
# do the scatter of these cells
add = ['', '_burst_corr', ]
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if ax3 != []:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type)
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
################################
# do the hist
alpha = 1
axi = plt.subplot(grid1[c, 0])
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
for gg in range(len(hists_both)):
hists_here = hists_both[gg]
for hh, h in enumerate(hists_here):
try:
axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh),
label='CV ' + str(
np.round(np.std(h) / np.mean(h), 3)))
except:
print('alpha problem6')
embed()
axi.legend()
axi.set_xlim(0, 13)
if c != len(cells_plot) - 1:
remove_xticks(axi)
else:
axi.set_xlabel('isi')
def plt_cell_body(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type, ax3=[],
xlim=[]):
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
embed()
spikes = frame_cell.spikes.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
# cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
# also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben!
if cont_spikes:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
if np.min(np.concatenate(hists)) < 1.5:
_, _, _ = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all))
else:
pass
# das ist der title fals der square nicht plottet
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = punit_file_exclude() #
else:
file_names_exclude = ampullary_file_exclude() #
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
for a, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a],
height_ratios=[1.5, 1.5, 5],
hspace=0)
axs = plt.subplot(grid_s[2])
except:
print('grid problem3')
embed()
cbar, mat, im = plot_square_core([axs], stack_final1)
if xlim:
axs.set_xlim(xlim)
axs.set_ylim(xlim)
if a == len(amps) - 1:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
fr = stack_final1.fr.unique()[0]
snippets = stack_final1['snippets'].unique()[0]
fr1 = np.unique(stack_final1.fr)
cv = stack_final1.cv.unique()[0]
ser = stack_final1.ser.unique()[0]
cv_stim = stack_final1.cv_stim.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
ser_stim = stack_final1.ser_stim.unique()[0]
axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns),
np.max(mat.columns), eod_fr, fr, fr_stim, stack_final,
)
if c == 0:
axo.set_title(' $std=$' + str(amp) + ' %') # files[0] + ' l ' + str(length)
if a != 0:
axi.set_ylabel('')
axo.set_ylabel('')
axs.set_ylabel('')
if c != 2:
axs.set_xlabel('')
remove_xticks(axi)
if a == 1:
axo.set_title(cell) # +' VS '+str(vs)
################################
# do the scatter of these cells
add = ['', '_burst_corr', ]
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if ax3 != []:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type)
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
def base_to_stim(load_name, frame, cell_type_type, cell_type_it, stack=[]):
if len(stack) == 0:
if os.path.exists(load_name):
stack_stim = pd.read_csv(load_name, low_memory=False)
else:
stack_stim = stack
cells = frame[frame[cell_type_type] == cell_type_it].cell.unique()
frame_gr = stack_stim[stack_stim.cell.isin(cells)]
frame1 = frame_gr['cell']
frame_g = frame_gr.loc[frame1.drop_duplicates().index]
return frame_g
def plt_cv_vs(frame_g, ax1, add, annotate, colors, cell_type):
ax1.scatter(frame_g['cv' + add], frame_g['vs' + add], alpha=0.5, label=cell_type, s=7, color=colors[str(cell_type)])
exclude = np.isnan(frame_g['cv' + add]) | np.isnan(frame_g['vs' + add])
frame_g_ex = frame_g[~exclude]
if annotate:
for f in range(len(frame_g_ex)):
ax1.text(frame_g_ex['cv' + add].iloc[f], frame_g_ex['vs' + add].iloc[f], frame_g_ex.cell.iloc[f][2:13],
rotation=45,
color=colors[str(cell_type)], fontsize=6)
ax1.set_xlim(0, 1.5)
ax1.set_ylim(0, 1)
ax1.set_xlabel('CV')
ax1.set_ylabel('VS')
def plt_data_up(cell, ax, fig, cells_chosen, cell_type='p-unit', width=0.005, cbar_label=True):
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30',
'InputArr_50to100hz_30', 'gwn25Hz10s0.3', 'InputArr_50hz_30',
'FileStimulus-file-gaussian50.0', 'gwn50Hz10.3', 'gwn50Hz10s0.3short',
'gwn50Hz50s0.3',
'FileStimulus-file-gaussian25.0', 'gwn50Hz10s0.3', ] #
else:
file_names_exclude = ['blwn125Hz10s0.3', 'gwn50Hz10s0.3', 'InputArr_350to400hz_30',
'InputArr_250to300hz_30', 'InputArr_150to200hz_30',
'InputArr_50to100hz_30', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0',
'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3',
'gwn50Hz10s0.3short',
'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] #
if len(cells_chosen) > 0:
cells = cells_chosen
col = 4
_, _, = find_row_col(cells, col=col)
ax_data = []
if cell == '2012-07-03-ak-invivo-1':
save_name = 'noise_data9_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s' # _burst_corr
else:
save_name = 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s' # _burst_corr
load_name = load_folder_name('calc_RAM') + '/' + save_name
ax_data.append(ax)
#########################################
# also die einzelzellen sind in pkls
stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell)
try:
stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)]
except:
print('stack cell problem')
stack_cell = []
if len(stack_cell):
file_names = stack_cell.file_name.unique()
cut_off_nr = []
for ff, file_name in enumerate(file_names):
if 'hz' in file_name.lower():
cut_off_nr = get_cut_off_for_wn(cut_off_nr, file_name)
elif 'gaussian' in file_name:
cut_off_nr.append(file_name.split('gaussian')[1])
else:
cut_off_nr.append(file_name[-5::])
try:
maxs = list(map(float, cut_off_nr))
except:
print('maxs something')
embed()
file_names = file_names[np.argmax(maxs)]
stack_file = stack_cell[stack_cell['file_name'] == file_names]
amps = [np.min(stack_file.amp.unique())]
amps = restrict_punits(cell, amps)
for amp in amps:
stack_amps = stack_file[stack_file['amp'] == amp]
lengths = stack_amps.stimulus_length.unique()
try:
length_max = [np.max(lengths)]
except:
print('length something')
embed()
for length in length_max:
stack_final = stack_amps[stack_amps['stimulus_length'] == length]
if len(stack_final) < 1:
embed()
snippets = stack_final['snippets'].unique()[0]
eod_fr = stack_final.eod_fr.unique()[0]
cv = stack_final.cv.unique()[0]
fr = stack_final.fr.unique()[0]
cv_stim = stack_final.cv_stim.unique()[0]
fr_stim = stack_final.fr_stim.unique()[0]
ax.set_title(
cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str(
snippets) + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' std ' + str(
amp) + ' % ' +
'\n $cv_{B}=$' + str(np.round(cv, 2)) + ',$f_{B}=$' + str(np.round(fr)) + 'Hz'
+ ',$cv_{S}=$' + str(np.round(cv_stim, 2)) + ',$f_{S}=$' + str(
np.round(fr_stim)) + 'Hz'
, fontsize=7) # + cell_type
stack_plot = stack_final
keys = stack_plot.keys()
new_keys = stack_plot.index
try:
stack_plot = stack_plot[new_keys]
except:
new_keys = list(map(str, new_keys))
stack_plot = stack_plot[new_keys]
stack_plot = stack_plot.astype(complex)
stack_plot.columns = list(map(float, stack_plot.columns))
mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0])
plot = True
if plot:
pcolor = False
if pcolor:
im = ax.pcolormesh(np.array(mat.index),
np.array(list(map(float, mat.columns))), mat,
vmin=np.nanpercentile(mat, 5),
vmax=np.nanpercentile(mat, 95), cmap='Greens',
rasterized=True
) # rasterized = True
else:
im = ax.imshow(mat, origin='lower',
extent=[float(np.min(mat.columns)),
float(np.max(mat.columns)),
float(np.min(mat.index)), float(np.max(mat.index))],
vmin=np.nanpercentile(mat, 5),
vmax=np.nanpercentile(mat, 95),
cmap='viridis', ) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95),
plt.suptitle(cell_type)
ax.set_xlim(float(np.min(mat.index)), float(np.max(mat.index)))
ax.set_ylim(float(np.min(mat.index)), float(np.max(mat.index)))
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_aspect('equal')
plt_triangle(ax, fr, fr_stim, new_keys[-1], eod_fr)
plt_50_Hz_noise(ax, new_keys[-1])
if plot:
cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, shrink=0.6,
width=width) # 0.02
if cbar_label:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
return ax_data
def plt_data_susept(fig, grid, cells_chosen, eod_metrice=True, fr_print=False, amp_given=None, nr=3, cell_type='p-unit',
xlabel=True, lp=10, title=True, cbar_label=True, xpos=1.1, width=0.005, n_print = True):
file_names_exclude = get_file_names_exclude(cell_type)
if len(cells_chosen) > 0:
cells = cells_chosen
ax_data = []
stack_spikes_all = []
eod_frs = []
for f, cell in enumerate(cells):
ax = plt.subplot(grid[f])
ax_data.append(ax)
eod_fr, stack_spikes = plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp,
title, width, fr_print=fr_print, nr=nr, eod_metrice=eod_metrice,
amp_given=amp_given, n_print = n_print, xpos=xpos, xlabel=xlabel)
stack_spikes_all.append(stack_spikes)
eod_frs.append(eod_fr)
return ax_data, stack_spikes_all, eod_frs
def plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp, title, width, fr_print=False,
eod_metrice=True, xpos = 1.1, ypos=1.05, n_print = True, nr=3, xlabel=True, amp_given=None):
if cell == '2012-07-03-ak-invivo-1':
pass
else:
pass
save_name = version_final() # ]
load_name = load_folder_name('calc_RAM') + '/' + save_name
#########################################
# also die einzelzellen sind in pkls
add = '_cell' + cell # str(f) # + '_amp_' + str(amp)
stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell, add=add,
load_version='csv')
try:
stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)]
except:
print('stack cell problem')
stack_cell = []
if len(stack_cell):
file_names = stack_cell.file_name.unique()
file_names2 = exclude_file_name_short(file_names)
cut_off_nr = get_cutoffs_nr(file_names2)
try:
maxs = list(map(float, cut_off_nr))
except:
print('error1')
embed()
file_names2 = file_names2[np.argmax(maxs)]
try:
stack_file = stack_cell[stack_cell['file_name'] == file_names2]
except:
print('stack file something')
embed()
amps = [np.min(stack_file.amp.unique())]
amps = restrict_punits(cell, amps)
for amp in amps:
stack_amps = stack_file[stack_file['amp'] == amp]
lengths = stack_amps.stimulus_length.unique()
try:
length_max = [np.max(lengths)]
except:
print('length thing')
embed()
for length in length_max:
trial_nr_double = stack_amps.trial_nr.unique()
trial_nr = np.max(trial_nr_double)
stack_final = stack_amps[
(stack_amps['stimulus_length'] == length) & (stack_amps.trial_nr == trial_nr)]
stack_spikes = load_data_susept(load_name + '_' + cell + '.pkl', load_name, load_version='csv',
load_type='spikes', add=add, trial_nr=trial_nr,
stimulus_length=length,
amp=amp, file_name=file_names2)
snippets = stack_final['snippets'].unique()[0]
eod_fr = stack_final.eod_fr.unique()[0]
fr = stack_final.fr.unique()[0]
cv_stim = stack_final.cv_stim.unique()[0]
fr_stim = stack_final.fr_stim.unique()[0]
if title:
if amp_given:
amp = amp_given
if n_print:
add = '\n $N = % s$' % snippets
else:
add = ''
if fr_print:
add += '\n fr$_{S}$=' + str(int(np.round(fr_stim))) + 'Hz' + ' cv$_{S}$=' + str(
np.round(cv_stim, 2))
else:
add += ''
ax.text(xpos, ypos, 'Recorded P-unit' + add, ha='right',
transform=ax.transAxes) # , fontsize7= + cell_type# cell[0:13] + stack_final.celltype.unique()[0] + 'S.Nr ' + str(
mat, new_keys = get_mat_susept(stack_final)
mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat)
im, plot = plt_mat_susept(ax, mat)
if f == len(cells) - 1:
ax.set_xticks_delta(100)
if xlabel:
set_xlabel_arrow(ax, xpos=xpos)
else:
remove_xticks(ax)
ax.set_xlim(float(np.min(mat.index)), float(np.max(mat.index)))
ax.set_ylim(float(np.min(mat.index)), float(np.max(mat.index)))
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_aspect('equal')
plt_triangle(ax, fr, fr_stim, new_keys[-1], eod_fr, lines=False, eod_metrice=eod_metrice, nr=nr)
set_clim_same([im], mats=[mat], lim_type='up', nr_clim='perc', clims='', percnr=perc_model_full())
if plot:
cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, shrink=0.6,
width=width) # 0.02
if cbar_label:
cbar.set_label(nonlin_title(' [' + add_nonlin_title), rotation=90, labelpad=lp)
else:
stack_spikes = []
eod_fr = []
return eod_fr, stack_spikes
def set_xlabel_arrow(ax, xpos=1.05, ypos=-0.35, color='black', arrow = False):
val = F1_xlabel()
set_xlabel_arrow_core(ax, val, xpos, ypos, color=color)
if arrow:
ax.arrow_spines('b')
def exclude_file_name_short(file_names):
file_names2 = []
for file in file_names:
if 'short' not in file:
file_names2.append(file)
return file_names2
def plt_mat_susept(ax, mat):
plot = True
if plot:
pcolor = False
if pcolor:
im = ax.pcolormesh(np.array(mat.index),
np.array(list(map(float, mat.columns))), mat,
vmin=np.nanpercentile(mat, 5),
vmax=np.nanpercentile(mat, 95), cmap='Greens',
rasterized=True
) # rasterized = True
else:
im = ax.imshow(mat, origin='lower',
extent=[float(np.min(mat.columns)),
float(np.max(mat.columns)),
float(np.min(mat.index)), float(np.max(mat.index))],
vmin=np.nanpercentile(mat, 5),
vmax=np.nanpercentile(mat, 95),
cmap='viridis', ) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95),
return im, plot
def get_mat_susept(stack_final):
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
norm_d = False
if norm_d:
mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0])
else:
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
return mat, new_keys
def get_file_names_exclude(cell_type='p-unit'):
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = ['InputArr_350to400hz_30', 'InputArr_250to300hz_30', 'InputArr_150to200hz_30',
'InputArr_50to100hz_30', 'gwn25Hz10s0.3', 'InputArr_50hz_30',
'FileStimulus-file-gaussian50.0', 'gwn50Hz10.3', 'gwn50Hz10s0.3short',
'gwn50Hz50s0.3',
'FileStimulus-file-gaussian25.0', 'gwn50Hz10s0.3', ] #
else:
file_names_exclude = ['blwn125Hz10s0.3', 'gwn50Hz10s0.3', 'InputArr_350to400hz_30',
'InputArr_250to300hz_30', 'InputArr_150to200hz_30',
'InputArr_50to100hz_30', 'InputArr_50hz_30', 'FileStimulus-file-gaussian50.0',
'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3',
'gwn50Hz10s0.3short',
'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] #
return file_names_exclude
def get_cutoffs_nr(file_names):
cut_off_nr = []
for ff, file_name in enumerate(file_names):
if 'hz' in file_name.lower():
cut_off_nr = get_cut_off_for_wn(cut_off_nr, file_name)
elif 'gaussian' in file_name:
cut_off_nr.append(file_name.split('gaussian')[1])
else:
cut_off_nr.append(file_name[-5::])
return cut_off_nr
def find_eod(frame_cell, EOD='EOD', sp=0):
if EOD in frame_cell:
eods, hists, frs_calc, cont = load_spikes(frame_cell[EOD].iloc[0], frame_cell['EODf'].iloc[0])
try:
eod = eods[sp]
except:
print('eod sp thing')
embed()
sampling_rate = frame_cell.sampling.iloc[0]
ds = int(frame_cell.downsample.iloc[0])
time_eod = np.arange(0, len(eod) / sampling_rate, 1 / sampling_rate) # [::ds]
if len(time_eod) > len(eod):
time_eod = time_eod[0:len(eod)]
elif len(time_eod) < len(eod):
eod = eod[0:len(time_eod)]
return eod, sampling_rate, ds, time_eod
def plot_lin_nonlin(aa, add, amp, amps_defined, axds, axos, c, cells_plot, file_name, grid_s1, ims,
load_name, stack_file, xlim=[], test_clim=False, power_type=False,
permuted=False, peaks_extra=False, zorder=1, alpha=1, extra_input=False, fr=None,
title_square='', fr_diag=None, nr=1, line_length=1 / 4, text_scalebar=False,
xpos_xlabel=-0.2, add_nonlin_title=None, amp_give=True, color='grey',
axo2=None, axd2=None, axi=None, eod_metrice=True, ax_square=None,
transfer=True, base_extra=False, color_same=True, snippets = 20, iterate_var=[0, 1], normval=1):
if not fr_diag:
fr_diag = fr
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file, snippets = snippets)
stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='osf', trial_nr=trial_nr,
stimulus_length=length, add=add, amp=amp, file_name=file_name)
stack_isf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='isf', trial_nr=trial_nr,
stimulus_length=length, add=add, amp=amp, file_name=file_name)
test_limits = False
# hier bereinige ich von Duplicates
add_nonlin_title, stack_plot = reduce_dubplicates(add_nonlin_title, stack_final)
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat, add_nonlin_title=add_nonlin_title)
axis_d = axis_projection(mat, axis='')
if power_type:
##################################
# das ist wenn wir das psd geplottet haben wollen
axd, _, axo2 = plt_psds_all(axd2, axo2, mat,
stack_final, stack_osf, test_limits, xlim,
color=color, alpha=alpha,
db='db', fr=fr, power_type=power_type, zorder=zorder, eod_fr=eod_fr,
peaks_extra=peaks_extra)
else:
###################################
# das ist jetzt die DEFAULT version
# plot the diagonal
db_diag = 'db' #
try:
xmax, xmin, diagonals_prj_l = plt_diagonal(axd2, color, db_diag, fr_diag, mat, alpha, eod_fr, peaks_extra,
xlim=xlim,
zorder=zorder, normval=normval, color_same=color_same)
prob = False
except:
print('diagonal of mat not working')
diagonals_prj_l = []
prob = True
if permuted:
add_nonlin_title = plt_permuted_diagonal(add_nonlin_title, axd2, axis_d, db_diag, stack_final)
# ###################################
# plt transferfunction
axd = axd2
if transfer:
plt_transferfunction(alpha, axo2, color, stack_final, label=title_square, normval=normval,
zorder=zorder)
xmax_tf = 400
if normval != 1:
axo2.set_xlim(xmin, xmax_tf / eod_fr)
else:
axo2.set_xlim(xmin, xmax / 2)
axos.append(axo2) # np.max(mat.columns)
axds.append(axd) # np.max(mat.columns)
########################################
# plot input if we need it
# NOT DEFAULT
if extra_input:
axis_d = axis_projection(mat, axis='')
xmax, xmin = get_xlim_psd(axis_d, xlim)
plt_power_trace(alpha, axi, color, 'db', stack_final, stack_isf, test_limits, xmax,
eod_fr=eod_fr)
axi.set_xlim(xmin, xmax)
#############################
# plot second-order susceptibility
if not ax_square:
ax_square = plt.subplot(grid_s1[:, 2 + aa])
if (aa == len(iterate_var) - 1) | test_clim:
cbar_true = True
else:
cbar_true = False
# embed()
mat, test_limits, im, add_nonlin_title = plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims,
stack_final1, [], perc=False, cbar_true=cbar_true, xpos=0,
ypos=1.05,
color=color, fr=fr, base_extra=base_extra,
eod_metrice=eod_metrice, nr=nr, amp_give=amp_give,
title_square=title_square, line_length=line_length,
ha='left', xpos_xlabel=xpos_xlabel, alpha=alpha,
add_nonlin_title=add_nonlin_title)
ims.append(im)
if text_scalebar:
if (aa == len(iterate_var) - 1) | test_clim:
fig = plt.gcf()
_, _, _, _, _ = colorbar_outside(ax_square, im, fig, add=5, width=0.01)
ax_square.text(1.45, 0.25, nonlin_title(' [' + add_nonlin_title), ha='center', rotation=90,
transform=ax_square.transAxes)
if prob:
print('prob something')
embed()
return diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title
def reduce_dubplicates(add_nonlin_title, stack_final):
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
duplicate_mask = stack_final.duplicated(subset=new_keys)
if duplicate_mask.any():
stack_final.drop_duplicates(subset=new_keys, inplace=True)
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
test = False
if test:
from utils_test import test_dublicates1, test_dublicates2
add_nonlin_title = test_dublicates1(stack_final)
test_dublicates2(stack_final)
return add_nonlin_title, stack_plot
def plt_permuted_diagonal(add_nonlin_title, axd2, axis_d, db_diag, stack_final):
add_nonlin_title, isfs_all, isfs_correct, mats_all, mats_all_correct2 = get_fft_matrices(stack_final,
add_nonlin_title)
_, _ = get_mat_diagonals(mats_all_correct2)
if db_diag == 'db':
pass
mats_all = np.array(mats_all)
diags_permuted = []
isfs_all = np.array(isfs_all)
for i in range(300):
random_numbers = sample(range(1, len(mats_all)), 20)
mean_matrix = np.sum(mats_all[random_numbers], axis=0)
mean_matrix2 = norm_suscept_whole(abs, isfs_all[random_numbers], stack_final, mean_matrix,
len(isfs_correct))
mean_matrix2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mean_matrix2,
add_nonlin_title=add_nonlin_title)
diag, diagonals_prj_l_perm = get_mat_diagonals(np.array(mean_matrix2))
if db_diag == 'db':
diagonals_prj_l_perm = 10 * np.log10(diagonals_prj_l_perm)
diags_permuted.append(diagonals_prj_l_perm)
diags_perm = np.transpose(diags_permuted)
axd2.plot(axis_d, np.percentile(diags_perm, 95, axis=1), color='darkgrey')
axd2.plot(axis_d, np.percentile(diags_perm, 5, axis=1), color='darkgrey')
return add_nonlin_title
def get_fft_matrices2(stack_final, add_nonlin_title=''):
isfs = get_isfs(stack_final, isf_name='isf')
osfs = get_isfs(stack_final, isf_name='osf')
f_range = np.arange(len(stack_final))
mats_all_correct = []
isfs_correct = []
for t in range(len(osfs)):
print('t' + str(t))
f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osfs[t], f_range, isfs[t],
norm='') # stimulus,
mats_all_correct.append(mat_all)
isfs_correct.append(isfs[t])
#########################
# the corrected matrices
mats_all_correct = np.array(mats_all_correct)
mats_all_correct2 = np.sum(np.array(mats_all_correct), axis=0)
mats_all_correct2 = norm_suscept_whole(abs, isfs_correct, stack_final, mats_all_correct2, len(isfs_correct))
mats_all_correct2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mats_all_correct2,
add_nonlin_title=add_nonlin_title)
return add_nonlin_title, isfs_correct, mats_all_correct2
def get_fft_matrices(stack_final, add_nonlin_title=''):
isfs = get_isfs(stack_final, isf_name='isf')
osfs = get_isfs(stack_final, isf_name='osf')
f_range = np.arange(len(stack_final))
mats_all = []
mats_all_correct = []
isfs_correct = []
isfs_all = []
for t in range(len(osfs)):
for tt in range(len(osfs)):
print('t' + str(t) + ' tt' + str(tt))
f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osfs[t], f_range, isfs[tt],
norm='') # stimulus,
if t != tt:
mats_all.append(mat_all)
isfs_all.append(isfs[tt])
else:
mats_all_correct.append(mat_all)
isfs_correct.append(isfs[tt])
#########################
# the corrected matrices
mats_all_correct = np.array(mats_all_correct)
mats_all_correct2 = np.sum(np.array(mats_all_correct), axis=0)
mats_all_correct2 = norm_suscept_whole(abs, isfs_correct, stack_final, mats_all_correct2, len(isfs_correct))
mats_all_correct2, add_nonlin_title, resize_val = rescale_colorbar_and_values(mats_all_correct2,
add_nonlin_title=add_nonlin_title)
return add_nonlin_title, isfs_all, isfs_correct, mats_all, mats_all_correct2
def plt_psds_in_one_squares(aa, add, amp, amps_defined, axds, axes, axis, axos, c, cells_plot, colors_b, file_name,
files,
grid_s1, grid_s2, ims, load_name, stack_file, wss, xlim,
axo2=None, axd2=None, iterate_var=[0, 1]):
if aa == 0:
try:
axd2 = plt.subplot(grid_s2[1, 0]) # plt.subplot(grid_s[0])
axo2 = plt.subplot(grid_s2[0, 0])
except:
print('grid thing3')
embed()
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='osf', trial_nr=trial_nr,
stimulus_length=length, add=add, amp=amp, file_name=file_name)
test_limits = False
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
axd, axi, axo2 = plt_psds_all(axd2, axo2, mat,
stack_final, stack_osf, test_limits, xlim,
color='grey',
db='db')
grid_s = grid_s1
axd = None
axo = None
else:
grid_s = grid_s2
axd = axd2
axo = axo2
ax_square, axi, eod_fr, fr, stack_final1, stack_spikes, im, axd, axo = plt_square_with_psds(aa, amp,
amps_defined,
axes, axis,
c, cells_plot,
files,
grid_s,
ims,
load_name,
stack_file,
xlim,
cbar_true=False,
axd=axd,
axo=axo,
color=
colors_b[aa],
add=add,
file_name=file_name)
axos.append(axo) # np.max(mat.columns)
axds.append(axd) # np.max(mat.columns)
test_limits = False
if test_limits:
axo.set_ylabel('Output')
axd.set_ylabel('Projection')
else:
remove_yticks(axo)
remove_yticks(axd)
if aa == 0:
axo.text(-0.45, 0, 'Output', rotation=90, transform=axo.transAxes)
axd.text(-0.45, 0, 'Projection', rotation=90, transform=axd.transAxes)
axd.yscalebar(-0.1, 0.5, 10, 'dB', va='center', ha='left')
axo.yscalebar(-0.1, 0.5, 10, 'dB', va='center', ha='left')
axd.show_spines('b')
axo.show_spines('b')
ims.append(im)
if aa == len(iterate_var) - 1:
fig = plt.gcf()
cbar, left, bottom, width, height = colorbar_outside(ax_square, im, fig, add=5, width=0.01)
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
return axi, eod_fr, fr, stack_final1, stack_spikes, axds, axos, ax_square, axo2, axd2
def nix_load(cell, stack_final1):
data_dir = 'cells/'
data_name = cell
name_core = load_folder_name('data') + data_dir + data_name
nix_name = name_core + '/' + data_name + '.nix' # '/'
f = nix.File.open(nix_name, nix.FileMode.ReadOnly)
b = f.blocks[0]
try:
names_mt_gwn = stack_final1['names_mt_gwn'].unique()[0]
except:
print('names mt')
embed()
mt = b.multi_tags[names_mt_gwn]
features, id, data_between_2017_2018, mt_ids = find_feature_gwn_id(mt)
dataset, rlx_problem = load_rlxnix(nix_name)
# wir machen das hier für diese rlx only weil ich nur so an den Kontrast komme
spikes_loaded = []
if rlx_problem:
file_name, file_name_save, cut_off, file, sd = find_file_names(nix_name, mt,
names_mt_gwn)
file_extra, idx_c, base_properties, id_names = get_contrasts_over_rlx_calc_RAM(dataset)
dataset.close()
# contrasts_sort_idx = np.argsort(base_properties)
try:
base_properties = base_properties.sort_values(by='c', ascending=False)
except:
print('contrast problem sorting')
embed()
# hier muss ich nochmal nach dem file sortieren!
if data_between_2017_2018 != 'all':
file_name_sorted = base_properties[base_properties.file_name == file_name]
else:
file_name_sorted = base_properties
if len(file_name_sorted) < 1:
print('file_name problem')
embed()
file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1]
# ich sollte auf dem level schon nach dem richtigen filename filtern!
file_name_sorted = file_name_sorted[file_name_sorted['c_orig'] == stack_final1['c_orig'].unique()[0]]
grouped = file_name_sorted.groupby('c')
# ok es gibt wohl eine Zelle die erste, Zelle '2010-06-15-af' wo eben das nicht input arr heißt sondern gwn 300, was da passiert ist kann ich
# euch jetzt so auch nicht sagen, aber alle anderen Zellen sehen gut aus! Scheint die einzige zu sein°
data_array_names = get_data_array_names(b) # ,find_indices_to_match_contrats,get_data_array_names
if 'eod' in ''.join(data_array_names).lower():
for g, group in enumerate(grouped):
# hier erstmal alles nach dem Kontrast sortieren
sd, start, end, rep, cut_off, c_len, c_unit, c_orig, c_len, files_load, cc, id_g, amplsel = open_group_gwn(
group,
file_name,
cut_off,
sd,
data_between_2017_2018)
indices, ends_mt = find_indices_to_match_contrats(grouped, group, mt, id_g, mt_ids,
data_between_2017_2018)
indices = list(map(int, indices))
max_f = cut_off
if max_f == 0:
print('max f = 0')
embed()
for mm, m in enumerate(indices):
first, minus, second, stimulus_length = find_first_second(b, names_mt_gwn, m, mt,
False,
mm=mm, ends_mt=ends_mt)
spikes_mt = link_arrays_spikes(b, first,
second, minus) #
spikes_loaded.append(spikes_mt * 1000)
eod_mt, sampling = link_arrays_eod(b, first,
second,
array_name='LocalEOD-1')
# hier noch das stimpresaved laden
else:
print('rlx thing')
return eod_mt, sampling, spikes_loaded
def burst_data():
plot_style()
cells = p_units_to_show(type_here='burst_didactic')
save_names = [version_final()]
# amps_desired, cell_type_type, cells_plot, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class)
cell_type_type = 'cell_type_reclassified'
# frame = load_cv_base_frame(cells, cell_type_type=cell_type_type, redo=True)
default_settings(column=2, width=12, length=8.5) # ts=10, fs=10, ls=10,
frame, frame_spikes = load_cv_vals_susept(cells, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual', 'cell_type_reclassified',
'cell'],
path_spikes='/calc_base_data-base_frame_EOD1__overview.pkl',
frame_general=False)
frame = unify_cell_names(frame, cell_type=cell_type_type)
frame_load = frame # [frame['cell'].isin(cells_exclude)]
colors = colors_overview()
tags_cell = []
grid = gridspec.GridSpec(len(cells), 1, wspace=0.1, hspace=0.21, top=0.97, left=0.105, bottom=0.085, right=0.9)
for c, cell in enumerate(cells):
print(cell)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame)
ims = []
tags = []
add_here = '_cell' + cell # str(c)
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv')
try:
grid_base_stim = gridspec.GridSpecFromSubplotSpec(2, 1, grid[c], height_ratios=[3, 5],
hspace=0.3)
except:
print('cell thing3')
embed()
grid_base = gridspec.GridSpecFromSubplotSpec(2, 2, grid_base_stim[0],
hspace=0.3)
if len(stack) > 0:
files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True)
file_name = files[0]
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
amps_defined = amps
grid_stim = gridspec.GridSpecFromSubplotSpec(len(amps), 1, grid_base_stim[1], hspace=0.3)
trues = []
for amp in amps_defined:
if amp in amps:
trues.append(True)
cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1']
if cell == cells_amp:
print('cell thing')
embed()
ims = []
xlim_e = [0, 200]
for aa, amp in enumerate(amps_defined):
add_save = '_cell' + str(cell) + '_amp_' + str(amp)
alpha_min = (1 - 0.2) / len(np.unique(stack_file['amp'])) # 25
if amp in np.array(stack_file['amp']):
grid_stim_aa = gridspec.GridSpecFromSubplotSpec(2, 2, grid_stim[aa], height_ratios=[3, 5],
hspace=0.3, width_ratios=[5, 2])
ax_square = plt.subplot(grid_stim_aa[:, -1])
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
_, _, _, _ = plt_square_here(aa, amp, amps_defined, ax_square, c,
cells, ims,
stack_final1, [], amp_give=False,
cbar_true=False)
tags.append(ax_square)
######################################################
spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
axe = plt.subplot(grid_stim_aa[0, 0])
plt_stimulus(eod_fr, axe, stack_final1, xlim_e, file_name=files[0])
tags.insert(1, axe)
################################
# spikes
ax_spikes = plt.subplot(grid_stim_aa[1, 0])
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
# todo: hier noch mehr trials laden
stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv',
load_type='spikes',
trial_nr=trial_nr, stimulus_length=length, amp=amp,
file_name=file_name)
test = False
if test:
pass
if aa == 2:
scale = True
else:
scale = False
# embed()
# todo: das mit dem hist will ich dohc noch haben
plt_spikes(c, cells, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes,
alpha=1 - alpha_min * aa, scale=scale)
ax_spikes.text(1, 0.5, str(amp) + '$\%$', transform=ax_spikes.transAxes, )
set_clim_same(ims, clims='all', same='same')
################################
# isi
if len(isi) > 0:
ax_isi = plt.subplot(grid_base[:, 1])
plt_susept_isi_base(colors[str(cell_type)], ax_isi, isi)
tags.insert(0, ax_isi)
cell_type_type = 'cell_type_reclassified'
frame_spikes_cell = frame_spikes[(frame_spikes['cell'] == cell)]
eod, sampling_rate, ds, time_eod = find_eod(frame_spikes_cell, EOD='EOD')
eod_period, zero_crossings, zero_crossings_eod2 = find_mean_period(
eod, sampling_rate)
nrs = 6
spikes_cut, eods_cut, times_cut = cut_spikes_to_certain_lenght_in_period(time_eod, ax_isi, eod, False, nrs,
spikes_base[0], xlim_e,
zero_crossings)
axe = plt.subplot(grid_base[0, 0])
for nr in range(1):
axe.plot(times_cut[nr], eods_cut[nr])
ax_isi = plt.subplot(grid_base[1, 0])
ax_isi.eventplot(spikes_cut)
###############################
# stimulus
tags_cell.append(tags)
fig = plt.gcf()
fig.tag(tags_cell, xoffs=-4.7, yoffs=1.9) # -1.5diese Offsets sind nicht intuitiv
save_visualization()
def plt_cellbody_eigen(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[],
titles=['Baseline \n Susceptibility', 'Half EODf \n Susceptibility'],
peaks_extra=[False, False, False], base_extra=False):
colors = colors_overview()
axis = []
tags_cell = []
lengths = [0.5, 0.25]
for c, cell in enumerate(cells_plot):
print(cell)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame)
ims = []
tags = []
add_here = '_cell' + cell # str(c)
mats = []
xlim_e = [0, 70]
zorders = [100, 50]
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot)
axes = []
if len(stack) > 0:
files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True)
file_name = files[0]
stack_here = stack[stack.trial_nr > 1]
stack_file = stack_here[stack_here['file_name'] == files[0]]
amps = stack_file['amp'].unique()
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
trues = []
for amp in amps_defined:
if amp in amps:
trues.append(True)
# ok das ist jetzt extra für die Bespiele ausgesucht
amps_defined = [20] # [amps[nr_e[c]]]
cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1']
if cell == cells_amp:
print('cell thing')
embed()
ws = 0.5
first = 1
wr_l = [first, 0, 1.3, 0, 1] # 1]
wr_u = [1.15, 0.15, 1.5] # , 1]
ws_total = np.sum(wr_u) + len(wr_u) * ws
ws_total - 2 * ws - 0.15 - 1.5
grid_cell, grid_upper = grids_upper_susept_pics(c, grid1, ws=ws, hr=[1, 0.4], row=2, col=3, wr_u=wr_u)
ws = 0.3
ims = []
axds = []
axos = []
extra_input = False
several = False
axd2, axi, axo2, grid_lower, grid_s1, grid_s2 = grids_for_psds2(amps_defined, extra_input, grid_cell,
several, wr=wr_l, ws=ws, add=1)
add_nonlin_title = None
xpos_xlabel = -0.25
normval = 1
for aa, amp in enumerate(amps_defined):
alpha = find_alpha_val(aa, amps_defined)
add_save = '_cell' + str(cell) + '_amp_' + str(amp)
wss = ws_for_susept_pic()
colors_b = ['grey', colors[cell_type]]
right = False
if amp in np.array(stack_file['amp']):
print(zorders[aa])
if not several:
diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title = plot_lin_nonlin(
aa, add_save, amp, amps_defined, axds, axos, c, cells_plot, file_name,
grid_lower, ims, load_name, stack_file, xlim=[], peaks_extra=peaks_extra[c],
zorder=zorders[aa], alpha=alpha, extra_input=extra_input, line_length=lengths[c],
xpos_xlabel=xpos_xlabel, add_nonlin_title=add_nonlin_title, color=colors[cell_type],
axo2=axo2, axd2=axd2, axi=axi, iterate_var=amps_defined, normval=normval)
mats.append(mat)
else:
axi, eod_fr, fr, stack_final1, stack_spikes, axds, axos, ax_square, axo2, axd2 = plt_psds_in_one_squares(
aa, add, amp,
amps_defined, axds, axes,
axis, axos, c, cells_plot,
colors_b, file_name, files, grid_s1, grid_s2, ims,
load_name, stack_file, wss, xlim, axo2=axo2, axd2=axd2, iterate_var=amps_defined)
if aa == 0:
if len(axi) < 1:
tags.append(axo2)
else:
tags.append(axi)
tags.append(ax_square)
tags.append(axd2)
######################################################
spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
################################
# spikes
ax_spikes = plt.subplot(grid_upper[1 + aa, -2::])
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv',
load_type='spikes',
trial_nr=trial_nr, stimulus_length=length, amp=amp,
file_name=file_name)
if aa == 0:
scale = True
else:
scale = False
plt_spikes(c, cells_plot, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes,
alpha=alpha, xlim_e=xlim_e, sc=10,
scale=scale) # 1 - alpha_min * aa
amp_name = round_for_nice_float_strs(amp)
ax_spikes.text(1.01, 0.55, str(amp_name) + '$\%$', va='center', transform=ax_spikes.transAxes,
color=colors[str(cell_type)], alpha=alpha)
labels_for_psds(axd2, axi, axo2, extra_input, xpos_xlabel=xpos_xlabel, chi_pos=-0.1, right=right)
set_clim_same(ims, mats=mats, lim_type='up')
# do the scatter of these cells
add = ['', '_burst_corr_individual']
if len(stack) > 0:
load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell
if ax3 != []:
try:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack)
except:
print('stim problem')
embed()
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
################################
# isi
if len(isi) > 0:
if aa == len(amps_defined) - 1:
grid_p = gridspec.GridSpecFromSubplotSpec(1, 2, grid_upper[:, 0], width_ratios=[1.4, 2],
wspace=0.35,
hspace=0.55)
ax_isi = plt.subplot(grid_p[0])
ax_p = plt.subplot(grid_p[1])
# embed()
ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame,
isi, right, spikes_base, stack, xlim, base_extra=base_extra,
texts_left=[90, 0], titles=titles, peaks=True, xlim_i=[0, 4])
# ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, colors, frame,
# isi, base_extra=base_extra,
# titles=titles, xlim_i=[0, 4])
tags.insert(0, ax_isi)
###############################
# stimulus
if aa == len(amps_defined) - 1:
axe = plt.subplot(grid_upper[0, -2::])
plt_stimulus(eod_fr, axe, stack_final1, xlim_e, file_name=files[0])
tags.insert(1, axe)
tags_cell.append(tags)
try:
tags_susept_pictures(tags_cell, xoffs=np.array([-4.7, -3.2, -4.7, -4.3, -6.3]),
yoffs=np.array([1.1, 1.1, 2, 2, 2])) # , xoffs=np.array([-5.2, -4.2, -5.2, -5.7, -4.7,-5])
except:
print('tags here')
embed()
def plt_cellbody_singlecell(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, ax3=[], xlim=[],
permuted=False, RAM=True, isi_delta=None,
titles=['Low CV P-unit', 'High CV P-unit', 'Ampullary cell'],
peaks_extra=[False, False, False], base_extra=False, color_same=True, fr_name='$f_{Base}$',
eod_metrice=True, tags_individual=False, xlim_p=[0, 1.1], add_texts=[0.25, 0],
scale_val=False):
colors = colors_overview()
tags_cell = []
for c, cell in enumerate(cells_plot):
print(cell)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame)
ims = []
tags = []
add_here = '_cell' + cell # str(c)
mats = []
zorders = [100, 50]
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
if cell == '2012-07-03-ak-invivo-1':
snippets = 4
else:
snippets = 20
stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot)
if len(stack) > 0:
files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True)
file_name = files[0]
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
trues = []
for amp in amps_defined:
if amp in amps:
trues.append(True)
amps_defined = [np.min(amps), np.max(amps)]
cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1']
if cell == cells_amp:
print('cell thing')
embed()
wr_l = wr_l_cells_susept()
wr_u = [1.4, 0.1, 1, 1]
grid_cell, grid_upper = grids_upper_susept_pics(c, grid1, wr_u=wr_u)
ims = []
axds = []
axos = []
extra_input = False
several = False
axd2, axi, axo2, grid_lower, grid_s1, grid_s2 = grids_for_psds(amps_defined, extra_input, grid_cell,
several, wr=wr_l)
power_type = False
ax_psds = []
add_nonlin_title = None
xpos_xlabel = -0.23
diag_vals = []
for aa, amp in enumerate(amps_defined):
alpha = find_alpha_val(aa, amps_defined)
add_save = '_cell' + str(cell) + '_amp_' + str(amp)
right = 'middle' # ,
normval = 1
if amp in np.array(stack_file['amp']):
print(zorders[aa])
diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2, mat, add_nonlin_title = plot_lin_nonlin(
aa, add_save, amp, amps_defined, axds, axos, c, cells_plot, file_name,
grid_lower, ims, load_name, stack_file, xlim, power_type=power_type,
permuted=permuted, peaks_extra=peaks_extra[c], zorder=zorders[aa], alpha=alpha,
extra_input=extra_input, fr=fr, xpos_xlabel=xpos_xlabel,
add_nonlin_title=add_nonlin_title, color=colors[cell_type], axo2=axo2, axd2=axd2,
axi=axi, eod_metrice=eod_metrice, base_extra=base_extra, color_same=color_same,
iterate_var=amps_defined, normval=normval, snippets = snippets)
diag_vals.append(np.median(diagonals_prj_l))
mats.append(mat)
if aa == 0:
if len(axi) < 1:
tags.append(axo2)
else:
tags.append(axi)
tags.append(ax_square)
if aa == 1:
tags.append(axd2)
spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
################################
# spikes
ax_spikes = plt.subplot(grid_upper[1 + aa, -2::])
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv',
load_type='spikes',
trial_nr=trial_nr, stimulus_length=length, amp=amp,
file_name=file_name)
if (aa == 1) | (scale_val == True):
scale = True
else:
scale = False
plt_spikes(c, cells_plot, colors[str(cell_type)], ax_spikes, stack_final1, stack_spikes,
alpha=alpha,
scale=scale) # 1 - alpha_min * aa
amp_name = round_for_nice_float_strs(amp)
ax_spikes.text(1.01, 0.55, str(int(amp_name)) + '$\%$', va='center', transform=ax_spikes.transAxes,
color=colors[str(cell_type)], alpha=alpha)
ax_psds.extend([axd2])
ax_psds.extend([axo2])
axd2.annotate('', ha='center',
xy=(1, diag_vals[1]),
xytext=(1, diag_vals[0]),
arrowprops={"arrowstyle": "<->",
"linestyle": "-",
"linewidth": 0.7,
"color": 'black'},
zorder=1, annotation_clip=False)
labels_for_psds(axd2, axi, axo2, extra_input, right=right, xpos_xlabel=xpos_xlabel, normval=normval)
set_same_ylimscale(ax_psds)
# todo: hier eventuell noch einen percent machen damit das nicht so vebrlendet
set_clim_same(ims, mats=mats, lim_type='up', percnr=95)
# do the scatter of these cells
if len(stack) > 0:
load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell
if ax3 != []:
try:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack)
except:
print('stim problem')
embed()
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
################################
# isi
if len(isi) > 0:
if aa == len(amps_defined) - 1:
grid_p = gridspec.GridSpecFromSubplotSpec(1, 2, grid_upper[:, 0], width_ratios=[1.4, 2],
wspace=0.38,
hspace=0.55)
ax_isi = plt.subplot(grid_p[0])
ax_p = plt.subplot(grid_p[1])
ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame,
isi, right, spikes_base, stack, xlim_p, base_extra=base_extra,
add_texts=add_texts, titles=titles, peaks=True, fr_name=fr_name)
if isi_delta:
ax_isi.set_xticks_delta(isi_delta)
if tags_individual:
tags.insert(0, ax_p)
tags.insert(0, ax_isi)
###############################
# stimulus
xlim_e = [0, 100]
if aa == len(amps_defined) - 1:
axe = plt.subplot(grid_upper[0, -2::])
plt_stimulus(eod_fr, axe, stack_final1, xlim_e, RAM=RAM, file_name=files[0])
if tags_individual:
tags.insert(2, axe)
else:
tags.insert(1, axe)
tags_cell.append(tags)
try:
if len(cells_plot) == 1:
if tags_individual:
tags_susept_pictures(tags_cell[0], xoffs=np.array([-4.7, -3.2, -3.2, -4.7, -6.3, -2.7, -3.2]),
yoffs=np.array([3, 3, 3, 5.5, 5.5, 5.5, 5.5]))
else:
tags_susept_pictures(tags_cell, yoffs=np.array([3, 3, 5.5, 5.5, 5.5, 5.5]))
else:
tags_susept_pictures(tags_cell)
except:
print('tags here')
embed()
def base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame, isi, right, spikes_base,
stack, xlim, texts_left=(0.25, 0), clip_on=True, add_texts=(0.25, 0), base_extra=False,
titles=('', '', '', '', ''), pos=-0.25, peaks=False, fr_name='$f_{Base}$', xlim_i=(0, 16)):
# ax_isi.text(-0.2, 0.5, 'Baseline', rotation=90, ha='center', va='center', transform=ax_isi.transAxes)
plt_susept_isi_base('grey', ax_isi, isi, xlim=xlim_i,
clip_on=clip_on) # colors[str(cell_type)]c, cell_type, cells_plot,
normval = 1
if normval != 1:
ax_p.text(1.1, -0.4, f_eod_label_core(), ha='center', va='center',
transform=ax_p.transAxes) # transform=ax_isi.transAxes,
else:
ax_p.text(1.1, -0.4, f_eod_label_core_hz(), ha='center', va='center',
transform=ax_p.transAxes) # transform=ax_isi.transAxes,
ax_p.arrow_spines('b')
# embed()
ax_p = plt_susept_psd_base('grey', eod_fr, ax_p, spikes_base, xlim, right=right,
add_texts=add_texts, normval=normval, texts_left=texts_left,
sampling_rate=stack.sampling.iloc[0], peaks=peaks,
fr_name=fr_name) # colors[str(cell_type)]
#
cvs = True
# embed()
if cvs:
cv = frame[frame.cell == cell].cv.iloc[
0] # str(np.round(frame[frame.cell == cell].cv.iloc[0], 2))# color=colors[str(cell_type)],
fr = frame[frame.cell == cell].fr.iloc[
0] # str(np.round(frame[frame.cell == cell].cv.iloc[0], 2))# color=colors[str(cell_type)],
if base_extra:
if titles[c] == '':
add_nrs = r'$\mathrm{f'+basename_small()+'}=%.0f$\,Hz,' % fr + r' $\mathrm{CV'+basename_small()+'}=%.2f$' % cv
else:
add_nrs = r'$\mathrm{f}'+basename_small()+'}=%.0f$\,Hz,' % fr + r' $\mathrm{CV'+basename_small()+'}=%.2f$' % cv
ax_isi.text(pos, 1.25, titles[c] + add_nrs,
transform=ax_isi.transAxes) # str(np.std(isi) / np.mean(isi))
else:
ax_isi.text(pos, 1.25, titles[c] + r' $\rm{CV}=%.2f$' % cv,
transform=ax_isi.transAxes) # str(np.std(isi) / np.mean(isi))
else:
ax_isi.text(pos, 1.2, titles[c], color=colors[str(cell_type)], transform=ax_isi.transAxes)
return ax_isi
def f_eod_label_core():
return '$f/'+f_eod_name_core_rm()+'$'
def f_eod_label_core_hz():
return '$f$ [Hz]'
def wr_l_cells_susept():
wr_l = [0.5, 0, 1, 1, 0.2, 0.5]
return wr_l
def set_same_ylimscale(ax_psds):
ranges = []
for ax in ax_psds:
lim = ax.get_ylim()
lim_range = lim[1] - lim[0]
ranges.append(lim_range)
new_lim = np.max(ranges)
for ax in ax_psds:
lim = ax.get_ylim()
lim_range = lim[1] - lim[0]
add_lim = (new_lim - lim_range) / 2
ax.set_ylim(lim[0] - add_lim, lim[1] + add_lim)
def peaks_extra_fillbetween(axd2, eod_fr, fr, mats, normval=1):
diags = []
if normval != 1:
normval = eod_fr
for mat in mats:
diag, diagonals_prj_l = get_mat_diagonals(np.array(mat))
diags.extend(diagonals_prj_l)
diagonals_prj_l = 10 * np.log10(diags) # / maxd
axd2.fill_between([(fr - 5) / normval, (fr + 5) / normval],
[np.min(diagonals_prj_l), np.min(diagonals_prj_l)],
[np.max(diagonals_prj_l), np.max(diagonals_prj_l)], color='grey', alpha=0.5,
zorder=0)
def plt_susept_psd_base(colors, eod_fr, ax_p, spikes_base, xlim, add_texts=[0, 0], normval=1,
texts_left=[0.22, 0], right='middle', fr_name='$f_{Base}$', sampling_rate=40000, peaks=False):
spikes_mat, f_array, p_array = calc_psd_from_spikes(int(sampling_rate / 2), sampling_rate, spikes_base)
pp = 10 * np.log10(np.mean(p_array, axis=0)) # [0]
if normval != 1:
normval = eod_fr
if len(xlim) > 0:
if normval == 1:
xlim = [xlim[0], xlim[1] * eod_fr]
ax_p.set_xlim(xlim)
ax_p.plot(f_array / normval, pp, color=colors) # , alpha = float(alpha-0.05*s)
ax_p.show_spines('b')
if right == 'right':
ax_p.yscalebar(1, 0.35, 20, 'dB', va='center', ha='right')
ax_p.text(1.15, 0, 'Baseline', rotation=90, transform=ax_p.transAxes)
elif right == 'left':
ax_p.yscalebar(-0.03, 0.5, 20, 'dB', va='center', ha='left')
ax_p.text(-0.23, 0.5, 'Baseline', va='center', rotation=90, transform=ax_p.transAxes)
else:
ax_p.yscalebar(1.05, 0.35, 20, 'dB', va='center', ha='right')
if peaks:
fr = 1 / np.mean(np.diff(np.array(spikes_base[0]) / 1000))
plt_peaks_several([fr / normval, eod_fr / normval], [pp], ax_p, pp, f_array / normval, [fr_name, f_eod_name_rm()],
0, ['grey', 'grey'], add_texts=add_texts, texts_left=texts_left, add_log=2.5, exact=False,
text_extra=True, perc_peaksize=0.08, ms=14, clip_on=False, log='log') # True
return ax_p
def round_for_nice_float_strs(amp):
if amp % 1 > 0:
amp_name = np.round(amp, 1)
else:
amp_name = int(amp)
return amp_name
def ws_for_susept_pic():
wss = [0.4, 0.2]
return wss
def grids_upper_susept_pics(c, grid1, row=3, hr=[1, 0.4, 0.4], hs=0.65, ws=0.08, col=4, wr_u=[1, 2, 2]):
try:
grid_cell = gridspec.GridSpecFromSubplotSpec(2, 1, grid1[c], height_ratios=[3, 5],
hspace=hs) # 0.15
except:
print('cell thing3')
grid_cell = []
embed()
grid_upper = gridspec.GridSpecFromSubplotSpec(row, col, grid_cell[0], width_ratios=wr_u,
hspace=0.0, wspace=ws, height_ratios=hr) # hspace=0.1,wspace=0.39,
return grid_cell, grid_upper
def tags_susept_pictures(tags_cell, xoffs=np.array([-4.7, -3.2, -4.7, -6.3, -2.7, -3.2]),
yoffs=np.array([1.1, 1.1, 2, 2, 2, 2])):
fig = plt.gcf()
# ok das finde ich jetzt gut dass ich da eine Liste eingeben kann
tag2(fig, tags_cell, xoffs=xoffs, yoffs=yoffs) # -1.5diese Offsets sind nicht intuitiv
def grids_for_psds2(amps_defined, extra_input, grid_cell, several, ws=0.3, wss=[], add=0, widht_ratios=[],
wr=[1, 0.2, 1, 1]):
axd2 = []
axi = []
axo2 = []
grid_s1 = []
grid_s2 = []
if several:
grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(amps_defined) + 1, grid_cell[1], hspace=0.1,
wspace=0.2, width_ratios=widht_ratios)
grid_s1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[0],
hspace=0.1, wspace=wss[0],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
# plot the same also to the next plot
grid_s2 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[1],
hspace=0.1, wspace=wss[1],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
else:
if extra_input:
row_nrs = 3
else:
row_nrs = 2
try:
grid_lower = gridspec.GridSpecFromSubplotSpec(row_nrs, len(amps_defined) + 3 + add, grid_cell[1],
hspace=0.1,
wspace=ws,
width_ratios=wr) # , width_ratios=widht_ratios
except:
print('grid bursts0')
embed()
if extra_input:
axi = plt.subplot(grid_lower[0, 0])
axd2 = plt.subplot(grid_lower[2, 0]) # plt.subplot(grid_s[0])
axo2 = plt.subplot(grid_lower[1, 0])
else:
axd2 = plt.subplot(grid_lower[:, -1]) # plt.subplot(grid_s[0])
axo2 = plt.subplot(grid_lower[:, 0])
axi = []
return axd2, axi, axo2, grid_lower, grid_s1, grid_s2
def grids_for_psds(amps_defined, extra_input, grid_cell, several, ws=0.25, wss=[], add=0, widht_ratios=[],
wr=[1, 0.2, 1, 1]):
axd2 = []
axi = []
axo2 = []
grid_s1 = []
grid_s2 = []
if several:
grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(amps_defined) + 1, grid_cell[1], hspace=0.1,
wspace=0.2, width_ratios=widht_ratios)
grid_s1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[0],
hspace=0.1, wspace=wss[0],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
# plot the same also to the next plot
grid_s2 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[1],
hspace=0.1, wspace=wss[1],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
else:
if extra_input:
row_nrs = 3
else:
row_nrs = 2
try:
grid_lower = gridspec.GridSpecFromSubplotSpec(row_nrs, len(amps_defined) + 4 + add, grid_cell[1],
hspace=0.1,
wspace=ws,
width_ratios=wr) # , width_ratios=widht_ratios
except:
print('grid bursts')
embed()
if extra_input:
axi = plt.subplot(grid_lower[0, 0])
axd2 = plt.subplot(grid_lower[2, 0]) # plt.subplot(grid_s[0])
axo2 = plt.subplot(grid_lower[1, 0])
else:
axd2 = plt.subplot(grid_lower[:, -1]) # plt.subplot(grid_s[0])
axo2 = plt.subplot(grid_lower[:, 0])
axi = []
return axd2, axi, axo2, grid_lower, grid_s1, grid_s2
def labels_for_psds(axd2, axi, axo2, extra_input, right='middle', chi_pos=-0.3, normval=1,
xpos_xlabel=-0.23, power_label='$|\chi_1|$', log_transfer=False):
test_limits = False
if test_limits:
axo2.set_ylabel(power_label)
axd2.set_ylabel('Projection')
else:
# if aa == 0:
if right == 'right':
remove_yticks(axo2)
remove_yticks(axd2)
axo2.text(1.15, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes)
axd2.text(1.15, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes)
axd2.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right')
axo2.yscalebar(1, 0.5, 10, trasnfer_ylabel(), va='center', ha='right')
axd2.show_spines('b') # /mV
axo2.show_spines('b')
if extra_input:
axi.text(1.15, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes)
axi.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right')
axi.show_spines('b')
elif right == 'left':
remove_yticks(axo2)
remove_yticks(axd2)
axo2.text(-0.23, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes)
axd2.text(-0.23, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes)
axd2.text(-0.23, 0.5, 'Proj.', rotation=90, va='center', transform=axd2.transAxes)
axd2.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left')
axo2.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left')
axd2.show_spines('b')
axo2.show_spines('b')
if extra_input:
axi.text(-0.23, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes)
axi.yscalebar(-0.03, 0.5, 10, 'dB', va='center', ha='left')
axi.show_spines('b')
else:
axd_labels(axd2, chi_pos=chi_pos, normval=normval, xpos_xlabel=xpos_xlabel)
if log_transfer == True:
axo2.text(-0.13, 0.5, power_label, rotation=90, va='center', transform=axo2.transAxes)
axo2.yscalebar(1, 0.5, 10, trasnfer_ylabel(), va='center', ha='right')
axo2.show_spines('b')
remove_yticks(axo2)
else:
axo2.set_ylabel(trasnfer_ylabel())
axo2.show_spines('lb')
if normval != 1:
axo2.text(1.05, xpos_xlabel, tranfer_xlabel(), ha='center', va='center',
transform=axo2.transAxes)
else:
axo2.text(1.05, xpos_xlabel, tranfer_xlabel_hz(), ha='center', va='center',
transform=axo2.transAxes)
axo2.arrow_spines('b')
if extra_input:
axi.text(1.15, 0.5, 'Input', rotation=90, va='center', transform=axi.transAxes)
axi.yscalebar(1, 0.5, 10, 'dB', va='center', ha='right')
axi.show_spines('b')
def axd_labels(axd2, chi_pos=-0.3, normval=1, xpos_xlabel=-0.23): # chi_pos = -0.3,normval = 1, xpos_xlabel = -0.23,
axd2.text(chi_pos, 0.5, ylabel_projected(), rotation=90, va='center', transform=axd2.transAxes)
axd2.yscalebar(1.1, 0.5, 10, 'dB', va='center', ha='right')
axd2.show_spines('b') # -0.23
if normval != 1:
axd2.text(1.05, xpos_xlabel, diagonal_xlabel_nothz(), ha='center', va='center',
transform=axd2.transAxes)
else:
axd2.text(1.05, xpos_xlabel, diagonal_xlabel(), ha='center', va='center',
transform=axd2.transAxes)
axd2.arrow_spines('b')
remove_yticks(axd2)
def ylabel_projected():
return r'$|\bar{\chi_2}|$'
def trasnfer_ylabel():
return '$|\mathcal{\chi}_{1}|\,$[Hz]' # '$|\chi_{1}|$'#r'$|\mathcal{X}_{1}|$\,[Hz]'#' '
def get_base_params(cell, cell_type_type, frame):
try:
frame_cell = frame[(frame['cell'] == cell)]
except:
print('frame thing')
embed()
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
print('cell type prob')
embed()
spikes = frame_cell.spikes.iloc[0]
spikes_all = []
isi = []
frs_calc = []
fr = frame_cell.fr.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
return cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all
def plt_susept_isi_base(color, ax_isi, isi, delta=None, ypos=-0.4, xlim=[0, 13],
clip_on=False):
ax_isi.show_spines('b') # xmin=0, xmax=xmax[c_here], alpha=0.5, step=0
kernel_histogram(ax_isi, color, isi[0], extend=False, clip_on=clip_on, step=0.01)
if len(xlim) > 0:
ax_isi.set_xlim(xlim)
ax_isi.text(1.1, ypos, isi_label_core(), transform=ax_isi.transAxes, ha='center',
va='center', ) # ($I_{spikes}/I_{EODf}$)
else:
ax_isi.text(1.1, ypos, isi_label_core(), transform=ax_isi.transAxes, ha='center',
va='center', ) # ($I_{spikes}/I_{EODf}$)
if delta:
ax_isi.set_xticks_delta(delta)
ax_isi.arrow_spines('b')
remove_yticks(ax_isi)
def isi_label_core():
return '$1/'+f_eod_name_core_rm()+'$'
def kernel_histogram(ax_isi, color, isi, norm='no', clip_on=False, step=0.1, label='', orientation='horizontal',
alpha=1, xmin='no', xmin_perc=False, perc_min=0.01, xmax='no', height_val=1, extend=True):
if len(isi) > 1:
isi = np.array(list(map(float, np.array(isi))))
try:
isi = isi[~np.isnan(isi)]
isi = isi[~np.isinf(isi)]
except:
print('any problem')
embed()
try:
if step == 0:
kernel = gaussian_kde(isi)
else:
kernel = gaussian_kde(isi, step / np.std(isi, ddof=1))
except:
print('kernel thing')
embed()
isi_sorted = np.sort(isi)
if xmin == 'no':
if xmin_perc:
# das mit dem percentile ist keine gute idee weil die verteilung kann ja im mittel durchaus hohe werte haben
xmin = np.min(isi_sorted) - np.percentile(isi_sorted, perc_min)
else:
xmin = np.min(isi_sorted) * 0.8
if xmax == 'no':
if xmin_perc:
xmax = np.max(isi_sorted) + np.percentile(isi_sorted, perc_min)
else:
xmax = np.max(isi_sorted) * 1.1
# create points between the min and max
try:
if extend:
x = np.linspace(xmin, xmax, 1000)
else:
x = isi_sorted
except:
print('extend thing')
embed()
kde = kernel(x)
if norm == 'density':
kde = kde / np.sum(kde)
elif norm == 'maximum':
kde = height_val * kde / np.max(kde)
if orientation == 'horizontal': # isi_sorted
ax_isi.plot(x, kde, color=color, label=label, alpha=alpha, clip_on=clip_on) # filllbetween
ax_isi.fill_between(x, kde, color=color, alpha=alpha, clip_on=clip_on) # filllbetween
else:
ax_isi.plot(kde, x, color=color, label=label, alpha=alpha, clip_on=clip_on) # ,clip_on = Falsefilllbetween
ax_isi.fill_betweenx(np.sort(x), kde[np.argsort(x)], color=color, alpha=alpha,
clip_on=clip_on) # filllbetween,clip_on = False
test = False
if test:
from utils_test import test_isi
test_isi()
def plt_square_with_psds(aa, amp, amps_defined, axes, axis, c, cells_plot, files, grid_s, ims, load_name,
stack_file, xlim, cbar_true=True, axd=None, axo=None, square_plot=True, color='black',
add='', file_name=None):
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
if not file_name:
file_name = files[0]
stack_osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='osf', trial_nr=trial_nr,
stimulus_length=length, add=add, amp=amp, file_name=file_name)
stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add, load_version='csv', load_type='spikes',
trial_nr=trial_nr, stimulus_length=length, amp=amp, file_name=file_name)
extra = False
if extra:
try:
pass
except:
pass
#############################
# load both amps for common amp db limit
_, _, _ = get_max_several_amp_squares(add, amp, amps_defined, files, load_name, stack_file)
if aa == 0:
pass
else:
pass
try:
ax_square = plt.subplot(grid_s[:, 1])
except:
print('grid problem4')
embed()
if square_plot:
mat, test_limits, im, add_nonlin_title = plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims,
stack_final1, [], cbar_true=cbar_true)
############################################
# psd part
fr = stack_final1.fr.unique()[0]
if not axd:
try:
axd = plt.subplot(grid_s[1, 0])
except:
print('axd thing')
embed()
if not axo:
axo = plt.subplot(grid_s[0, 0])
axd, axi, axo = plt_psds_all(axd, axo, mat,
stack_final, stack_osf, test_limits, xlim, color=color,
db='db')
axes.append(axi) # np.min(mat.columns)
axes.append(axo) # np.max(mat.columns)
axis.append(axi)
return ax_square, axi, eod_fr, fr, stack_final1, stack_spikes, im, axd, axo
def stack_preprocessing(amp, stack_file, snippets=20):
stack_amp2 = stack_file[stack_file['snippets'] == snippets]
if len(stack_amp2) < 1:
stack_amp2 = stack_file[stack_file['snippets'] == 20]
if len(stack_amp2) < 1:
stack_amp2 = stack_file[stack_file['snippets'] == 10]
if len(stack_amp2) < 1:
stack_amp2 = stack_file[stack_file['snippets'] == 9]
if len(stack_amp2) < 1:
stack_amp2 = stack_file[stack_file['snippets'] == 4]
stack_amp = stack_amp2[stack_amp2['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
#if len(stack_final) <0:
# todo: hier 20 Trials auch einbauen
trial_nr_double = stack_final.trial_nr.unique()
try:
eod_fr = stack_final.eod_fr.iloc[0]
except:
print('trial thing')
embed()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double1')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
trial_nr = np.max(trial_nr_double)
except:
print('trial something')
embed()
try:
stack_final1 = stack_final[stack_final.trial_nr == trial_nr]
except:
print('stack_final1 problem')
embed()
return eod_fr, length, stack_final, stack_final1, trial_nr
def get_max_several_amp_squares(add, amp, amps_defined, files, load_name, stack_file):
maxs_i = []
maxs_o = []
maxs_d = []
for amp_here in amps_defined:
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
norm_d = False # todo: das insowas wie ein übergeordnetes Dict machen
if norm_d:
mat = RAM_norm_data(stack_final['d_isf1'].iloc[0], stack_plot, stack_final['snippets'].unique()[0])
else:
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
diag, diagonals_prj_l = get_mat_diagonals(np.array(mat))
maxs_d.append(np.max(diagonals_prj_l))
trial_nr = np.max(trial_nr_double)
isf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='isf', trial_nr=trial_nr,
stimulus_length=length, add=add, amp=amp_here, file_name=files[0])
f = find_f(stack_final)
f_max = stack_final.index[-1] * 2
f_restict = f[f < f_max]
maxs_i.append(np.max(isf.iloc[0:len(f_restict)]) ** 2)
osf = load_data_susept(load_name + '.pkl', load_name, load_version='csv',
load_type='osf', add=add, trial_nr=trial_nr,
stimulus_length=length, amp=amp_here, file_name=files[0])
f_max = stack_final.index[-1] * 2
f_restict = f[f < f_max]
maxs_o.append(np.max(osf.iloc[0:len(f_restict)]) ** 2)
maxo = np.max(maxs_o)
maxi = np.max(maxs_i)
maxd = np.max(maxs_d)
return maxo, maxi, maxd
def cv_stim_name_rm():
return '\mathrm{CV_{stim}}'
def plt_square_here(aa, amp, amps_defined, ax_square, c, cells_plot, ims, stack_final1, xlim, line_length=1 / 4,
alpha=1, cbar_true=True,
perc=True, amp_give=True, base_extra=False, eod_metrice=True, ha='right', nr=3, fr=None,
title_square='', xpos_xlabel=-0.2, ypos=0.05, xpos=0.1, color='white', add_nonlin_title=None):
if aa == len(amps_defined) - 1:
cbar_do = False
else:
cbar_do = False
print(add_nonlin_title)
cbar, mat, im, add_nonlin_title = plot_square_core([ax_square], stack_final1, eod_metrice=eod_metrice, fr=fr, nr=nr,
line_length=line_length, cbar_do=cbar_do, perc=perc,
add_nonlin_title=add_nonlin_title)
ims.append(im)
if xlim:
ax_square.set_xlim(xlim)
ax_square.set_ylim(xlim)
else:
ax_square.set_xlim(0, np.max(mat.columns))
ax_square.set_ylim(0, np.max(mat.columns))
ax_square.set_title('')
test_limits = False
if test_limits:
fig = plt.gcf()
_, _, _, _, _ = colorbar_outside(ax_square, im, fig, add=5, width=0.01)
ax_square.text(1.05, 0.5, nonlin_title(), ha='center', rotation=90)
else:
if cbar_true:
fig = plt.gcf()
cbar, left, bottom, width, height = colorbar_outside(ax_square, im, fig, add=5, width=0.007)
cbar.set_label(nonlin_title(' [' + add_nonlin_title), rotation=90, labelpad=3)
amp_name = round_for_nice_float_strs(amp)
cvs = True
if amp_give:
amp_val = '$c=%s$' % (int(amp_name)) + '$\,\%$, '
else:
amp_val = ''
if cvs: # $\mathcal{X}_{2}
if base_extra:
ax_square.text(xpos, ypos,
title_square + amp_val + r'$'+cv_stim_name_rm()+'=%.2f$' % (stack_final1.cv_stim.iloc[0]),
ha=ha,
transform=ax_square.transAxes, color=color,
alpha=alpha) # (np.round(stack_final1.cv_stim.iloc[0], 2))'white' files[0] + ' l ' + str(length)chi_name()+
else:
ax_square.text(xpos, ypos, title_square + amp_val + r'$\rm{CV}=%.2f$' % (stack_final1.cv_stim.iloc[0]),
ha=ha,
transform=ax_square.transAxes, color=color,
alpha=alpha) # (np.round(stack_final1.cv_stim.iloc[0], 2))'white' files[0] + ' l ' + str(length)chi_name()+
else:
ax_square.text(xpos, ypos, title_square + chi_name() + '$($' + str(int(amp_name)) + '$\%$)', ha=ha,
transform=ax_square.transAxes, color=color,
alpha=alpha) # 'white' files[0] + ' l ' + str(length)
ax_square.set_xticks_delta(100)
ax_square.set_yticks_delta(100)
if c != len(cells_plot) - 1:
ax_square.set_xlabel('')
else:
ax_square.set_xlabel(F1_xlabel())
ax_square.text(1.05, xpos_xlabel, F1_xlabel(), ha='center', va='center',
transform=ax_square.transAxes)
if aa != 0:
remove_yticks(ax_square)
ax_square.set_ylabel('')
else:
set_ylabel_arrow(ax_square)
ax_square.arrow_spines('lb')
ax_square.set_xlabel('')
ax_square.set_ylabel('')
return mat, test_limits, im, add_nonlin_title
def set_ylabel_arrow(ax_square, xpos=-0.15, ypos=0.97, color='black', arrow = False):
ax_square.text(xpos, ypos, F2_xlabel(), ha='center', va='center',
transform=ax_square.transAxes, rotation=90, color=color)
if arrow:
ax_square.arrow_spines('l')
def F2_xlabel():
return '$f_{2}$\,[Hz]'
def nonlin_title(add_nonlin_title=''):
return chi_name() + add_nonlin_title + 'Hz]' # \frac{Hz}{mV^2}
def chi_name():
return r'$|\chi_{2}|$' # r'$|\mathcal{X}_{2}|$'# ['
def plt_psds_all(axd, axo, mat, stack_final, stack_osf,
test_limits, xlim, alpha=1, color='black', power_type=False, db='', fr=None, peaks_extra=False,
zorder=1, eod_fr=1):
###############################
# projection diagonal
xmax, xmin, diagonals_prj_l = plt_diagonal(axd, color, db, fr, mat, alpha, eod_fr, peaks_extra, xlim, zorder)
if power_type:
plt_power_trace(alpha, axo, color, db, stack_final, stack_osf, test_limits, xmax,
eod_fr=eod_fr, zorder=zorder)
else:
plt_transferfunction(alpha, axo, color, stack_final, zorder=zorder)
axo.set_xlim(xmin, xmax)
return axd, axo, axo
def plt_diagonal(axd, color, db, fr, mat, alpha=1, eod_fr=750, peaks_extra=True, label='', xlim=[], zorder=1, normval=1,
color_same=True):
diag, diagonals_prj_l = get_mat_diagonals(np.array(mat))
axis_d = axis_projection(mat, axis='')
if normval != 1:
normval = eod_fr
if db == 'db':
diagonals_prj_l = 10 * np.log10(diagonals_prj_l) # / maxd
axd.plot(axis_d / normval, diagonals_prj_l, color=color, alpha=alpha - 0.05, zorder=zorder, label=label)
if peaks_extra:
if not color_same:
color = 'black'
color_dot = 'black'
alpha_dot = [alpha]
else:
color_dot = 'grey'
alpha_dot = [1]
axd.axhline(np.median(diagonals_prj_l), linewidth=0.9, linestyle='--', color=color,
alpha=alpha, zorder=zorder + 1) # 0.45#0.75
plt_peaks_several([fr / normval], [diagonals_prj_l], axd, diagonals_prj_l, axis_d / normval, [''], 0,
[color_dot], zorder=zorder + 1, alphas=alpha_dot, ms=5)
xmax, xmin = get_xlim_psd(axis_d / normval, xlim)
print(xmax)
axd.set_xlim(xmin, xmax)
axd.arrow_spines('b')
return xmax, xmin, diagonals_prj_l
def get_xlim_psd(axis_d, xlim):
if xlim: # [0]
xmin = xlim[0]
xmax = xlim[1]
else:
xmin = 0 # axis_d[0] - 1 # mat.columns[0]
xmax = axis_d[-1] # mat.columns[-1]
return xmax, xmin
def plt_transferfunction(alpha, axo, color, stack_final, zorder=1, label='', normval=1, log_transfer=False):
f_axis, vals = get_transferfunction(stack_final)
if log_transfer:
means_all = 10 * np.log10(vals)
else:
means_all = 10 * np.log10(vals)
max_lim = calc_cut_offs(stack_final.file_name.iloc[0]) / normval
axis = f_axis / normval
if normval != 1:
pass
if max_lim:
axo.plot(axis[axis < max_lim], means_all[axis < max_lim], color=color, zorder=zorder, alpha=alpha, label=label)
else:
axo.plot(axis, means_all, color=color, alpha=alpha, zorder=zorder, label=label)
def get_transferfunction(stack_final):
osf = stack_final['osf']
isf = stack_final['isf']
f = find_f(stack_final)
f_axis = f[0:len(isf.iloc[0][0])]
# csd pds berechnung
counter = 0
for t in range(len(osf)):
if type(osf.iloc[t]) == list:
if t == 0:
vals = osf.iloc[t][0] * np.conj(isf.iloc[t][0])
powers = np.abs(isf.iloc[t][0]) ** 2
else:
vals += osf.iloc[t][0] * np.conj(isf.iloc[t][0])
powers += np.abs(isf.iloc[t][0]) ** 2
counter += 1
vals = vals / counter
vals = np.abs(vals) / (powers / counter)
return f_axis, vals
def plt_power_trace(alpha, axo, color, db, stack_final, stack_osf, test_limits, xmax, eod_fr=1,
zorder=1):
if len(stack_osf) == 0:
isf = stack_final['osf']
isf_resaved = False
else:
isf = stack_osf
isf_resaved = True
f = find_f(stack_final)
power = 1
if isf_resaved:
f_axis = f[0:len(isf)]
means = np.transpose(isf)
means_all = np.mean(np.abs(means) ** power, axis=0)
else:
f_axis = f[0:len(isf.iloc[l][0])]
means = get_array_from_pandas(isf)
means_all = np.mean(np.abs(means) ** power, axis=0)
if db == 'db':
means_all = 10 * np.log10(means_all)
max_lim = xmax
axis = f_axis / eod_fr
if max_lim:
axo.plot(axis[axis < max_lim], means_all[axis < max_lim], color=color, zorder=zorder, alpha=alpha)
else:
axo.plot(axis, means_all, color=color, alpha=alpha, zorder=zorder)
if not test_limits:
remove_xticks(axo)
remove_xticks(axo)
def plt_spikes(c, cells_plot, color, ax_spikes, stack_final1,
stack_spikes, alpha=1, xlim_e=[0, 200], sc=20, scale=True, spikes_max=5):
spikes_here = create_spikes(stack_final1, stack_spikes)
if len(spikes_here) > 0:
if len(spikes_here) > spikes_max:
spikes_here = spikes_here[0:spikes_max]
ax_spikes.eventplot(spikes_here, color=color, alpha=alpha)
ax_spikes.set_xlim(xlim_e) # spikes_both[gg]
if c != len(cells_plot) - 1:
remove_xticks(ax_spikes)
ax_spikes.show_spines('')
if scale:
ax_spikes.xscalebar(float(1 - sc / xlim_e[-1]), -0.03, sc, 'ms', va='left', ha='bottom') #
def plt_stimulus(eod_fr, axe, stack_final1, xlim_e, RAM=True, file_name=None, alpha = 0.5, add=0.07):
axe.show_spines('')
neuronal_delay = 5 # das hatte Jan G. angemerkt, dass wir den Stimulus um den neuronal Delay kompensieren sollten
max_here = (xlim_e[1] + neuronal_delay) / 1000
eod_interp, sampling_interp, time_eod_interp = get_stimulus_here(file_name, stack_final1, sampling=40000,
max=max_here)
fake_fish = fakefish.wavefish_eods('Alepto', frequency=eod_fr,
samplerate=sampling_interp,
duration=len(time_eod_interp) / sampling_interp,
phase0=0.0, noise_std=0.00)
size_fake_am = 0.5
if RAM:
try:
axe.plot(time_eod_interp * 1000 - neuronal_delay, fake_fish * (1 + eod_interp * size_fake_am), color='lightgrey',
alpha=alpha, clip_on=True)
except:
print('axe thing')
embed()
axe.plot(time_eod_interp * 1000 - neuronal_delay, eod_interp * size_fake_am + 1 + add, color='red', linewidth=1)
else:
try:
axe.plot(time_eod_interp * 1000 - neuronal_delay, fake_fish + eod_interp * size_fake_am, color='grey',
alpha=0.5, clip_on=True)
except:
print('axe thing')
embed()
axe.plot(time_eod_interp * 1000 - neuronal_delay, eod_interp * size_fake_am + 1 + add, color='red', linewidth=1)
axe.set_xlim(xlim_e)
ylim_e = axe.get_ylim()
ylim_e = np.array(ylim_e) * 1.05
axe.set_ylim(ylim_e)
remove_xticks(axe)
def get_stimulus_here(file_name, stack_final1, max=0.4, sampling=None):
if not sampling:
sampling = stack_final1.sampling.iloc[0]
time_eod = np.arange(0, max, 1 / sampling)
if not file_name:
try:
eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name.iloc[0])
except:
try:
eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name2.iloc[0])
except:
eod_interp, time_wn_cut, _ = load_noise(stack_final1.file_name.iloc[0] + 's')
print('open problem thing2')
else:
try:
eod_interp, time_wn_cut, _ = load_noise(file_name)
except:
eod_interp, time_wn_cut, _ = load_noise(file_name + 's')
eod_interp = interpolate(time_wn_cut, eod_interp,
time_eod,
kind='cubic')
return eod_interp, sampling, time_eod
def same_lims_susept(axds, axis, axos, ims):
set_clim_same(ims, clims='all', same='same', lim_type='up')
set_same_ylim(axos)
set_same_ylim(axis)
set_same_ylim(axds)
def create_spikes(stack_final1, stack_spikes=[]):
spikes_here = []
if len(stack_spikes) > 0: # type(stack_final1.spikes.iloc[0]) == str
if type(stack_spikes) == list:
for sp in range(len(stack_spikes)):
try:
spi = stack_spikes[sp].dropna()
except:
spi = stack_spikes[sp]
spikes_here.append(np.array(spi) * 1000)
else:
for sp in range(np.shape(stack_spikes)[1]):
try:
spi = stack_spikes[sp].dropna()
except:
spi = stack_spikes[sp]
spikes_here.append(np.array(spi) * 1000)
else:
for sp in stack_final1.spikes.iloc[0][0]:
try:
spikes_here.append(np.array(sp) * 1000)
except:
print('spike thing')
embed()
return spikes_here
def load_mt_data(axi, axs, c, cell, cells_plot, colors_hist, gg, grid_upper, stack_final1, xlim):
data_dir = 'cells/'
data_name = cell
name_core = load_folder_name('data') + data_dir + data_name
nix_name = name_core + '/' + data_name + '.nix' # '/'
f = nix.File.open(nix_name, nix.FileMode.ReadOnly)
b = f.blocks[0]
names_mt_gwn = stack_final1['names_mt_gwn'].unique()[0]
try:
mt = b.multi_tags[names_mt_gwn]
except:
names_mt_gwns = find_names_gwn(b)
mt = b.multi_tags[names_mt_gwns[0]]
print('mt thing')
embed()
features, id, data_between_2017_2018, mt_ids = find_feature_gwn_id(mt)
dataset, rlx_problem = load_rlxnix(nix_name)
# wir machen das hier für diese rlx only weil ich nur so an den Kontrast komme
spikes_loaded = []
if rlx_problem:
file_name, file_name_save, cut_off, file, sd = find_file_names(nix_name, mt,
names_mt_gwn)
file_extra, idx_c, base_properties, id_names = get_contrasts_over_rlx_calc_RAM(dataset)
dataset.close()
try:
base_properties = base_properties.sort_values(by='c', ascending=False)
except:
print('contrast problem sorting')
embed()
if data_between_2017_2018 != 'all':
file_name_sorted = base_properties[base_properties.file_name == file_name]
else:
file_name_sorted = base_properties
if len(file_name_sorted) < 1:
print('file_name problem')
embed()
file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1]
# ich sollte auf dem level schon nach dem richtigen filename filtern!
file_name_sorted = file_name_sorted[file_name_sorted['c_orig'] == stack_final1['c_orig'].unique()[0]]
grouped = file_name_sorted.groupby('c')
# ok es gibt wohl eine Zelle die erste, Zelle '2010-06-15-af' wo eben das nicht input arr heißt sondern gwn 300, was da passiert ist kann ich
# euch jetzt so auch nicht sagen, aber alle anderen Zellen sehen gut aus! Scheint die einzige zu sein°
data_array_names = get_data_array_names(b) # ,find_indices_to_match_contrats,get_data_array_names
if 'eod' in ''.join(data_array_names).lower():
for g, group in enumerate(grouped):
# hier erstmal alles nach dem Kontrast sortieren
sd, start, end, rep, cut_off, c_len, c_unit, c_orig, c_len, files_load, cc, id_g, amplsel = open_group_gwn(
group,
file_name,
cut_off,
sd,
data_between_2017_2018)
indices, ends_mt = find_indices_to_match_contrats(grouped, group, mt, id_g, mt_ids,
data_between_2017_2018)
indices = list(map(int, indices))
max_f = cut_off
if max_f == 0:
print('max f = 0')
embed()
for mm, m in enumerate(indices):
first, minus, second, stimulus_length = find_first_second(b, names_mt_gwn, m, mt,
False,
mm=mm, ends_mt=ends_mt)
spikes_mt = link_arrays_spikes(b, first,
second, minus) #
spikes_loaded.append(spikes_mt * 1000)
eod_mt, sampling = link_arrays_eod(b, first,
second,
array_name='LocalEOD-1')
# hier noch das stimpresaved laden
# todo: das eventuell noch anpassen
axi.set_xlim(0, 13)
xlim_e = [0, 200]
axs = plt.subplot(grid_upper[1, 1::])
axs.eventplot(spikes_loaded, color=colors_hist[gg], )
axs.set_xlim(xlim) # spikes_both[gg]
if c != len(cells_plot) - 1:
remove_xticks(axi)
remove_xticks(axs)
else:
axi.set_xlabel('ISI') # ($I_{spikes}/I_{EODf}$)
axs.set_xlabel('Time [ms]')
axi.set_ylabel('Nr')
axs.set_ylabel('Nr')
else:
print('rlx thing')
return axs, eod_mt, sampling, xlim_e
def exclude_cut_filenames(cell_type, stack, fexclude=False):
file_names_exclude = create_file_names_exclude(cell_type)
files = stack['file_name'].unique()
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
print('file names excluded')
print(files)
return files, stack
def plt_cellbody_punitsingle(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type,
plus=1, ax3=[], xlim=[], burst_corr='_burst_corr_individual'):
stack = []
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
print('cell type prob')
embed()
spikes = frame_cell.spikes.iloc[0]
eod, sampling_rate, ds, time_eod = find_eod(frame_cell)
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
# cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
# also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben!
if cont_spikes:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
lim_here = find_lim_here(cell, burst_corr)
print(lim_here)
if np.min(np.concatenate(hists)) < lim_here:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all), lim=lim_here,
burst_corr=burst_corr)
spikes_both = [spikes_all, spikes_ex]
hists_both = [hists, hists2]
else:
spikes_both = [spikes_all]
hists_both = [hists]
# das ist der title fals der square nicht plottet
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
axes = []
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
file_names_exclude = create_file_names_exclude(cell_type)
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
for aa, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, aa + plus],
height_ratios=[1.5, 1.5, 5],
hspace=0.1)
axs = plt.subplot(grid_s[2])
except:
print('grid problem2')
embed()
cbar, mat, im = plot_square_core([axs], stack_final1)
if xlim:
axs.set_xlim(xlim)
axs.set_ylim(xlim)
axs.set_title('')
if aa == len(amps) - 1:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
fr = stack_final1.fr.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
if xlim: # [0]
pass
else:
pass
try:
axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, xlim[0], xlim[-1], eod_fr,
fr, fr_stim, stack_final, )
except:
print('psd problem')
embed()
if aa == 1:
axo.text(0.5, 1, cell, ha='center', transform=axo.transAxes)
axes.append(axi) # np.min(mat.columns)
axes.append(axo) # np.max(mat.columns)
if aa == 0:
axo.set_ylabel('Otp.')
axi.set_ylabel('Inp.')
if c == 0:
axo.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length)
if aa != 0:
axi.set_ylabel('')
if c != len(cells_plot) - 1:
axs.set_xlabel('')
remove_xticks(axs)
else:
axs.set_xlabel(F1_xlabel())
if aa != 0:
remove_yticks(axs)
axs.set_ylabel('')
# do the scatter of these cells
add = ['', '_burst_corr_individual']
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if len(stack) > 0:
load_name = load_folder_name('calc_RAM') + '/' + save_names[s] + '_' + cell
if ax3 != []:
try:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack)
except:
print('stim problem')
embed()
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
alpha = 1
grid_s = gridspec.GridSpecFromSubplotSpec(4, 1, grid1[c, 0], height_ratios=[2.5, 1.5, 1.2, 2.5],
hspace=0.25)
axi = plt.subplot(grid_s[-1])
axs.set_title(' ' + cell)
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
for gg in range(len(hists_both)):
hists_here = hists_both[gg]
for hh, h in enumerate(hists_here):
try:
axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh))
except:
print('alpha problem2')
embed()
axi.set_xlim(0, 13)
axe = plt.subplot(grid_s[0])
axe.plot(time_eod * 1000, eod, color='grey', linewidth=0.5)
axe.set_xlim(0, 40)
axs = plt.subplot(grid_s[1])
axs.eventplot(spikes_both[gg], color=colors_hist[gg], )
axs.set_xlim(0, 40)
if c != len(cells_plot) - 1:
remove_xticks(axi)
remove_xticks(axs)
else:
axi.set_xlabel('isi')
axs.set_xlabel('Time [ms]')
remove_xticks(axe)
axi.set_ylabel('Nr')
axs.set_ylabel('Nr')
axe.set_ylabel('mV')
def create_file_names_exclude(cell_type):
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = punit_file_exclude()
else:
file_names_exclude = ampullary_file_exclude() #
return file_names_exclude
def punit_file_exclude():
file_names_exclude = ['InputArr_350to400hz_30',
'InputArr_250to300hz_30',
'InputArr_150to200hz_30',
'InputArr_50to100hz_30',
'InputArr_50hz_30',
'gwn50Hz50s0.3',
'gwn50Hz10s0.3',
'gwn50Hz10.3',
'gwn50Hz10s0.3short',
'gwn25Hz10s0.3',
'FileStimulus-file-gaussian50.0',
'FileStimulus-file-gaussian25.0',
] #
return file_names_exclude
def plt_cell_body2(grid1, frame, colors, cells_plot, cell_type_type, ax3=[],
xlim=[]):
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
embed()
spikes = frame_cell.spikes.iloc[0]
fr = frame_cell.fr.iloc[0]
cv = frame_cell.cv.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
# cont_spikes heißt dass die spikes nans sind oder leer sind, das heißt da ist gar nichts, auch keine scatter, das sind die weißen bilder die brauchen wir nicht
# also hier ist das ok das mit dem Cont spikes so zu machen weil wir wollen die ja haben!
if cont_spikes:
# die zwei checken ob es mehr als paar spikes gibt,ohne die brauchen wir auch kein Bild
if len(hists) > 0:
if len(np.concatenate(hists)) > 0:
if np.min(np.concatenate(hists)) < 1.5:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all, [eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all))
hists_both = [hists, hists2]
else:
hists_both = [hists]
# das ist der title fals der square nicht plottet
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz',
fontsize=11, ) # cell[0:13] + color=color+ cell_type
save_names = [
'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0_s_burst_corr',
'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr',
'noise_data8_nfft1sec_original__LocalEOD_mean2__CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr',
]
for a, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
if os.path.exists(load_name + '.pkl'):
stack = pd.read_pickle(load_name + '.pkl')
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = punit_file_exclude() #
else:
file_names_exclude = ampullary_file_exclude()
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
_, _ = find_row_col(np.arange(len(amps) * len(files)))
predefined_amp = True
if predefined_amp:
pass
else:
pass
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
amps_defined = [np.min(amps)]
for aa, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
embed()
# ich hatte mal einen save fehler deswegen habe ich manche doppelt also einfach das letzte nehmen nehme ich an
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
try:
grid_s = gridspec.GridSpecFromSubplotSpec(3, 1, grid1[c, a + 1],
height_ratios=[1.5, 1.5, 5],
hspace=0)
axs = plt.subplot(grid_s[2])
except:
print('grid problem0')
embed()
cbar, mat, im = plot_square_core([axs], stack_final1)
if xlim:
axs.set_xlim(xlim)
axs.set_ylim(xlim)
if a == len(amps) - 1:
cbar.set_label(nonlin_title(), rotation=90, labelpad=10)
fr = stack_final1.fr.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
axo, axi = plt_psd_traces(grid_s[0], grid_s[1], axs, np.min(mat.columns),
np.max(mat.columns), eod_fr, fr, fr_stim, stack_final,
)
if c == 0:
axi.set_title(' std = ' + str(amp) + '$\%$') # files[0] + ' l ' + str(length)
if a != 0:
axi.set_ylabel('')
if c != 2:
axs.set_xlabel('')
remove_xticks(axi)
################################
# do the scatter of these cells
if ax3 != []:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type)
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
################################
# do the hist
do_hist(grid1, c, colors, cell_type, hists_both, cell, cells_plot)
def ampullary_file_exclude():
file_names_exclude = ['InputArr_350to400hz_30',
'InputArr_250to300hz_30',
'InputArr_150to200hz_30',
'InputArr_50to100hz_30',
'InputArr_50hz_30',
'blwn125Hz10s0.3',
'gwn50Hz10s0.3',
'FileStimulus-file-gaussian50.0',
'FileStimulus-file-gaussian25.0', 'gwn25Hz10s0.3', 'gwn50Hz10.3',
'gwn50Hz10s0.3short',
'gwn50Hz50s0.3', 'gwn25Hz10s0.3', ] #
return file_names_exclude
def do_hist(grid1, c, colors, cell_type, hists_both, cell, cells_plot):
alpha = 1
axi = plt.subplot(grid1[c, 0])
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
for gg in range(len(hists_both)):
hists_here = hists_both[gg]
for hh, h in enumerate(hists_here):
try:
axi.hist(h, bins=100, color=colors_hist[gg], alpha=float(alpha - 0.05 * hh))
except:
print('alpha problem3')
embed()
axi.set_title('CV ' + str(np.round(np.std(h) / np.mean(h), 3)) + ' ' + cell) # +' VS '+str(vs)
axi.set_xlim(0, 13)
if c != len(cells_plot) - 1:
remove_xticks(axi)
else:
axi.set_xlabel('isi')
def cells_eigen(base_extra=False, amp_desired=[0.5, 1, 5], xlim=[0, 1.1], cells_plot2=[],
titles=['Baseline \n Susceptibility', 'Half EODf \n Susceptibility'],
peaks_extra=[False, False, False]):
plot_style()
# 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_
# ]
# save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_']
save_names = [version_final()]
amps_desired = amp_desired
# amps_desired, cell_type_type, cells_plot, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class)
cell_type_type = 'cell_type_reclassified'
frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual', 'cell_type_reclassified',
'cell'], path_sp='/calc_base_data-base_frame_overview.pkl',
frame_general=False)
cells_plot = cells_plot2
default_figsize(column=2, width=12, length=3.25 * len(cells_plot)) # ts=10, fs=10, ls=10,
grid1 = big_grid_susept_pics(cells_plot, top=0.94, bottom=0.065)
plt_cellbody_eigen(grid1, frame, amps_desired, save_names, cells_plot, cell_type_type, xlim=xlim,
base_extra=base_extra, titles=titles,
peaks_extra=peaks_extra)
save_visualization(pdf=True)
def fr_name_rm():
rm_var = rem_variable()
if rm_var['rm']:
val = r'$f\rm{'+basename_small()+'}$'
else:
val = r'$f'+basename_small()+'$'
return val
def ampullary_punit(permuted=False, eod_metrice=True, base_extra=False, color_same=True, fr_name='$f_{Base}$',
amp_desired=[5, 20], isi_delta=None, xlim_p=[0, 1.1], tags_individual=False, xlim=[],
add_texts=[0.25, 0], cells_plot2=[], RAM=True, scale_val=False,
titles=['Low-CV P-unit,', 'High-CV P-unit', 'Ampullary cell,'],
peaks_extra=[True, True, True]): # [0, 1.1]
plot_style()
# save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_']
save_names = [version_final()]
amps_desired = amp_desired
# amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class)
cell_type_type = 'cell_type_reclassified'
frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual', 'cell_type_reclassified',
'cell'], path_sp='/calc_base_data-base_frame_overview.pkl',
frame_general=False)
default_settings_cells_susept(cells_plot2)
if len(cells_plot2) == 1:
grid1 = big_grid_susept_pics(cells_plot2, top=0.9, bottom=0.12)
else:
grid1 = big_grid_susept_pics(cells_plot2, bottom=0.065)
plt_cellbody_singlecell(grid1, frame, amps_desired, save_names, cells_plot2, cell_type_type, xlim=xlim,
permuted=permuted, base_extra=base_extra,
color_same=color_same, fr_name=fr_name, eod_metrice=eod_metrice, isi_delta=isi_delta,
tags_individual=tags_individual, RAM=RAM, add_texts=add_texts, titles=titles, xlim_p=xlim_p,
peaks_extra=peaks_extra, scale_val=scale_val, )
save_visualization(pdf=True, individual_tag=cells_plot2[0])
def default_settings_cells_susept(cells_plot, l=3.7):
default_figsize(column=2, width=12, length=l * len(cells_plot)) # ts=10, fs=10, ls=10,
def big_grid_susept_pics(cells_plot, top=0.96, bottom=0.065):
grid = gridspec.GridSpec(1, 1, wspace=0.1, hspace=0.5, top=top, left=0.08, bottom=bottom, right=0.95)
grid1 = gridspec.GridSpecFromSubplotSpec(len(cells_plot), 1, grid[0], hspace=0.35,
wspace=0.35) # ,
return grid1
def show_func(show=True):
if show:
if os.path.exists('..\code\calc_model'):
plt.show()
else:
plt.close()
else:
plt.close()
def plt_RAM_overview_all_filename_selected():
save_name = 'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s'
frame_load = pd.read_csv(load_folder_name('calc_RAM') + '/' + save_name + '.csv')
scores = ['perc95_perc5_fr',
'perc80_perc5_fr',
'entropy_mat_fr',
'entropy_diagonal_fr',
]
col = 4
row = 2
cell_types = [' Ampullary', ' P-unit', ]
fig, ax = plt.subplots(row, col, sharex=True, figsize=(14, 7.5)) # constrained_layout=True,
for c, cell_type_here in enumerate(cell_types):
cell_type = frame_load.cell_type
p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit'
frame = frame_load.loc[p_pos]
plt.suptitle(cell_type_here + ' \n ' + save_name)
for s, score in enumerate(scores):
file_names = ['gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
cmap = rainbow_cmap(file_names, nrs=len(file_names))
for ff, file in enumerate(file_names):
frame_file = frame[frame.file_name == file]
try:
ax[c, s].scatter(frame_file['cv_wo_burstcorr'], np.array(frame_file[score]), color=cmap[ff],
label=file) # colors[' P-unit']
except:
print('axs thing1')
embed()
ax[c, s].set_title(score, fontsize=11)
if s < row * col - col:
pass
else:
ax[c, s].set_xlabel('CV') # fr_names[f]
ax[c, s].set_xlim(0, 2)
if s == len(scores) - 1:
ax[c, s].legend(ncol=1, loc=(1.3, 0))
ax[c, 0].set_ylabel(cell_type_here)
for a in range(4):
set_same_ylim(ax[:, a])
plt.subplots_adjust(left=0.06, right=0.8, top=0.83, wspace=0.45, hspace=0.3)
save_visualization(individual_tag='_score_' + str(score) + '_celltype_' + str(cell_type_here))
#
def plt_data_noise(time_wn, stimulus_wn, nfft, sampling, mt, b, m):
plt.subplot(2, 3, 1)
plt.plot(time_wn, stimulus_wn)
plt.xlim(0, 0.3)
plt.subplot(2, 3, 4)
p, f = ml.psd(stimulus_wn - np.mean(stimulus_wn), Fs=sampling, NFFT=nfft,
noverlap=nfft // 2, sides='twosided')
plt.plot(f, p)
plt.xlim(0, 0.3)
eod_mt_test, spikes_mt, sampling = link_arrays(b, mt.positions[:][m],
mt.extents[:][m])
time_here = np.arange(0, len(eod_mt_test) / sampling, 1 / sampling)
plt.subplot(1, 3, 2)
plt.plot(time_here - np.min(time_here), eod_mt_test)
plt.xlim(0, 0.3)
plt.subplot(1, 3, 3)
plt.plot(time_here - np.min(time_here), eod_mt_test)
plt.plot(time_wn, amp * stimulus_wn + 0.4, color='red')
plt.xlim(0, 0.3)
plt.show()
def plt_data_overview_amps(ax):
save_names = [
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
]
x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_stim_wo_burstcorr", "cv_stim_w_burstcorr"]
save_names_title = ['No burst correction (cv_base)', 'Burst correction (cv_base)', 'No burst correction (cv_stim)',
'Burst correction (cv_stim)']
counter = 0
for cv_n, cv_name in enumerate(x_axis):
frame_load = load_overview_susept(save_names[cv_n])
scores = [
'perc95_perc5_fr',
]
cell_types = [' P-unit', ' Ampullary', ]
for c, cell_type_here in enumerate(cell_types):
cell_type = frame_load['celltype'] # 'cell_type_reclassified'
p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit'
frame = frame_load.loc[p_pos]
for s, score in enumerate(scores):
file_names = ['InputArr_50to100hz_30',
'InputArr_150to200hz_30', 'InputArr_250to300hz_30',
'InputArr_350to400hz_30',
'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3',
'gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
file_names_there = ['gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
frame_file_ex = frame[frame.file_name.isin(file_names_there)]
frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9]
print(len(file_names))
frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file]
amps = np.array(frame_file.amp.unique())
cmap = rainbow_cmap(amps, nrs=len(amps))
for a, amp in enumerate(amps):
frame_amp = frame_file[frame_file.amp == amp]
cvs = frame_amp[cv_name] #
x_axis = cvs[frame_amp[score] > 0]
y_axis = np.array(frame_amp[score])[frame_amp[score] > 0]
ax[counter].set_title(save_names_title[cv_n])
max_val = 1.5
if 'P-unit' in cell_type_here:
marker = '.'
else:
marker = '*'
try:
ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val],
color=cmap[a], alpha=0.45, marker=marker,
s=10) # colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
except:
print('axs thing2')
embed()
ax[counter].set_xlabel(cv_name)
if counter == 0:
ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes
ax[counter].set_xlim(0, max_val)
counter += 1
return cell_type_here, score
def plt_data_overview2(ax, scores=['perc95_perc5_fr']):
##########################
# Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln
save_names = [
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
]
x_axis = ["cv_base", "cv_base", "cv_base_w_burstcorr", "cv_base_w_burstcorr"]
save_names_title = ['No burst correction', 'Burst correction', 'No burst correction',
'Burst correction']
counter = 0
for cv_n, cv_name in enumerate(x_axis):
frame_load = load_overview_susept(save_names[cv_n])
cell_types = [' P-unit', ' Ampullary', ]
for c, cell_type_here in enumerate(cell_types):
cell_type = frame_load['cell_type_reclassified'] # 'celltype' 'cell_type_reclassified'
p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit'
frame = frame_load.loc[p_pos]
for s, score in enumerate(scores):
# todo: hier den Übergang womöglich soft machen
mod_limits = mod_lims_modulation(cell_type_here, frame_load, score)
if cell_type_here == ' P-unit':
cm = 'Blues'
else:
cm = 'Greens'
cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[
::-1] # len(amps)
cmap = cmap[0:len(mod_limits)][::-1]
for ff, amp in enumerate(range(len(mod_limits) - 1)):
file_names = ['InputArr_50to100hz_30',
'InputArr_150to200hz_30', 'InputArr_250to300hz_30',
'InputArr_350to400hz_30',
'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3',
'gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
##############
# Auschlusskriterium 1, nur RAMs die bei Null anfangen
file_names_there = ['gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
frame_amp = frame[(frame['response_modulation'] > mod_limits[ff]) & (
frame['response_modulation'] <= mod_limits[ff + 1])]
try:
frame_file_ex = frame_amp[frame_amp.file_name.isin(file_names_there)]
except:
print('file thing')
embed()
##############
# Auschlusskriterium 2, mindestens 9 Sekunden
frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9]
print(len(file_names))
frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file]
##############
# Auschlusskriterium 3, kleiner als 10 % Kontrast
# oder nicht ausschließen und stattdessen Modulation Farben!
# frame_amp = frame_file[frame_file.amp < 9]
cvs = frame_file[cv_name] #
x_axis = cvs[frame_file[score] > 0]
y_axis = np.array(frame_file[score])[frame_file[score] > 0]
ax[counter].set_title(save_names_title[cv_n])
max_val = 1.5
try:
ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val],
alpha=1,
s=2.5, color=cmap[
ff], ) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
except:
print('axs thing3')
embed()
ax[counter].set_xlabel(cv_name)
test = False
if test:
for ff in range(len(cmap)):
plt.plot([1, 2], [1, 2 * ff], color=cmap[ff])
plt.show()
# embed()
if counter == 0:
ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes
ax[counter].set_xlim(0, max_val)
counter += 1
return cell_type_here, score
def plt_data_overview(ax, scores=['perc95_perc5_fr']):
save_names = [
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s',
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
]
x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_stim_wo_burstcorr", "cv_stim_w_burstcorr"]
save_names_title = ['No burst correction (cv_base)', 'Burst correction (cv_base)', 'No burst correction (cv_stim)',
'Burst correction (cv_stim)']
counter = 0
for cv_n, cv_name in enumerate(x_axis):
frame_load = load_overview_susept(save_names[cv_n])
colors = colors_overview()
cell_types = [' P-unit', ' Ampullary', ]
for c, cell_type_here in enumerate(cell_types):
cell_type = frame_load['celltype'] # 'cell_type_reclassified'
p_pos = np.where(np.array(cell_type) == cell_type_here) # ' P-unit'
frame = frame_load.loc[p_pos]
for s, score in enumerate(scores):
file_names = ['InputArr_50to100hz_30',
'InputArr_150to200hz_30', 'InputArr_250to300hz_30',
'InputArr_350to400hz_30',
'InputArr_50hz_30', 'gwn50Hz10s0.3', 'gwn50Hz50s0.3', 'gwn100Hz10s0.3',
'gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
##############
# Auschlusskriterium 1, nur RAMs die bei Null anfangen
file_names_there = ['gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
frame_file_ex = frame[frame.file_name.isin(file_names_there)]
##############
# Auschlusskriterium 2, mindestens 9 Sekunden
frame_file_ex = frame_file_ex[frame_file_ex.snippets == 9]
print(len(file_names))
frame_file = frame_file_ex # frame_file_ex[frame_file_ex[var_name] == file]
##############
# Auschlusskriterium 3, kleiner als 10 % Kontrast
frame_amp = frame_file[frame_file.amp < 9]
cvs = frame_amp[cv_name] #
x_axis = cvs[frame_amp[score] > 0]
y_axis = np.array(frame_amp[score])[frame_amp[score] > 0]
ax[counter].set_title(save_names_title[cv_n])
max_val = 1.5
try:
ax[counter].scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val],
color=colors[cell_type_here], alpha=0.45,
s=5) # colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
except:
print('axs thing4')
embed()
ax[counter].set_xlabel(cv_name)
if counter == 0:
ax[counter].set_ylabel(score) # cell_type_here+, transform=ax[counter,l].transAxes
ax[counter].set_xlim(0, max_val)
counter += 1
return cell_type_here, score
def plt_power2(spikes_all_here, axp, color='blue'):
spikes_mat = [[]] * len(spikes_all_here)
sampling_calc = 40000
nfft = 2 ** 14
p_array = [[]] * len(spikes_all_here)
alpha = 1
for s, sp in enumerate(spikes_all_here):
spikes_mat[s] = cr_spikes_mat(np.array(sp) / 1000, sampling_rate=sampling_calc,
length=int(sampling_calc * np.array(sp[-1]) / 1000))
p_array[s], f_array = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=sampling_calc, NFFT=nfft,
noverlap=nfft // 2)
axp.plot(f_array, p_array[s], alpha=float(alpha - 0.05 * s), color=color) # color=colors[str(cell_type)],
axp.set_xlim(0, 1000)
axp.set_xlabel('Hz')
axp.set_ylabel('Hz')
return p_array, f_array
def find_names_gwn(b):
names_mt_gwns = []
for mts in b.multi_tags:
if find_gwn(mts):
names_mt_gwns.append(mts.name)
return names_mt_gwns
def find_gwn(trials):
return ('file' in trials.name) or ('noise' in trials.name) or ('gwn' in trials.name) or (
'InputArr' in trials.name) or (
'FileStimulus-file-gaussian' in trials.name)
def model_and_data_isi_power(nr_clim=10, many=False, width=0.005, row='no', HZ50=True, fs=8, nffts=['whole'],
cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'],
contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz',
mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'],
external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300,
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1],
c_signal=[0.9],
cut_offs1=[300], clims='all', restrict='restrict'):
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
repeats = [9] # 30
powers = [3] # ,1]
aa = 0
good_data, remaining = overlap_cells()
cells_all = good_data
for _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, in it.product(
cells, D_extraction_method, external_noise, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs,
c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
if row == 'no':
col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange(
else:
pass
if row == 2:
default_settings(column=2, length=7.5) # 2+2.25+2.25
elif row == 1:
default_settings(column=2, length=4)
grid = gridspec.GridSpec(1, 4, wspace=0.6, bottom=0.075,
hspace=0.13, left=0.08, right=0.93, top=0.88, width_ratios=[0.7, 1, 1, 1])
a = 0
maxs = []
mins = []
ims = []
perc05 = []
perc95 = []
iternames = [D_extraction_method, external_noise,
internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
print(cells_all)
for all in it.product(*iternames):
var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
fig = plt.figure()
hs = 0.45
#################################
# model cells
adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width = plt_model_part(HZ50, a,
a_fe, a_fr,
adapt_type,
c_noise,
c_sig,
cell_recording_save_name,
cells_all,
cut_off1,
cut_off2,
dendrid,
extract,
fft_i,
fft_o, fig,
fs,
grid, hs,
ims,
mV_unit,
many, maxs,
mimick,
mins, nfft,
noise_added,
nr, perc05,
perc95,
power,
ref_type,
repeats,
spikes_unit,
stim_type_afe,
stim_type_noise,
stimulus_length,
trans,
trial_nrs,
var_items,
var_type,
variant,
width)
#################################
# data cells
grid_data = gridspec.GridSpecFromSubplotSpec(len(cells_all), 1, grid[1],
hspace=hs)
ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width,
cbar_label=False)
for ax in ax_data:
ax.set_ylabel(F2_xlabel())
#################################
# plt isi of data
grid_isi = gridspec.GridSpecFromSubplotSpec(len(cells_all), 1, grid[0],
hspace=hs)
ax_isi = plt_isi(cells_all, grid_isi, stack_spikes=stack_spikes_all, eod_frs=eod_frs)
ax_isi[0].get_shared_x_axes().join(*ax_isi)
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title) # +' file '
set_clim_same(ims, perc05, perc95, mins, maxs, nr_clim, clims)
axes = np.array([np.array(ax_data), np.array(ax_model[0:int(len(ax_model) / 2)]),
np.array(ax_model[int(len(ax_model) / 2)::]), np.array(ax_isi)])
fig.tag(np.transpose(axes), xoffs=-3, yoffs=2.9, minor_index=2)
save_visualization(pdf=True)
def model_and_data(width=0.005, nffts=['whole'], powers=[1], cells=["2013-01-08-aa-invivo-1"], contrasts=[0],
noises_added=[''], D_extraction_method=['additiv_cv_adapt_factor_scaled'],
internal_noise=['RAM'], external_noise=['RAM'], level_extraction=[''], receiver_contrast=[1],
dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9],
cut_offs1=[300]): # ['eRAM']
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
good_data, remaining = overlap_cells()
cells_all = [good_data[0]]
plot_style()
default_settings(column=2, length=4.9) # 0.75
grid = gridspec.GridSpec(3, 4, wspace=0.95, bottom=0.07,
hspace=0.23, left=0.09, right=0.9, top=0.92)
a = 0
maxs = []
mins = []
mats = []
ims = []
iternames = [D_extraction_method, external_noise,
internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
lp = 2
for all in it.product(*iternames):
var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
fig = plt.figure()
hs = 0.45
#################################
# data cells
grid_data = gridspec.GridSpecFromSubplotSpec(1, 1, grid[0, 1],
hspace=hs)
ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width,
cbar_label=True, lp=lp, title=True)
for ax in ax_data:
ax.set_xticks_delta(100)
set_ylabel_arrow(ax, xpos=xpos_y_modelanddata(), ypos=0.87)
set_xlabel_arrow(ax)
ax.arrow_spines('lb')
##################################
# model part
cell = '2012-07-03-ak-invivo-1'
cells_given = [cell]
save_names = [
'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_11_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
'calc_RAM_model-2__nfft_whole_power_1_afe_0.009_RAM_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_500000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV',
]
nrs_s = [2, 3, 6, 7, 10, 11]
titles = ['Trials=11 c=0.01', 'Trials=500000 c=0.01', 'Trials=11 \n Intrinsic split',
'Trials=500000\n Intrinsic split', 'Trials=11 c=0.01\n Intrinsic split',
'Trials=500000 c=0.01\n Intrinsic split']
ax_model = []
for s, sav_name in enumerate(save_names):
try:
ax = plt.subplot(grid[nrs_s[s]])
except:
print('vers something')
embed()
ax_model.append(ax)
save_name = load_folder_name('calc_model') + '/' + sav_name
cell_add, cells_save = find_cell_add(cells_given)
perc = 'perc'
path = save_name + '.pkl' # '../'+
stack = load_model_susept(path, cells_save, save_name.split(r'/')[-1] + cell_add)
add_nonlin_title, cbar, fig, stack_plot, im = plt_single_square_modl(ax, cell, stack, perc, titles[s],
width, titles_plot=True, resize=True)
ims.append(im)
mats.append(stack_plot)
maxs.append(np.max(np.array(stack_plot)))
mins.append(np.min(np.array(stack_plot)))
col = 2
row = 3
ax.set_xticks_delta(100)
ax.set_yticks_delta(100)
cbar.set_label(nonlin_title(' [' + add_nonlin_title), labelpad=lp) # rotation=270,
if (s in np.arange(col - 1, 100, col)) | (s == 0):
remove_yticks(ax)
else:
set_ylabel_arrow(ax, xpos=xpos_y_modelanddata(), ypos=0.87)
if s >= row * col - col:
set_xlabel_arrow(ax)
else:
remove_xticks(ax)
if len(cells) > 1:
a += 1
set_clim_same(ims, mats=mats, lim_type='up', nr_clim='perc', clims='', percnr=95)
#################################################
# Flowcharts
var_types = ['', 'additiv_cv_adapt_factor_scaled', 'additiv_cv_adapt_factor_scaled']
a_fes = [0.009, 0, 0.009]
eod_fe = [750, 750, 750]
ylim = [-0.5, 0.5]
c_sigs = [0, 0.9, 0.9]
grid_left = [[], grid[1, 0], grid[2, 0]]
ax_ams = []
for g, grid_here in enumerate([grid[0, 0], grid[1, 1], grid[2, 1]]):
grid_lowpass = gridspec.GridSpecFromSubplotSpec(3, 1,
subplot_spec=grid_here, hspace=0.2,
height_ratios=[1, 1, 0.1])
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
model_params = models[models['cell'] == '2012-07-03-ak-invivo-1'].iloc[0]
cell = model_params.pop('cell') # .iloc[0]# Werte für das Paper nachschauen
eod_fr = model_params['EODf'] # .iloc[0]
deltat = model_params.pop("deltat") # .iloc[0]
v_offset = model_params.pop("v_offset") # .iloc[0]
print(var_types[g] + ' a_fe ' + str(a_fes[g]))
noise_final_c, spike_times, stimulus, stimulus_here, time, v_dent_output, v_mem_output, frame = get_flowchart_params(
a_fes, a_fr, g, c_sigs[g], cell, deltat, eod_fr, model_params, stimulus_length, v_offset, var_types,
eod_fe=eod_fe)
if (len(np.unique(frame.RAM_afe)) > 1) & (len(np.unique(frame.RAM_noise)) > 1):
grid_lowpass2 = gridspec.GridSpecFromSubplotSpec(2, 1,
subplot_spec=grid_left[g], hspace=0.2)
axt_p2 = plt_time_arrays('purple', grid_lowpass2, 1, frame.RAM_noise, time=time, nr=0)
axt_p2.text(-0.6, 0.5, '$\%$', rotation=90, va='center', transform=axt_p2.transAxes)
color_timeseries = 'black'
axt_p2.set_xlabel('Time [ms]')
axt_p2.text(-0.6, 0.5, '$\%$', rotation=90, va='center', transform=axt_p2.transAxes)
ax_ams.append(axt_p2)
elif len(np.unique(frame.RAM_afe)) > 1:
color_timeseries = 'red'
elif len(np.unique(frame.RAM_noise)) > 1:
color_timeseries = 'purple'
print(str(g) + ' afevar ' + str(np.var(frame.RAM_afe)) + ' afenoise ' + str(np.var(frame.RAM_noise)))
try:
ax, ff, pp, ff_am, pp_am = plot_lowpass2([grid_lowpass[0]], time, frame.RAM_afe + frame.RAM_noise,
deltat, eod_fr,
color1=color_timeseries, lw=1, extract=False)
except:
print('add up thing')
embed()
ax.text(-0.6, 0.5, '$\%$', va='center', rotation=90, transform=ax.transAxes)
ax_ams.append(ax)
remove_xticks(ax)
ax_n, ff, pp, ff_am, pp_am = plot_lowpass2([grid_lowpass[1]], time, noise_final_c, deltat, eod_fr,
extract=False, color1='grey', lw=1)
remove_yticks(ax_n)
if g == 1:
remove_xticks(ax_n)
else:
ax_n.set_xlabel('Time [ms]')
ax_n.set_ylim(ylim)
set_same_ylim(ax_ams, up='up')
axes = np.concatenate([ax_data, ax_model])
axes = [ax_ams[0], axes[0], axes[1], axes[2], ax_ams[1], axes[3], axes[4], ax_ams[2], ax_ams[3], axes[5],
axes[6], ]
fig.tag(axes, xoffs=-3, yoffs=2)
save_visualization(pdf=True)
def xpos_y_modelanddata():
return -0.52
def F1_xlabel():
return '$f_{1}$\,[Hz]'
def plt_model_part(HZ50, a, a_fe, a_fr, adapt_type, c_noise, c_sig, cell_recording_save_name, cells_all,
cut_off1, cut_off2, dendrid, extract, fft_i, fft_o, fig, fs, grid, hs, ims,
mV_unit, many, maxs, mimick, mins, nfft, noise_added, nr, perc05, perc95, power,
ref_type, repeats, spikes_unit, stim_type_afe, stim_type_noise, stimulus_length, trans, trial_nrs,
var_items, var_type, variant, width, xlabels=True, perc='',
label=nonlin_title(), rows=2, title=True):
ax_model = []
for t, trials_stim in enumerate(repeats):
grid_model = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[2 + t],
hspace=hs)
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim,
power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, stim_type_afe=stim_type_afe, extract=extract,
noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
path = save_name + '.pkl'
print(t)
print(path)
model = load_model_susept(path, cells_all, save_name + 'all') # cells
adapt_type_name, ref_type_name, dendrid_name, stim_type_noise_name = define_names(var_type, stim_type_noise,
dendrid, ref_type,
adapt_type)
if len(model) > 0:
model = model[model.cell.isin(cells_all)] # ('cv_stim')
try:
cells_all = model.groupby('cv_stim').first().sort_values(by='cv_stim').cell # ('cv_stim')
except:
print('model thing')
for c, cell in enumerate(cells_all):
print(c)
try:
ax = plt.subplot(grid_model[c])
except:
print('something')
embed()
titles = ''
suptitles = ''
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe,
stim_type_noise_name)
suptitles, titles = find_titles_RAM(a_fe, cell, extract, noise_added, stim_type_afe_name,
stim_type_noise_name2, suptitles, titles, trials_stim, var_items,
var_type)
model_show = model[
(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
new_keys = model_show.index.unique() # [0:490]
try:
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_aspect('equal')
ax.set_xticks_delta(100)
ax.set_yticks_delta(100)
ax.arrow_spines('lb')
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength # (noise_strength ** 2) / 2
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived, model_show=model_show)
if many == True:
titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0]))
color = title_color(cell)
if title:
if t == 0:
ax.set_title(
titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str(
int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str(
np.round(model_show.cv.iloc[0], 2)) + ' $cv_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2)) + ' s=' + str(
np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs,
color=color) # + ' $D_{sig}$=' + str(np.round(D_derived, 5))
im = plt_RAM_perc(ax, perc, stack_plot)
ims.append(im)
maxs.append(np.max(np.array(stack_plot)))
mins.append(np.min(np.array(stack_plot)))
perc05.append(np.percentile(stack_plot, 5))
perc95.append(np.percentile(stack_plot, 95))
plt_triangle(ax, model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300,
model_show.eod_fr.iloc[0])
if HZ50:
plt_50_Hz_noise(ax, 300)
ax.set_aspect('equal')
cbar, left, bottom, width, height = colorbar_outside(ax, im, fig, add=0, width=width)
if t == 1:
ax.text(1.5, 0.5, label, rotation=90, ha='center', va='center',
transform=ax.transAxes)
remove_yticks(ax)
if (c == len(cells_all) - 1) & xlabels:
ax.text(1.05, -0.35, F1_xlabel(), ha='center', va='center',
transform=ax.transAxes)
ax.arrow_spines('lb')
else:
remove_xticks(ax)
print(c)
ax_model.append(ax)
a += 1
print('model done')
return adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width
def find_titles_RAM(a_fe, cell, extract, noise_added, stim_type_afe_name, stim_type_noise_name2, suptitles, titles,
trials_stim, var_items, var_type):
if 'cells' in var_items:
titles += cell[2:13]
else:
suptitles += cell[2:13]
if 'internal_noise' in var_items:
titles += ' intrinsic noise=' + stim_type_noise_name2
else:
suptitles += ' intrinsic noise=' + stim_type_noise_name2
if 'external_noise' in var_items:
titles += ' additive RAM=' + stim_type_afe_name
else:
suptitles += ' additive RAM=' + stim_type_afe_name
if 'repeats' in var_items:
titles += ' $N_{repeat}=$' + str(trials_stim)
else:
suptitles += ' $N_{repeat}=$' + str(trials_stim)
if 'contrasts' in var_items:
titles += ' contrast=' + str(a_fe)
else:
suptitles += ' contrast=' + str(a_fe)
if 'level_extraction' in var_items:
titles += ' Extract Level=' + str(extract)
else:
suptitles += ' Extract Level=' + str(extract)
if 'D_extraction_method' in var_items:
titles += str(var_type)
else:
suptitles += str(var_type)
if 'noises_added' in var_items:
titles += ' high freq noise=' + str(noise_added)
else:
suptitles += ' high freq noise=' + str(noise_added)
return suptitles, titles
def set_clim_same(ims, perc05=[], val_chosen=None, percnr=None, perc95=[], mins=[], maxs=[], mats=[],
nr_clim='perc', lim_type='',
clims='all', same='', mean_type=False, clim_min=[], clim_max=[]):
if clims == 'all':
if same == 'same':
if len(clim_min) < 1:
clim_min = []
clim_max = []
for a, im in enumerate(ims):
clim_min.append(im.get_clim()[0])
clim_max.append(im.get_clim()[1])
lim = np.max([np.abs(np.min(clim_min)), np.abs(np.max(clim_max))])
for a, im in enumerate(ims):
if lim_type == 'same':
ims[a].set_clim(-lim, lim)
elif lim_type == 'up':
ims[a].set_clim(0, lim)
else:
ims[a].set_clim(np.min(clim_min), np.max(clim_max))
else:
if len(mats) < 1:
if nr_clim == 'perc':
for im in ims:
if lim_type == 'up':
im.set_clim(0, np.max(perc95))
else:
im.set_clim(np.min(perc05), np.max(perc95))
else:
for im in ims:
im.set_clim(np.min(np.min(mins)) * nr_clim, np.max(np.max(maxs) / nr_clim))
else:
maxs, mins, perc05, perc95 = get_perc_vals(mats, percnr)
for i, im in enumerate(ims):
if nr_clim == 'perc':
if lim_type == 'up':
if mean_type:
im.set_clim(0, np.mean(perc95))
else:
im.set_clim(0, np.max(perc95))
else:
im.set_clim(np.min(perc05), np.max(perc95))
else:
im.set_clim(np.min(mins) * nr_clim, np.max(maxs) / nr_clim)
# todo: noch alle clim funkcitonen fusioenier
else:
if len(mats) < 1:
for i, im in enumerate(ims):
if nr_clim == 'perc':
if lim_type == 'up':
im.set_clim(0, perc95[i])
else:
im.set_clim(perc05[i], perc95[i])
elif nr_clim == 'None':
values = im.get_clim()
if lim_type == 'up':
if val_chosen:
im.set_clim(0, val_chosen)
else:
im.set_clim(0, values[-1])
else:
if lim_type == 'up':
im.set_clim(0, maxs[i] / nr_clim)
else:
im.set_clim(mins[i] * nr_clim, maxs[i] / nr_clim)
else:
maxs, mins, perc05, perc95 = get_perc_vals(mats, percnr)
for i, im in enumerate(ims):
if nr_clim == 'perc':
if lim_type == 'up':
im.set_clim(0, perc95[i])
else:
im.set_clim(perc05[i], perc95[i])
else:
if lim_type == 'up':
im.set_clim(0, maxs[i] / nr_clim)
else:
im.set_clim(mins[i] * nr_clim, maxs[i] / nr_clim)
def get_perc_vals(mats, percnr):
perc05 = []
perc95 = []
mins = []
maxs = []
for m in range(len(mats)):
mins.append(np.min(mats[m]))
if not percnr:
perc05.append(np.percentile(mats[m], 5))
perc95.append(np.percentile(mats[m], 95))
else:
perc05.append(np.percentile(mats[m], 100 - percnr))
perc95.append(np.percentile(mats[m], percnr))
maxs.append(np.min(mats[m]))
return maxs, mins, perc05, perc95
def plt_isi(cells_all, grid_isi, stack_spikes=[], eod_frs=[]):
frame = load_cv_base_frame(cells_all)
ax_isi = []
for f, cell in enumerate(cells_all):
axi = plt.subplot(grid_isi[f])
frame_cell = frame[(frame['cell'] == cell)]
# todo: hier mit dem EODfr nochmal schauen
if len(stack_spikes) > 0:
spikes = []
hists = []
for sp in range(len(stack_spikes[f].keys())):
spikes.append(np.array(stack_spikes[f][sp]))
hists.append(np.diff(spikes[-1]) / (1 / eod_frs[f]))
else:
spikes = frame_cell.spikes.iloc[0]
spikes_all, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_frs[f])
remove_yticks(axi)
axi.spines['left'].set_visible(False)
alpha = 1
for hh, h in enumerate(hists):
try:
axi.hist(h, bins=100, color='blue', alpha=float(alpha - 0.05 * hh))
except:
print('hist i')
embed()
ax_isi.append(axi)
axi.spines['left'].set_visible(False)
remove_yticks(axi)
if f == len(cells_all) - 1:
axi.set_xlabel('EODf multiple')
return ax_isi
def group_the_certain_group(grouped, DF2_desired, DF1_desired):
mult1 = np.array([a_tuple[2][0] for a_tuple in grouped.groups.keys()])
mult2 = np.array([a_tuple[2][1] for a_tuple in grouped.groups.keys()])
mult_array = np.round(np.abs(mult1 - DF1_desired) + np.abs((mult2 - DF2_desired)), 2)
restrict = np.argmin(mult_array)
return restrict
def extract_waves(variant, cell, stimulus_length, deltat, eod_fr, a_fr, a_fe, eod_fe, e, eod_fj, a_fj, phase_r=0,
nfft_for_morph=4068 * 4, phase_e=0):
if 'receiver' in variant:
time, time_fish_r, eod_fish_r, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves(
nfft_for_morph, cell, a_fr=a_fr, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fr)
else:
time = np.arange(0, stimulus_length, deltat)
time_fish_r = time * 2 * np.pi * eod_fr
eod_fish_r = a_fr * np.sin(time_fish_r + phase_r)
if 'emitter' in variant:
time, time_fish_e, eod_fish_e, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves(
nfft_for_morph, cell, a_fr=a_fe, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fe[e])
else:
time = np.arange(0, stimulus_length, deltat)
time_fish_e = time * 2 * np.pi * eod_fe[e]
eod_fish_e = a_fe * np.sin(time_fish_e + phase_e)
if 'jammer' in variant:
time, time_fish_e, eod_fish_e, ff_first, eod_fr_data_first, pp_first_not_log, eod_fish_r_first, p_array_new_first, f_new_first = load_waves(
nfft_for_morph, cell, a_fr=a_fj, stimulus_length=stimulus_length, sampling=1 / deltat, eod_fr=eod_fj)
else:
time = np.arange(0, stimulus_length, deltat)
time_fish_j = time * 2 * np.pi * eod_fj
eod_fish_j = a_fj * np.sin(time_fish_j + phase_e)
time_fish_sam = time * 2 * np.pi * (eod_fe[e] - eod_fr)
eod_fish_sam = a_fe * np.sin(time_fish_sam)
stimulus_am = eod_fish_e + eod_fish_r + eod_fish_j
stimulus_sam = eod_fish_r * (1 + eod_fish_sam)
return eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam
def plot_shemes3(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=0,
waves_present=['receiver', 'emitter', 'jammer', 'all'], sheme_shift=0, title=[]):
stimulus = np.zeros(len(eod_fish_r))
ax = []
xlim = [0, 0.05]
for ww, w in enumerate(waves_present):
ax = plt.subplot(grid0[ww + sheme_shift, g])
if w == 'receiver':
if title:
plt.title(title)
plt.plot(time, eod_fish_r, color='grey')
stimulus += eod_fish_r
plt.ylim(-1.1, 1.1)
plt.xlim(xlim)
ax.spines['bottom'].set_visible(False)
if g == 0:
plt.ylabel('f0')
elif w == 'emitter':
ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
plt.plot(time, eod_fish_e, color='orange')
stimulus += eod_fish_e
plt.ylim(-1.1, 1.1)
plt.xlim(xlim)
ax.spines['bottom'].set_visible(False)
if g == 0:
plt.ylabel('f1')
elif w == 'jammer':
ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
plt.plot(time, eod_fish_j, color='purple')
stimulus += eod_fish_j
plt.ylim(-1.1, 1.1)
plt.xlim(xlim)
if g == 0:
plt.ylabel('f2')
elif w == 'all':
ax.text(0.5, 1.25, '$=$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
plt.plot(time, stimulus, color='grey')
plt.ylim(-1.2, 1.2)
plt.xlim(xlim)
if g == 0:
plt.ylabel('Stimulus')
if g == 0:
if ww == 0:
plt.ylabel('f0')
elif ww == 1:
plt.ylabel('f1')
elif ww == 2:
plt.ylabel('f2')
elif ww == 3:
plt.ylabel('Stimulus')
ax.show_spines('')
ax.set_xticks([])
ax.set_yticks([])
return ax
def plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, ylim=[-1.1, 1.1], g=0,
waves_present=['receiver', 'emitter', 'jammer', 'all'], eod_fr=700, xlim=[0, 0.05],
color_am2='purple', extracted=False, extracted2=False, color_am='black',
title_top=False, title=[]):
stimulus = np.zeros(len(eod_fish_r))
ax = plt.subplot(grid0[0, g])
for ww, w in enumerate(waves_present):
if title_top:
if ww == 0:
# ok das ist alles anders zentriert, wie ich denke
# ha bedeutet dass das alignment horizontal ist nicht der anker
# und auch rechts und bottom ist genau anders herum
# EGOZENTRISCHE und nicht ALOZENTRISCHE Ausrichtung!!
# todo: vielleicht eine funktion die das in allozentrisch ändert, das kann sich doch keiner merken
ax.text(1, 1, title, va='bottom', ha='right', transform=ax.transAxes)
if w == 'receiver':
stimulus += eod_fish_r
elif w == 'emitter':
stimulus += eod_fish_e
elif w == 'jammer':
stimulus += eod_fish_j
elif w == 'all':
eod_interp, eod_norm = extract_am(stimulus, time, norm=False, sampling=1 / time[1], eodf=eod_fr,
emb=False, extract='')
plt.plot(time, stimulus, color='grey', linewidth=0.5)
if extracted: #
plt.plot(time, eod_interp, color=color_am, linewidth=1)
if extracted2: #
eod_interp2, eod_norm = extract_am(eod_interp, time, norm=False, sampling=1 / time[1], eodf=eod_fr,
emb=False, extract='')
test = False
if test:
nfft = 2 ** 16
_, _ = ml.psd(eod_interp2 - np.mean(eod_interp2), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
_, _ = ml.psd(eod_interp - np.mean(eod_interp), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
plt.plot(time, eod_interp2, color=color_am2, linewidth=1)
plt.ylim(-1.2, 1.2)
if len(xlim) > 0:
plt.xlim(xlim)
plt.ylim(ylim)
if g == 0:
plt.ylabel('stimulus')
if g == 0:
if ww == 3:
plt.ylabel('stimulus')
ax.show_spines('')
ax.set_xticks([])
ax.set_yticks([])
return ax
def motivation_small_roc(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'],
figsize=None, end='0', cut_matrix='malefemale', chose_score='mean_nrs',
detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'):
plot_style()
default_settings(column=2, length=3.7) # 3.3ts=12, ls=12, fs=12
show = True
# mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_'
datasets, data_dir = find_all_dir_cells()
# '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ]
autodefine = '_dfchosen_closest_first_'
cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ]
# c1s = [10] # 1, 10,
# c2s = [10]
c2 = 10
# detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort'
# detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch']
# phase_sorting = ''#'PhaseSort'
eodftype = '_psdEOD_'
indices = ['_allindices_']
chirps = [
''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_'
extract = '' # '_globalmax_'
if len(cells) < 1:
data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets)
_, _, _ = restrict_cell_type(cells, 'p-units')
start = 'min' #
cells = ['2022-01-28-ah-invivo-1']
DF2_desired = [-175]
DF1_desired = [-99]
for c, cell in enumerate(cells):
if not c2:
contrasts = [10, 5, 3, 1]
else:
contrasts = [c2]
if not c2:
contrasts = [10, 5, 3, 1]
else:
contrasts1 = [c1]
for c, contrast in enumerate(contrasts):
contrast_small = 'c2'
contrast_big = 'c1'
for contrast1 in contrasts1:
for devname_orig in devs:
datapoints = [1000]
for _ in datapoints:
################################
# prepare DF1 desired
# chose_score = 'auci02_012-auci_base_01'
# hier muss das halt stimmen mit der auswahl
# hier wollen wir eigntlich kein autodefine
# sondern wir wollen so ein diagonal ding haben
divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value(
DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small,
contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine,
cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01'
DF1_desired = DF1_desired # [::-1]
DF2_desired = DF2_desired # [::-1]
# ROC part
b = load_b_public(c, cell, data_dir)
mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell)
mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)]
for gg in range(len(DF1_desired)):
DF1_desired_ROC = [DF1_desired[gg]]
DF2_desired_ROC = [DF2_desired[gg]]
t3 = time.time()
# all trials in one
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2'],
as_index=False)
grouped_mean = chose_certain_group(DF1_desired[gg],
DF2_desired[gg], grouped,
several=True, emb=False,
concat=True)
# groups sorted by repro tag
# todo: evnetuell die tuples gleich hier umspeichern vom csv ''
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
as_index=False)
grouped_orig = chose_certain_group(DF1_desired[gg],
DF2_desired[gg],
grouped,
several=True)
group_mean = [grouped_orig[0][0], grouped_mean]
for d, detection in enumerate(detections):
mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat
if figsize:
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure()
grid = gridspec.GridSpec(1, 3, wspace=0.35, hspace=0.5, left=0.05, top=top,
bottom=0.14,
right=0.95, width_ratios=[4.2, 1,
1]) # height_ratios = [1,6]bottom=0.25, top=0.8,
grid0 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.15, hspace=0.06,
subplot_spec=grid[0],
height_ratios=[0.4, 3, 3]) # height_ratios=hr,
grid_sheme = gridspec.GridSpecFromSubplotSpec(1, 4, wspace=0.15, hspace=0.35,
subplot_spec=grid0[0])
xlim = [0, 100]
fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine)
###########################################
stimulus_length = 0.3
deltat = 1 / 40000
eodf = np.mean(group_mean[1].eodf)
eod_fr = eodf
a_fr = 1
eod_fe = eodf + np.mean(
group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fe = group_mean[0][1] / 100
eod_fj = eodf + np.mean(
group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = group_mean[0][0] / 100
variant_cell = 'no' # 'receiver_emitter_jammer'
eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
jammer_name = 'female'
titles = ['receiver ',
'+' + 'intruder ',
'+' + jammer_name,
'+' + jammer_name + '+intruder',
[]] ##'receiver + ' + 'receiver + receiver
gs = [0, 1, 2, 3, 4]
waves_presents = [['receiver', '', '', 'all'],
['receiver', 'emitter', '', 'all'],
['receiver', '', 'jammer', 'all'],
['receiver', 'emitter', 'jammer', 'all'],
] # ['', '', '', ''],['receiver', '', '', 'all'],
symbols = ['', '', '', '', '']
time_array = time_array * 1000
color0_burst = 'darkgreen'
color01 = 'blue'
color02 = 'red'
color012 = 'orange'
colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012]
extracted = [False, True, True, True]
ax_w = []
for i in range(len(waves_presents)):
ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid_sheme, time_array,
g=gs[i], title_top=True, eod_fr=eod_fr,
waves_present=waves_presents[i], ylim=ylim,
xlim=xlim, color_am=colors_am[i],
extracted=extracted[i],
title=titles[i]) # 'intruder','receiver'#jammer_name
ax_w.append(ax)
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
bar = False
if bar:
if i == 0:
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.16, '20 ms', va='center', fontsize=10,
transform=ax.transAxes)
printing = True
if printing:
print('time of arrays plotting: ' + str(time.time() - t3))
##########################################
# spike response
means_here = ['_MeanTrialsIndexPhaseSort', 'AllTrialsIndex']
array_chosen = 1
for m, mean_type in enumerate(means_here):
hr = [0.35, 1.2, 0, 3]
grid_psd = gridspec.GridSpecFromSubplotSpec(4, 4, wspace=0.15, hspace=0.35,
subplot_spec=grid0[m + 1],
height_ratios=hr, )
if d == 0: #
##############################################################
# load plotting arrays
arrays, arrays_original, spikes_pure = save_arrays_susept(
data_dir, cell, c, chirps, devs, extract, group_mean, mean_type,
plot_group=0,
rocextra=False, sorted_on=sorted_on)
fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3(
[arrays[0], arrays[2], arrays[1], arrays[3]],
[arrays_original[0], arrays_original[2], arrays_original[1],
arrays_original[3]], spikes_pure, cell, grid_psd, mean_type,
group_mean, xlim=xlim, row=d * 3,
array_chosen=array_chosen, ylim_log=(-50.5, 3),
color0_burst=color0_burst, xlim_psd=[0, 550],
color01=color01, color02=color02, color012=color012, add_burst_corr=True,
log='')
###################################################################
nrs = [1, 2]
for n, nr in enumerate(nrs):
grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.4,
subplot_spec=grid[nr],
height_ratios=[1, 1])
subdevision_nr = 3
dev = '05'
datapoints_way = ['absolut']
color = ['red', 'green', 'lightblue', 'pink', ]
fig = plt.gcf()
plot_group = 0
ranges = [plot_group]
_, _, _, _, _ = roc_part(
titles, devs, group_mean, ranges, fig, subdevision_nr, datapoints,
datapoints_way, color, c, chose_score, cell, DF1_desired_ROC,
DF2_desired_ROC, contrast_small,
contrast_big, contrast1, dfs, start, dev, contrast, grid2[0],
plot_group, autodefine2='_dfchosen_', sorted_on=sorted_on,
cut_matrix=cut_matrix,
mean_type=means_here[n], extract=extract, mult_type=mult_type,
eodftype=eodftype,
indices=indices, c1=c1, c2=c2, autodefine=autodefine)
ax = plt.gca()
ax.set_title(means_here[n])
plot_group = 1
ranges = [plot_group]
_, _, _, _, _ = roc_part(
titles, devs, group_mean, ranges, fig,
subdevision_nr, datapoints,
datapoints_way, color, c, chose_score,
cell, DF1_desired_ROC,
DF2_desired_ROC, contrast_small,
contrast_big, contrast1, dfs, start,
dev, contrast, grid2[1],
plot_group, sorted_on=sorted_on, autodefine2='_dfchosen_',
cut_matrix=cut_matrix,
mean_type=means_here[n], extract=extract, mult_type=mult_type,
eodftype=eodftype,
indices=indices, c1=c1, c2=c2, autodefine=autodefine)
suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str(
group_mean[0][2][0]) + ' DF1: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str(
group_mean[0][1]) + '$\%$ m2: ' + str(
group_mean[0][2][1]) + ' DF2: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str(
len(group_mean[1])) + fr_end
individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str(
DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type
axes = []
axes.append(ax_w)
axes.extend(np.transpose(ax_as))
axes.append(np.transpose(ax_ps))
fig.tag(ax_w, xoffs=-1.5, yoffs=1.4)
save_visualization(individual_tag=individual_tag, show=show, pdf=True)
return suptitle
def load_b_public(c, cell, data_dir):
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
if version_comp != 'public':
full_path = find_nix_full_path(c, cell, data_dir)
if os.path.exists(full_path): # todo: this maybe also has to be fixed
print('do ' + cell)
file = nix.File.open(full_path, nix.FileMode.ReadOnly)
b = file.blocks[0]
else:
b = []
else:
b = []
return b
def motivation_all(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'],
figsize=None, save=True, end='0', chose_score='mean_nrs',
detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'):
plot_style()
default_settings(column=2, length=6.7) # 3.3ts=12, ls=12, fs=12
show = True
# mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_'
datasets, data_dir = find_all_dir_cells()
DF2_desired = [-33]
DF1_desired = [133]
autodefine = '_dfchosen_closest_first_'
cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ]
# c1s = [10] # 1, 10,
# c2s = [10]
c2 = 10
# detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort'
# detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch']
# phase_sorting = ''#'PhaseSort'
eodftype = '_psdEOD_'
indices = ['_allindices_']
chirps = [
''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_'
extract = '' # '_globalmax_'
if len(cells) < 1:
data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets)
_, _, _ = restrict_cell_type(cells, 'p-units')
start = 'min' #
cells = ['2021-08-03-ac-invivo-1']
for c, cell in enumerate(cells):
contrasts = [c2]
for c, contrast in enumerate(contrasts):
contrast_small = 'c2'
contrast_big = 'c1'
contrasts1 = [c1]
for contrast1 in contrasts1:
for devname_orig in devs:
datapoints = [1000]
for _ in datapoints:
################################
# prepare DF1 desired
# chose_score = 'auci02_012-auci_base_01'
# hier muss das halt stimmen mit der auswahl
# hier wollen wir eigntlich kein autodefine
# sondern wir wollen so ein diagonal ding haben
divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value(
DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small,
contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine,
cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01'
DF1_desired = DF1_desired # [::-1]
DF2_desired = DF2_desired # [::-1]
# embed()
#######################################
# ROC part
_, _, _, _, _ = find_code_vs_not()
b = load_b_public(c, cell, data_dir)
mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell)
mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)]
for gg in range(len(DF1_desired)):
t3 = time.time()
ax_w = []
###################
# all trials in one
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2'],
as_index=False)
grouped_mean = chose_certain_group(DF1_desired[gg],
DF2_desired[gg], grouped,
several=True, emb=False,
concat=True)
# groups sorted by repro tag
# todo: evnetuell die tuples gleich hier umspeichern vom csv ''
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
as_index=False)
grouped_orig = chose_certain_group(DF1_desired[gg],
DF2_desired[gg],
grouped,
several=True)
###################
group_mean = [grouped_orig[0][0], grouped_mean]
for d, detection in enumerate(detections):
mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat
arrays, arrays_original, spikes_pure = save_arrays_susept(
data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0,
rocextra=False, sorted_on=sorted_on)
# hier checken wir ob für diesen einen Punkt das funkioniert mit der standardabweichung
try:
check_var_substract_method(spikes_pure)
except:
print('var checking not possible')
if figsize:
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure()
grid = gridspec.GridSpec(2, 3, wspace=0.7, hspace=0.35, left=0.075, top=top,
bottom=0.1, height_ratios=[1, 2],
right=0.935) # height_ratios = [1,6]bottom=0.25, top=0.8,
hr = [1, 0.35, 1.2, 0, 3, ] # 1
# several coherence plot
ax_w, d, data_dir, devs = plt_coherences(ax_w, d, devs, grid)
# part with the power spectra
grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35,
subplot_spec=grid[1, :],
height_ratios=hr)
xlim = [0, 100]
fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine)
###########################################
stimulus_length = 0.3
deltat = 1 / 40000
eodf = np.mean(group_mean[1].eodf)
eod_fr = eodf
a_fr = 1
eod_fe = eodf + np.mean(
group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fe = group_mean[0][1] / 100
eod_fj = eodf + np.mean(
group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = group_mean[0][0] / 100
variant_cell = 'no' # 'receiver_emitter_jammer'
eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
jammer_name = 'female'
cocktail_names = False
if cocktail_names:
titles = ['receiver ',
'+' + 'intruder ',
'+' + jammer_name,
'+' + jammer_name + '+intruder',
[]] ##'receiver + ' + 'receiver + receiver
else:
titles = title_motivation() ##'receiver + ' + 'receiver + receiver
gs = [0, 1, 2, 3, 4]
waves_presents = [['receiver', '', '', 'all'],
['receiver', 'emitter', '', 'all'],
['receiver', '', 'jammer', 'all'],
['receiver', 'emitter', 'jammer', 'all'],
] # ['', '', '', ''],['receiver', '', '', 'all'],
symbols = ['', '', '', '', '']
time_array = time_array * 1000
color0_burst = 'darkgreen'
color01 = 'green'
color02 = 'red'
color012 = 'orange'
colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012]
extracted = [False, True, True, True]
for i in range(len(waves_presents)):
ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array,
g=gs[i], title_top=True, eod_fr=eod_fr,
waves_present=waves_presents[i], ylim=ylim,
xlim=xlim, color_am=colors_am[i],
extracted=extracted[i],
title=titles[i]) # 'intruder','receiver'#jammer_name
ax_w.append(ax)
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
bar = False
if bar:
if i == 0:
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.16, '20 ms', va='center', fontsize=10,
transform=ax.transAxes)
printing = True
if printing:
print(time.time() - t3)
# spike response
array_chosen = 1
if d == 0: #
_, _, _ = plot_arrays_ROC_psd_single3(
[arrays[0], arrays[2], arrays[1], arrays[3]],
[arrays_original[0], arrays_original[2], arrays_original[1],
arrays_original[3]], spikes_pure, cell, grid0, mean_type,
group_mean, xlim=xlim, row=1 + d * 3,
array_chosen=array_chosen,
color0_burst=color0_burst, color01=color01, color02=color02,
color012=color012)
suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str(
group_mean[0][2][0]) + ' DF1: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str(
group_mean[0][1]) + '$\%$ m2: ' + str(
group_mean[0][2][1]) + ' DF2: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str(
len(group_mean[1])) + fr_end
individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str(
DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type
axes = []
axes.append(ax_w)
fig.tag(ax_w[0:3], xoffs=-2.3, yoffs=1.7)
fig.tag(ax_w[3::], xoffs=-1.9, yoffs=1.4)
if save:
save_visualization(individual_tag=individual_tag, show=show, pdf=True)
return suptitle
def check_var_substract_method(spikes_pure):
vars = {}
for k, key in enumerate(spikes_pure.keys()):
for j in range(len(spikes_pure[key])):
spikes_mat = cr_spikes_mat(spikes_pure[key][j] / 1000, 40000,
int(spikes_pure[key][j][-1] / 1000 * 40000)) # len(arrays[k][j])
smoothed = gaussian_filter(spikes_mat, sigma=0.0005 * 40000)
if key not in vars:
vars[key] = [np.var(smoothed)]
else:
vars[key].append(np.var(smoothed))
var_vals = []
for j in range(len(spikes_pure[key])):
var_vals.append(vars['012'][j] - vars['control_01'][j] - vars['control_02'][j] + vars['base_0'][j])
# ja wenn das stabil wäre wäre das in Ordnung aber so weiß nciht
print('single var vals:' + str(var_vals))
print('mean of single var vals:' + str(np.mean(var_vals)))
def plt_coherences(ax_w, d, devs, grid):
data_names, data_dir = find_all_dir_cells()
cell_here = ['2021-08-03-ab-invivo-1'] # cell
cell_here.extend(data_names)
data_names = ['2021-08-03-ab-invivo-1']
for data_name in data_names:
frame = load_coherence_file(data_name, '05')
if len(frame) > 0:
amps = np.sort(frame.amp.unique())[::-1]
file_names = frame.file_name.unique()
devs = ['05'] # original
for a, amp in enumerate(amps):
for file_name in file_names:
ax = plt.subplot(grid[0, a])
ax.set_ylim(0, 1)
for d, dev in enumerate(devs):
frame = load_coherence_file(data_name, dev)
if len(frame) > 0:
frame_cell = frame[
(frame.file_name == file_name) & (frame.amp == amp)]
names = ['coherence_s', 'coherence_r', 'coherence_r_exp'
] # 'coherence_r_direct_restrict',
labels = ['SR', '$\sqrt{RR}$', 'RR$_{exp}$',
]
colors = ['black', 'grey', 'brown'] # 'coherence_r_firstsnippet',
linestyles = ['-', '-', '--', '-', '--', '-',
'--'] # 'purple','-',
for n, name in enumerate(names):
if 'coherence_s' in name:
ax.plot(frame_cell['f'], frame_cell[name] ** 2,
label=labels[n], color=colors[n],
linestyle=linestyles[
n]) # , 'MI_r_direct', 'coherence_r_direct_restrict',
else:
ax.plot(frame_cell['f'], frame_cell[name],
label=labels[n], color=colors[n],
linestyle=linestyles[
n]) # , 'MI_r_direct', 'coherence_r_direct_restrict',
if amp < 1:
amp_name = amp
else:
amp_name = int(amp)
ax.set_title('Contrast=' + str(amp_name))
if a == 0:
ax.legend(loc=(0.75, 0.75))
ax.set_ylabel('Coherence')
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Coherence')
xlim = ax.get_xlim()
ax.set_xlim(0, xlim[-1])
ax_w.append(ax)
return ax_w, d, data_dir, devs
def load_coherence_file(data_name, dev):
save_name = load_folder_name(
'calc_RAM') + '/calc_coherence-coherence__cell_' + data_name + '_dev_' + dev + '.csv'
load_function = find_load_function()
name1 = load_function + save_name.split('/')[-1]
if not os.path.exists(name1):
frame = pd.read_csv(save_name, index_col=0)
frame.to_csv(name1)
frame = pd.read_csv(save_name, index_col=0)
else:
frame = pd.read_csv(name1, index_col=0)
return frame
def motivation_small(ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'],
mult_type='_multsorted2_',
top=0.94,
devs=['2'], figsize=None, save=True, end='0',
chose_score='mean_nrs', detections=['AllTrialsIndex'], sorted_on='LocalReconst0.2Norm'):
plot_style()
default_settings(column=2, length=3.5) # 3.3ts=12, ls=12, fs=12
show = True
datasets, data_dir = find_all_dir_cells()
# '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ]
DF2_desired = [-33]
DF1_desired = [133]
autodefine = '_dfchosen_closest_first_'
cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ]
c2 = 10
# phase_sorting = ''#'PhaseSort'
eodftype = '_psdEOD_'
indices = ['_allindices_']
chirps = [
''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_'
extract = '' # '_globalmax_'
if len(cells) < 1:
data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets)
_, _, _ = restrict_cell_type(cells, 'p-units')
start = 'min' #
cells = ['2021-08-03-ac-invivo-1']
for c, cell in enumerate(cells):
contrasts = [c2]
for c, contrast in enumerate(contrasts):
contrast_small = 'c2'
contrast_big = 'c1'
contrasts1 = [c1]
for contrast1 in contrasts1:
for devname_orig in devs:
datapoints = [1000]
for _ in datapoints:
################################
# prepare DF1 desired
# chose_score = 'auci02_012-auci_base_01'
# hier muss das halt stimmen mit der auswahl
# hier wollen wir eigntlich kein autodefine
# sondern wir wollen so ein diagonal ding haben
divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value(
DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small,
contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine,
cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01'
DF1_desired = DF1_desired # [::-1]
DF2_desired = DF2_desired # [::-1]
# ROC part
_, _, _, _, _ = find_code_vs_not()
b = load_b_public(c, cell, data_dir)
mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell)
mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)]
for gg in range(len(DF1_desired)):
t3 = time.time()
# all trials in one
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2'],
as_index=False)
grouped_mean = chose_certain_group(DF1_desired[gg],
DF2_desired[gg], grouped,
several=True, emb=False,
concat=True)
###################
# groups sorted by repro tag
# todo: evnetuell die tuples gleich hier umspeichern vom csv ''
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
as_index=False)
grouped_orig = chose_certain_group(DF1_desired[gg],
DF2_desired[gg],
grouped,
several=True)
###################
group_mean = [grouped_orig[0][0], grouped_mean]
for d, detection in enumerate(detections):
mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat
##############################################################
# load plotting arrays
arrays, arrays_original, spikes_pure = save_arrays_susept(
data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0,
rocextra=False, sorted_on=sorted_on)
####################################################
if figsize:
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure()
grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=top,
bottom=0.14,
right=0.95) # height_ratios = [1,6]bottom=0.25, top=0.8,
hr = [1, 0.35, 1.2, 0, 3, ] # 1
grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35,
subplot_spec=grid[0],
height_ratios=hr, )
xlim = [0, 100]
fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine)
###########################################
stimulus_length = 0.3
deltat = 1 / 40000
eodf = np.mean(group_mean[1].eodf)
eod_fr = eodf
a_fr = 1
eod_fe = eodf + np.mean(
group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fe = group_mean[0][1] / 100
eod_fj = eodf + np.mean(
group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = group_mean[0][0] / 100
variant_cell = 'no' # 'receiver_emitter_jammer'
eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
jammer_name = 'female'
titles = ['receiver ',
'+' + 'intruder ',
'+' + jammer_name,
'+' + jammer_name + '+intruder',
[]] ##'receiver + ' + 'receiver + receiver
gs = [0, 1, 2, 3, 4]
waves_presents = [['receiver', '', '', 'all'],
['receiver', 'emitter', '', 'all'],
['receiver', '', 'jammer', 'all'],
['receiver', 'emitter', 'jammer', 'all'],
] # ['', '', '', ''],['receiver', '', '', 'all'],
symbols = ['', '', '', '', '']
time_array = time_array * 1000
color0_burst = 'darkgreen'
color01 = 'green'
color02 = 'red'
color012 = 'orange'
colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012]
extracted = [False, True, True, True]
ax_w = []
for i in range(len(waves_presents)):
ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array,
g=gs[i], title_top=True, eod_fr=eod_fr,
waves_present=waves_presents[i], ylim=ylim,
xlim=xlim, color_am=colors_am[i],
extracted=extracted[i],
title=titles[i]) # 'intruder','receiver'#jammer_name
ax_w.append(ax)
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
bar = False
if bar:
if i == 0:
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.16, '20 ms', va='center', fontsize=10,
transform=ax.transAxes)
printing = True
if printing:
print(time.time() - t3)
# spike response
array_chosen = 1
if d == 0: #
fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3(
[arrays[0], arrays[2], arrays[1], arrays[3]],
[arrays_original[0], arrays_original[2], arrays_original[1],
arrays_original[3]], spikes_pure, cell, grid0, mean_type,
group_mean, xlim=xlim, row=1 + d * 3,
array_chosen=array_chosen,
color0_burst=color0_burst, color01=color01, color02=color02,
color012=color012)
suptitle = cell + ' c1: ' + str(group_mean[0][0]) + '$\%$ m1: ' + str(
group_mean[0][2][0]) + ' DF1: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][0]) + ' c2: ' + str(
group_mean[0][1]) + '$\%$ m2: ' + str(
group_mean[0][2][1]) + ' DF2: ' + str(
group_mean[1]['DF1, DF2'].iloc[0][1]) + ' Trials nr ' + str(
len(group_mean[1])) + fr_end
individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str(
DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type
axes = []
axes.append(ax_w)
axes.extend(np.transpose(ax_as))
axes.append(np.transpose(ax_ps))
fig.tag(ax_w, xoffs=-1.5, yoffs=1.4)
if save:
save_visualization(individual_tag=individual_tag, show=show, pdf=True)
return suptitle
def csvReader(filename):
context = open(filename).read(2048)
dialect = csv.Sniffer().sniff(context)
return csv.reader(open(filename), dialect)
def plot_arrays_ROC_psd_single(arrays, arrays_original, spikes_pure, cell, grid0, mean_type,
group_mean, rocextra=False, xlim=[0, 100], row=4, way='absolut', color0='green',
color0_burst='darkgreen',
color01='blue', ylim_log=(-13.5, 3), add_burst_corr=False, color02='red',
array_chosen=1, color012='orange'):
arrs = []
for a, arr in enumerate(arrays):
time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40)
if len(xlim) > 0:
arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])])
else:
arrs.append(np.array(arr[0]))
ylim = [-2, np.max(arrs) + 30]
ps = {}
p_means = {}
p_means_all = {}
ax_ps = []
key_names = ['base_0', 'control_02', 'control_01', '012']
names = ['0', '02', '01', '012'] # color012color0, color02, color01,
colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0]
colors_p = [color0, color02, color01, color012, color02, color01, color0_burst, color0_burst, color0, color0]
xlim_psd = [0, 1000]
ylim_psd = [] # [-40, 10]
color_psd = 'black'
ax_as = []
for j in range(len(arrays)):
ax0 = plt.subplot(grid0[row, j])
ax_a = []
ax_a.append(ax0)
for i in range(len(spikes_pure[key_names[j]])):
ax0.eventplot(spikes_pure[key_names[j]],
color=colors[j])
ax0.show_spines('')
ax0.set_xticks([])
ax0.set_yticks([])
if len(xlim) > 0:
ax0.set_xlim(xlim)
ax00 = plt.subplot(grid0[row + 1, j])
ax_a.append(ax00)
# hier wird nur der erste Array geplottet
time_array = np.arange(0, len(arrays[j][0]) / 40, 1 / 40)
# embed()
if rocextra:
if '_AllTrialsIndex' in mean_type:
pass
else:
pass
else:
pass
try:
if '_AllTrialsIndex' in mean_type:
ax00.plot(time_array, arrays[j][array_chosen], color=colors[j])
else:
ax00.plot(time_array, arrays[j][0], color=colors[j])
except:
print('array thing')
embed()
if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2'
pass
if len(xlim) > 0:
ax00.set_xlim(xlim)
ax00.set_ylim(ylim)
ax00.show_spines('')
ax00.set_xticks([])
ax00.set_yticks([])
if j == 0:
length = 20
plus_here = 5
try:
ax00.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0]
ax00.yscalebar(-0.02, 0.35, 500, 'Hz', va='center', ha='left')
except:
ax00.plot([0, length], [ylim[0] + plus_here, ylim[0] + plus_here], color='black')
ax00.text(0, -0.2, str(length) + ' ms', va='center', fontsize=10,
transform=ax00.transAxes)
if len(xlim) > 0:
ax00.plot([xlim[0] + 0.01, xlim[0] + 0.01], [ylim[0], 500],
color='black')
else:
ax00.plot([time_array[0] + 0.01, time_array[0] + 0.01], [ylim[0], 500],
color='black')
ax00.text(-0.1, 0.4, ' 500 Hz', rotation=90, va='center', fontsize=10,
transform=ax00.transAxes)
# plot the corresponding psds
# hier kann man aussuchen welches power spektrum machen haben will
nfft = 2 ** 13 # 2 ** 18 # 17#16
p_mean_all_here = []
if '_AllTrialsIndex' in mean_type:
range_here = [array_chosen]
else:
range_here = range(len(arrays[j]))
for i in range_here:
p_type = '05'
if 'original' in p_type:
p_mean_all, f = ml.psd(arrays_original[j][i] - np.mean(arrays_original[j][i]), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
else:
p_mean_all, f = ml.psd(arrays[j][i] - np.mean(arrays[j][i]), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
p_mean_all_here.append(p_mean_all)
p_means_all[names[j]] = p_mean_all_here
ax_as.append(ax_a)
# das machen wir nochmal für einen gemeinsamen Ref Wert
for j in range(len(arrays)):
log = 'log' # '' # 'log'#''#'log'#''#
ref, ax00 = plt_single_pds(nfft, f, p_means, p_means_all[names[j]], ylim_psd, xlim_psd, color_psd, names,
ps,
arrays, ax_ps, grid0, row + 1, j, p_means_all, psd_type='mean_freq', log=log)
if j == 0:
if log == 'log':
ax00.set_ylabel('dB')
else:
ax00.set_ylabel('Hz/Hz$^2$')
ax00.set_xlim(xlim_psd)
DF1 = group_mean[1].DF1.iloc[-1]
DF2 = group_mean[1].DF2.iloc[-1]
fr_isis = []
if add_burst_corr:
frs_burst_corr = []
for i in range(len(spikes_pure['base_0'])):
fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc,
lim_here = find_lim_here(cell, 'individual')
print(lim_here)
eod_fr = group_mean[1].EODf.iloc[i]
spikes_all = spikes_pure['base_0'][i]
isi = calc_isi(spikes_all, eod_fr)
if np.min(isi) < lim_here:
hists2, spikes_ex, fr_burst_corr = correct_burstiness(isi, spikes_all,
[eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all), lim=lim_here,
burst_corr='individual')
frs_burst_corr.append(fr_burst_corr)
else:
frs_burst_corr.append(fr_isis[-1])
else:
for i in range(len(spikes_pure['base_0'])):
fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc,
fr_isi = np.nanmean(fr_isis)
freqs = [fr_isi, np.abs(DF2), np.abs(DF1),
np.abs(DF1) + np.abs(DF2), 2 * np.abs(DF2), 2 * np.abs(DF1),
]
try:
labels = ['Baseline=' + str(int(np.round(fr_isi))) + 'Hz',
'DF1=' + str(DF2) + 'Hz',
'DF2=' + str(DF1) + 'Hz',
'$|$DF1+DF2$|$=' + str(np.abs(DF1) + np.abs(DF2)) + 'Hz',
'DF1$_{H}$=' + str(DF2 * 2) + 'Hz',
'DF2$_{H}$=' + str(DF1 * 2) + 'Hz',
'fr_burst_corr_individual',
'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given',
'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz',
except:
print('label thing')
embed()
if add_burst_corr:
frs_burst_corr_mean = np.nanmean(frs_burst_corr)
freqs.extend([
frs_burst_corr_mean]) # np.abs(np.abs(DF1) - np.abs(DF2)),,np.array(np.nanmax(frame_spikes['highest_fr'])),np.array(np.nanmax(frame_spikes['highest_fr_burst_corr_individual']))
labels.extend(['Baseline_Burstcorr'])
colors_p.extend(['pink'])
choice = [[0], [1, 4], [2], [0, 1, 2, 3, 6]]
else:
choice = [[0], [1, 4], [2], [0, 1, 2, 3]]
if log == 'log':
pp = 10 * np.log10(p_means_all[names[j]] / ref)
pp_mean = 10 * np.log10(np.mean(p_means_all[names[j]], axis=0) / ref)
else:
pp = p_means_all[names[j]]
pp_mean = np.mean(p_means_all[names[j]], axis=0)
try: # todo: if log müsste hier was anderes rein, das log veränderte nämlich!
plt_peaks_several(np.array(freqs)[choice[j]], pp, ax00, pp_mean, f, np.array(labels)[choice[j]], j,
np.array(colors_p)[choice[j]], add_log=2.5, exact=False, text_extra=True,
perc_peaksize=0.08,
ms=14, clip_on=True, log=log) # True
except:
print('peaks thing0')
embed()
if log == 'log':
ax00.set_ylim(ylim_log)
ax00.show_spines('b')
if j == 0:
ax00.yscalebar(-0.02, 0.5, 10, 'dB', va='center', ha='left')
ax00.get_shared_y_axes().join(*ax_ps)
return fr_isi, ax_ps, ax_as
def plot_arrays_ROC_psd_single4(arrays, arrays_original, spikes_pure, cell, grid0, mean_type,
group_mean, names=['0', '02', '01', '012'],
xlim=[0, 100], row=4, way='absolut', datapoints=1000,
xlim_psd=[0, 235], color0='blue', color0_burst='darkgreen',
color01='green', ylim_log=(-15, 3), add_burst_corr=False, color02='red',
array_chosen=1, text_extra=True, color012_minus='purple', color012='orange', log='log'):
arrs = []
for a, arr in enumerate(arrays):
time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40)
if len(xlim) > 0:
arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])])
else:
arrs.append(np.array(arr[0]))
ylim = [-2, np.max(arrs) + 30]
ax_ps = []
key_names = ['base_0', 'control_02', 'control_01', '012']
colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0]
colors_p = [color0, color02, color01, color012, color02, color01, color012_minus, color0_burst, color0_burst,
color0, color0]
ylim_psd = [] # [-40, 10]
color_psd = 'black'
ax_as = []
for j in range(len(arrays)):
###################################
# plt spikes
try:
ax0 = plt.subplot(grid0[row, j])
plt_spikes_ROC(ax0, colors[j], spikes_pure[key_names[j]], xlim)
ax_a = []
ax_a.append(ax0)
except:
print('ax something')
embed()
#########################################
ax00 = plt.subplot(grid0[row + 1, j])
ax_a.append(ax00)
time_array = plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type,
way, xlim, ylim)
var_val = np.var(arrays[3]) - np.var(arrays[2]) - np.var(arrays[1]) + np.var(arrays[0])
print('mean var val:' + str(var_val))
p_means_all = {}
for j in range(len(arrays)):
########################################
# get the corresponding psds
# hier kann man aussuchen welches power spektrum machen haben will
f, nfft = get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all)
ax_as.append(ax_a)
# plot the psds
ps = {}
p_means = {}
ax00, fr_isi = plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all,
ps, row, spikes_pure,
time_array, names=names, color_psd=color_psd, add_burst_corr=add_burst_corr,
xlim_psd=xlim_psd,
ylim_log=ylim_log, ylim_psd=ylim_psd, log=log, text_extra=text_extra)
ax00.get_shared_y_axes().join(*ax_ps)
return fr_isi, ax_ps, ax_as
def plot_arrays_ROC_psd_single3(arrays, arrays_original, spikes_pure, cell, grid0, mean_type,
group_mean, names=['0', '02', '01', '012'],
xlim=[0, 100], row=4, way='absolut', datapoints=1000,
xlim_psd=[0, 235], color0='blue', color0_burst='darkgreen',
color01='green', ylim_log=(-15, 3), add_burst_corr=False, color02='red',
array_chosen=1, text_extra=True, color012_minus='purple', color012='orange', log='log'):
arrs = []
for a, arr in enumerate(arrays):
time_array = np.arange(0, len(arrays[a][0]) / 40, 1 / 40)
if len(xlim) > 0:
arrs.append(np.array(arr[0])[(time_array > xlim[0]) & (time_array < xlim[-1])])
else:
arrs.append(np.array(arr[0]))
ylim = [-2, np.max(arrs) + 30]
ax_ps = []
key_names = ['base_0', 'control_02', 'control_01', '012']
colors = ['grey', 'grey', 'grey', 'grey', color0_burst, color0_burst, color0, color0]
colors_p = [color0, color02, color01, color012, color02, color01, color012_minus, color0_burst, color0_burst,
color0, color0]
ylim_psd = [] # [-40, 10]
color_psd = 'black'
ax_as = []
for j in range(len(arrays)):
###################################
# plt spikes
ax0 = plt.subplot(grid0[row, j])
plt_spikes_ROC(ax0, colors[j], spikes_pure[key_names[j]], xlim)
ax_a = []
ax_a.append(ax0)
#########################################
ax00 = plt.subplot(grid0[row + 1, j])
ax_a.append(ax00)
time_array = plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type,
way, xlim, ylim)
var_val = np.var(arrays[3]) - np.var(arrays[2]) - np.var(arrays[1]) + np.var(arrays[0])
print('mean var val:' + str(var_val))
p_means_all = {}
for j in range(len(arrays)):
########################################
# get the corresponding psds
# hier kann man aussuchen welches power spektrum machen haben will
f, nfft = get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all)
ax_as.append(ax_a)
# plot the psds
ps = {}
p_means = {}
ax00, fr_isi = plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all,
ps, row, spikes_pure,
time_array, names=names, color_psd=color_psd, add_burst_corr=add_burst_corr,
xlim_psd=xlim_psd,
ylim_log=ylim_log, ylim_psd=ylim_psd, log=log, text_extra=text_extra)
ax00.get_shared_y_axes().join(*ax_ps)
return fr_isi, ax_ps, ax_as
def motivation_all_small_stim(dev_desired = '1',ylim=[-1.25, 1.25], c1=10, dfs=['m1', 'm2'], mult_type='_multsorted2_', top=0.94, devs=['2'],
figsize=None, redo=False, save=True, end='0', cut_matrix='malefemale', chose_score='mean_nrs',
a_fr=1, restrict='modulation', adapt='adaptoffsetallall2', step=str(30),
detections=['AllTrialsIndex'], variant='no', sorted_on='LocalReconst0.2Norm'):
autodefines = [
'triangle_diagonal_fr'] # ['triangle_fr', 'triangle_diagonal_fr', 'triangle_df2_fr','triangle_df2_eodf''triangle_df1_eodf', ] # ,'triangle_df2_fr''triangle_df1_fr','_triangle_diagonal__fr',]
cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ]
c1s = [10] # 1, 10,
c2s = [10]
plot_style()
default_figsize(column=2, length=3.3) #6.7 ts=12, ls=12, fs=12
show = True
DF2_desired = [0.8]
DF1_desired = [0.87]
DF2_desired = [-0.23]
DF1_desired = [0.94]
# mean_type = '_MeanTrialsIndexPhaseSort_Min0.25sExcluded_'
extract = ''
datasets, data_dir = find_all_dir_cells()
cells = ['2022-01-28-ah-invivo-1'] # , '2022-01-28-af-invivo-1', '2022-01-28-ab-invivo-1',
# '2022-01-27-ab-invivo-1', ] # ,'2022-01-28-ah-invivo-1', '2022-01-28-af-invivo-1', ]
append_others = 'apend_others' # '#'apend_others'#'apend_others'#'apend_others'##'apend_others'
autodefine = '_DFdesired_'
autodefine = 'triangle_diagonal_fr' # ['triangle_fr', 'triangle_diagonal_fr', 'triangle_df2_fr','triangle_df2_eodf''triangle_df1_eodf', ] # ,'triangle_df2_fr''triangle_df1_fr','_triangle_diagonal__fr',]
DF2_desired = [-33]
DF1_desired = [133]
autodefine = '_dfchosen_closest_'
autodefine = '_dfchosen_closest_first_'
cells = ['2021-08-03-ac-invivo-1'] ##'2021-08-03-ad-invivo-1',,[10, ][5 ]
# c1s = [10] # 1, 10,
# c2s = [10]
minsetting = 'Min0.25sExcluded'
c2 = 10
# detections = ['MeanTrialsIndexPhaseSort'] # ['AllTrialsIndex'] # ,'MeanTrialsIndexPhaseSort''DetectionAnalysis''_MeanTrialsPhaseSort'
# detections = ['AllTrialsIndex'] # ['_MeanTrialsIndexPhaseSort_Min0.25sExcluded_extended_eod_loc_synch']
extend_trials = '' # 'extended'#''#'extended'#''#'extended'#''#'extended'#''#'extended'#''#'extended'# ok kein Plan was das hier ist
# phase_sorting = ''#'PhaseSort'
eodftype = '_psdEOD_'
concat = '' # 'TrialsConcat'
indices = ['_allindices_']
chirps = [
''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_'
extract = '' # '_globalmax_'
devs_savename = ['original', '05'] # ['05']#####################
# control = pd.read_pickle(
# load_folder_name(
# 'calc_model') + '/modell_all_cell_no_sinz3_afe0.1__afr1__afj0.1__length1.5_adaptoffsetallall2___stepefish' + step + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096.pkl')
if len(cells) < 1:
data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets)
cells, p_units_cells, pyramidals = restrict_cell_type(cells, 'p-units')
# default_settings(fs=8)
start = 'min' #
cells = ['2021-08-03-ac-invivo-1']
tag_cells = []
for c, cell in enumerate(cells):
counter_pic = 0
contrasts = [c2]
tag_cell = []
for c, contrast in enumerate(contrasts):
contrast_small = 'c2'
contrast_big = 'c1'
contrasts1 = [c1]
for contrast1 in contrasts1:
for devname_orig in devs:
datapoints = [1000]
for d in datapoints:
################################
# prepare DF1 desired
# chose_score = 'auci02_012-auci_base_01'
# hier muss das halt stimmen mit der auswahl
# hier wollen wir eigntlich kein autodefine
# sondern wir wollen so ein diagonal ding haben
extra_f_calculatoin = False
if extra_f_calculatoin:
divergnce, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired, DF2_desired, min_y, min_x, min_val, diff_cut = chose_mat_max_value(
DF1_desired, DF2_desired, '', mult_type, eodftype, indices, cell, contrast_small,
contrast_big, contrast1, dfs, start, devname_orig, contrast, autodefine=autodefine,
cut_matrix='cut', chose_score=chose_score) # chose_score = 'auci02_012-auci_base_01'
DF1_desired = [1.2]#DF1_desired # [::-1]
DF2_desired = [0.95]#DF2_desired # [::-1]
#embed()
#######################################
# ROC part
# fr, celltype = get_fr_from_info(cell, data_dir[c])
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
b = load_b_public(c, cell, data_dir)
mt_sorted = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell)
counter_waves = 0
mt_sorted = mt_sorted[(mt_sorted['c2'] == c2) & (mt_sorted['c1'] == c1)]
for gg in range(len(DF1_desired)):
# try:
t3 = time.time()
# except:
# print('time thing')
# embed()
ax_w = []
###################
# all trials in one
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2'],
as_index=False)
# try:
grouped_mean = chose_certain_group(DF1_desired[gg],
DF2_desired[gg], grouped,
several=True, emb=False,
concat=True)
# except:
# print('grouped thing')
# embed()
###################
# groups sorted by repro tag
# todo: evnetuell die tuples gleich hier umspeichern vom csv ''
grouped = mt_sorted.groupby(
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
as_index=False)
grouped_orig = chose_certain_group(DF1_desired[gg],
DF2_desired[gg],
grouped,
several=True)
gr_trials = len(grouped_orig)
###################
groups_variants = [grouped_mean]
group_mean = [grouped_orig[0][0], grouped_mean]
for d, detection in enumerate(detections):
mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat
##############################################################
# load plotting arrays
arrays, arrays_original, spikes_pure = save_arrays_susept(
data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group=0,
rocextra=False, sorted_on=sorted_on, dev_desired = dev_desired)
####################################################
####################################################
# hier checken wir ob für diesen einen Punkt das funkioniert mit der standardabweichung
try:
check_var_substract_method(spikes_pure)
except:
print('var checking not possible')
# fig = plt.figure()
# grid = gridspec.GridSpec(2, 1, wspace=0.7, left=0.05, top=0.95, bottom=0.15,
# right=0.98)
if figsize:
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure()
grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.35, left=0.055, top=top,
bottom=0.15,
right=0.935) # height_ratios=[1, 2], height_ratios = [1,6]bottom=0.25, top=0.8,
hr = [1, 0.35, 1.2, 0, 3, ] # 1
##########################################################################
# several coherence plot
# frame_psd = pd.read_pickle(load_folder_name('calc_RAM')+'/noise_data11_nfft1sec_original__StimPreSaved4__first__CutatBeginning_0.05_s_NeurDelay_0.005_s_2021-08-03-ab-invivo-1.pkl')
# frame_psd = pd.read_pickle(load_folder_name('calc_RAM') + '/noise_data11_nfft1sec_original__StimPreSaved4__first__CutatBeginning_0.05_s_NeurDelay_0.005_s_2021-08-03-ab-invivo-1.pkl')
coh = False
if coh:
ax_w, d, data_dir, devs = plt_coherences(ax_w, d, devs, grid)
# ax_cohs = plt.subplot(grid[0,1])
##########################################################################
# part with the power spectra
grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35,
subplot_spec=grid[:, :],
height_ratios=hr)
xlim = [0, 100]
###########################################
stimulus_length = 0.3
deltat = 1 / 40000
eodf = np.mean(group_mean[1].eodf)
eod_fr = eodf
a_fr = 1
eod_fe = eodf + np.mean(
group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fe = group_mean[0][1] / 100
eod_fj = eodf + np.mean(
group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = group_mean[0][0] / 100
variant_cell = 'no' # 'receiver_emitter_jammer'
print('f0' + str(eod_fr))
print('f1'+str(eod_fe))
print('f2' + str(eod_fj))
eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
jammer_name = 'female'
cocktail_names = False
if cocktail_names:
titles = ['receiver ',
'+' + 'intruder ',
'+' + jammer_name,
'+' + jammer_name + '+intruder',
[]] ##'receiver + ' + 'receiver + receiver
else:
titles = title_motivation()
gs = [0, 1, 2, 3, 4]
waves_presents = [['receiver', '', '', 'all'],
['receiver', 'emitter', '', 'all'],
['receiver', '', 'jammer', 'all'],
['receiver', 'emitter', 'jammer', 'all'],
] # ['', '', '', ''],['receiver', '', '', 'all'],
# ['receiver', '', 'jammer', 'all'],
# ['receiver', 'emitter', '', 'all'],'receiver', 'emitter', 'jammer',
symbols = [''] # '$+$', '$-$', '$-$', '$=$',
symbols = ['', '', '', '', '']
time_array = time_array * 1000
color01, color012, color01_2, color02, color0_burst, color0 = colors_suscept_paper_dots()
colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012]
extracted = [False, True, True, True]
extracted2 = [False, False, False, False]
for i in range(len(waves_presents)):
ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array,
g=gs[i], title_top=True, eod_fr=eod_fr,
waves_present=waves_presents[i], ylim=ylim,
xlim=xlim, color_am=colors_am[i],
color_am2 = color01_2, extracted=extracted[i], extracted2=extracted2[i],
title=titles[i]) # 'intruder','receiver'#jammer_name
ax_w.append(ax)
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
bar = False
if bar:
if i == 0:
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.16, '20 ms', va='center', fontsize=10,
transform=ax.transAxes)
printing = True
if printing:
print(time.time() - t3)
##########################################
# spike response
array_chosen = 1
if d == 0: #
#embed()
frs = []
for i in range(len(spikes_pure['base_0'])):
#duration = spikes_pure['base_0'][i][-1] / 1000
duration = 0.5
fr = len(spikes_pure['base_0'][i])/duration
frs.append(fr)
fr = np.mean(frs)
#embed()
base_several = False
if base_several:
spikes_new = []
for i in range(len(spikes_pure['base_0'])):
duration = 100
duration_full = 101#501
dur = np.arange(0, duration_full, duration)
for d_nr in range(len(dur) - 1):
#embed()
spikes_new.append(np.array(spikes_pure['base_0'][i][
(spikes_pure['base_0'][i] > dur[d_nr]) & (
spikes_pure['base_0'][i] < dur[
d_nr + 1])])/1000-dur[d_nr]/1000)
# spikes_pure['base_0'] = spikes_new
sampling_rate = 1/np.diff(time_array)
sampling_rate = int(sampling_rate[0]*1000)
spikes_mats = []
smoothed05 = []
for i in range(len(spikes_new)):
spikes_mat = cr_spikes_mat(spikes_new[i], sampling_rate, int(sampling_rate*duration/1000))
spikes_mats.append(spikes_mat)
smoothed05.append(gaussian_filter(spikes_mat, sigma=(float(dev_desired)/1000) * sampling_rate))
smoothed_base = np.mean(smoothed05, axis=0)
mat_base = np.mean(spikes_mats, axis=0)
else:
smoothed_base = arrays[0][0]
mat_base = arrays_original[0][0]
#embed()#arrays[0]v
fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3(
[[smoothed_base], arrays[2], arrays[1], arrays[3]],
[[mat_base], arrays_original[2], arrays_original[1],
arrays_original[3]], spikes_pure, cell, grid0, mean_type,
group_mean, xlim=xlim, row=1 + d * 3,
array_chosen=array_chosen,
color0_burst=color0_burst, color01=color01, color02=color02,ylim_log=(-15, 3),
color012=color012,color012_minus = color01_2,color0=color0)
##########################################################################
individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str(
DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type
# save_all(individual_tag, show, counter_contrast=0, savename='')
# print('individual_tag')
axes = []
axes.append(ax_w)
# axes.extend(np.transpose(ax_as))
# axes.append(np.transpose(ax_ps))
# np.transpose(axes)
#fig.tag(ax_w[0:3], xoffs=-2.3, yoffs=1.7)
#fig.tag(ax_w[3::], xoffs=-1.9, yoffs=1.4)
fig.tag(ax_w, xoffs=-1.9, yoffs=1.4)
# ax_w, np.transpose(ax_as), ax_ps
if save:
save_visualization(individual_tag=individual_tag, show=show, pdf=True)
# fig = plt.gcf()
# fig.savefig
# plt.show()
def plt_spikes_ROC(ax0, colors, spikes_pure, xlim, lw=None):
if lw:
ax0.eventplot(spikes_pure,
color=colors, linewidth=lw)
else:
ax0.eventplot(spikes_pure,
color=colors)
ax0.show_spines('')
ax0.set_xticks([])
ax0.set_yticks([])
if len(xlim) > 0:
ax0.set_xlim(xlim)
def plt_psds_ROC(arrays, ax00, ax_ps, cell, colors_p, f, grid0, group_mean, nfft, p_means, p_means_all, ps, row,
spikes_pure, time_array, names=['0', '02', '01', '012'], color_psd='black', add_burst_corr=False,
xlim_psd=[0, 235], clip_on=True, ms=14, labels=[], ax00s=[], choice=[], marker='o', text_extra=True,
alphas=[], range_plot=[], ylim_log=(-15, 3), ylim_psd=[], log='log', ax01=None):
psd_type = 'mean_freq'
if not range_plot:
range_plot = range(len(arrays))
for j in range_plot:
try:
ref, ax00 = plt_single_pds(nfft, f, p_means, p_means_all[names[j]], ylim_psd, xlim_psd, color_psd, names,
ps,
arrays, ax_ps, grid0, row + 1, j, p_means_all, ax00s=ax00s, ax00=ax01,
psd_type='mean_freq', log=log)
except:
print('ref not working')
embed()
if j == 0:
if log == 'log':
ax00.set_ylabel('dB')
else:
ax00.set_ylabel('Hz/Hz$^2$')
ax00.set_xlim(xlim_psd)
DF1 = group_mean[1].DF1.iloc[-1]
DF2 = group_mean[1].DF2.iloc[-1]
fr_isis = []
if add_burst_corr:
frs_burst_corr = get_burst_corr_peak(cell, fr_isis, group_mean, spikes_pure)
else:
for i in range(len(spikes_pure['base_0'])):
fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc,
fr_isi = np.nanmean(fr_isis)
freqs = [fr_isi, np.abs(DF2), np.abs(DF1),
np.abs(DF1) + np.abs(DF2), 2 * np.abs(DF2), 2 * np.abs(DF1),
np.abs(np.abs(DF1) - np.abs(DF2)), ]
try:
if not labels:
if j == 3:
labels_inside = ['',
'',
'',
fsum_core(DF1, DF2),
'',
'',
fdiff_core(DF1, DF2),
'fr_bc',
'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given',
'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz',
elif j == 2:
labels_inside = ['',
df1_core(DF2),
df2_core(DF1),
fsum_core(DF1, DF2),
f1_core(DF2),
f2_core(DF1),
fdiff_core(DF1, DF2),
'fr_bc',
'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given',
'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz',
else:
labels_inside = labels_all_motivation(DF1, DF2, fr_isi)
else:
labels_inside = labels
except:
print('label thing2')
embed()
#embed()
if add_burst_corr:
if not choice:
choice = update_burst_corr_peaks(colors_p, freqs, frs_burst_corr, labels_inside)
rots = [[45], [45, 45], [45], [45, 45, 45, 45, 45, 45]]
extra = []
left = 40
else:
if not choice:
choice = [[0], [1, 4], [0, 2], [0, 1, 2, 3, 4, 6]]
rots = [[0], [0, 0], [0, 0], [55, 55, 57, 45, 45, 45]] # 45
lefts = [[10], [25, 3], [0, 105], [10, 10, 10, 13, 12, 40]] # 40
extras = [[1], [1, 1], [1, 1], [1, 1, 2.5, 1.7, 4, 4]] # 4,1
extra = extras[j]
try:
left = np.array(lefts)[j]
except:
print('left something')
embed()
pp, pp_mean = decide_log_ROCs(j, log, names, p_means_all, ref)
try: # todo: if log müsste hier was anderes rein, das log veränderte nämlich!#2.5
plt_peaks_several(np.array(freqs)[choice[j]], pp, ax00, pp_mean, f, np.array(labels_inside)[choice[j]], j,
np.array(colors_p)[choice[j]], marker=marker, add_texts=extra, texts_left=left,
add_log=1.5, rots=np.array(rots)[j], exact=False, text_extra=text_extra,
perc_peaksize=0.08,
alphas=alphas, ms=ms, clip_on=clip_on, log=log) # True
except: # freqs, p_arrays, axs_p, p0_means, fs, labels=None, j=1, colors=None,
print('peaks thing2')
embed()
if log == 'log':
ax00.set_ylim(ylim_log)
ax00.show_spines('b')
if log == 'log':
if j == 0:
ax00.yscalebar(-0.02, 0.5, 10, 'dB', va='center', ha='left')
return ax00, fr_isi
def labels_all_motivation(DF1, DF2, fr_isi):
labels = [r'$f'+basename_small()+'=%s$' % (int(np.round(fr_isi))) + '\,Hz',
df1_core(DF2),
df2_core(DF1),
fsum_core(DF1, DF2),
f1_core(DF2),
f2_core(DF1),
fdiff_core(DF1, DF2),
'fr_bc',
'fr_given_burst_corr_individual', 'highest_fr_burst_corr_individual', 'fr', 'fr_given',
'highest_fr'] # '$|$DF1-DF2$|$=' + str(np.abs(np.abs(DF1) - np.abs(DF2))) + 'Hz',
return labels
def df2_core(DF1):
return '$|\Delta f_{2}|=|f_{2}-$' + f_eod_name_rm() + '$|=%s$' % (np.abs(DF1)) + '\,Hz'
def df1_core(DF2):
return '$|\Delta f_{1}|=|f_{1}-$' + f_eod_name_rm() + '$|=%s$' % (np.abs(DF2)) + '\,Hz'
def f2_core(DF1):
return '$2 |\Delta f_{2}|=%s$' % (DF1 * 2) + '\,Hz'
def f1_core(DF2):
return '$2 |\Delta f_{1}|=%s$' % (np.abs(DF2) * 2) + '\,Hz'
def fdiff_core(DF1, DF2):
return '$||\Delta f_{1}|-|\Delta f_{2}||=%s$' % (np.abs(np.abs(DF1) - np.abs(DF2))) + '\,Hz'
def fsum_core(DF1, DF2):
return '$||\Delta f_{1}| + |\Delta f_{2}||=%s$' % (np.abs(DF1) + np.abs(DF2)) + '\,Hz' # )
def decide_log_ROCs(j, log, names, p_means_all, ref):
if log == 'log':
pp = 10 * np.log10(p_means_all[names[j]] / ref)
pp_mean = 10 * np.log10(np.mean(p_means_all[names[j]], axis=0) / ref)
else:
pp = p_means_all[names[j]]
pp_mean = np.mean(p_means_all[names[j]], axis=0)
return pp, pp_mean
def update_burst_corr_peaks(colors_p, freqs, frs_burst_corr, labels):
frs_burst_corr_mean = np.nanmean(frs_burst_corr)
freqs.extend([
frs_burst_corr_mean]) # np.abs(np.abs(DF1) - np.abs(DF2)),,np.array(np.nanmax(frame_spikes['highest_fr'])),np.array(np.nanmax(frame_spikes['highest_fr_burst_corr_individual']))
labels.extend(['Baseline_Burstcorr'])
colors_p.extend(['pink'])
choice = [[0, 7], [1, 4], [2], [0, 1, 2, 3, 7]]
return choice
def get_burst_corr_peak(cell, fr_isis, group_mean, spikes_pure):
frs_burst_corr = []
for i in range(len(spikes_pure['base_0'])):
fr_isis.append(1 / np.mean(np.diff(spikes_pure['base_0'][i] / 1000))) # np.mean(fr), fr_calc,
lim_here = find_lim_here(cell, 'individual')
eod_fr = group_mean[1].EODf.iloc[i]
spikes_all = spikes_pure['base_0'][i]
isi = calc_isi(np.array(spikes_all) / 1000, eod_fr)
if np.min(isi) < lim_here:
hists2, spikes_ex, fr_burst_corr = correct_burstiness([isi], [spikes_all],
[eod_fr],
[eod_fr], lim=lim_here,
burst_corr='individual')
frs_burst_corr.append(fr_burst_corr)
else:
frs_burst_corr.append(fr_isis[-1])
return frs_burst_corr
def get_psds_ROC(array_chosen, arrays, arrays_original, j, mean_type, names, p_means_all, nfft=2 ** 13):
p_mean_all_here = []
if 'AllTrialsIndex' in mean_type: # AllTrialsIndex
range_here = [array_chosen]
print('alltrials choice')
else:
range_here = range(len(arrays[j]))
for i in range_here:
p_type = '05'
if 'original' in p_type:
p_mean_all, f = ml.psd(arrays_original[j][i] - np.mean(arrays_original[j][i]), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
else:
p_mean_all, f = ml.psd(arrays[j][i] - np.mean(arrays[j][i]), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
p_mean_all_here.append(p_mean_all)
try:
p_means_all[names[j]] = p_mean_all_here
except:
print('assign p problem')
embed()
return f, nfft
def plt_traces_ROC(array_chosen, arrays, ax00, colors, group_mean, j, mean_type, way, xlim,
ylim):
# hier wird nur der erste Array geplottet
time_array = np.arange(0, len(arrays[j][0]) / 40, 1 / 40)
try:
if '_AllTrialsIndex' in mean_type:
ax00.plot(time_array, arrays[j][array_chosen], color=colors[j])
else:
ax00.plot(time_array, arrays[j][0], color=colors[j])
except:
print('array thing')
embed()
if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2'
pass
if len(xlim) > 0:
ax00.set_xlim(xlim)
ax00.set_ylim(ylim)
ax00.show_spines('')
ax00.set_xticks([])
ax00.set_yticks([])
if j == 0:
length = 20
plus_here = 5
try:
ax00.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0]
ax00.yscalebar(-0.02, 0.35, 500, 'Hz', va='center', ha='left')
except:
ax00.plot([0, length], [ylim[0] + plus_here, ylim[0] + plus_here], color='black')
ax00.text(0, -0.2, str(length) + ' ms', va='center', fontsize=10,
transform=ax00.transAxes)
if len(xlim) > 0:
ax00.plot([xlim[0] + 0.01, xlim[0] + 0.01], [ylim[0], 500],
color='black')
else:
ax00.plot([time_array[0] + 0.01, time_array[0] + 0.01], [ylim[0], 500],
color='black')
ax00.text(-0.1, 0.4, ' 500 Hz', rotation=90, va='center', fontsize=10,
transform=ax00.transAxes)
return time_array
def save_arrays_susept(data_dir, cell, c, chirps, devs, extract, group_mean, mean_type, plot_group, rocextra,
sorted_on='LocalReconst0.2Norm', dev_desired='1',
mean_type0=''): # '_MeanTrialsIndex'
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
load_name = find_load_function()
if (version_comp == 'develop') | (version_comp == 'code'):
full_path = find_nix_full_path(c, cell, data_dir)
if os.path.exists(full_path): # todo: this maybe also has to be fixed
print('do ' + cell)
file = nix.File.open(full_path, nix.FileMode.ReadOnly)
b = file.blocks[0]
all_mt_names, mt_names, t_names = get_all_nix_names(b, what='Three')
if mt_names:
nix_there = check_nix_fish(b)
if nix_there: ##
printing = True
t3 = time.time()
spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three(
group_mean, b, extract, chirps=chirps, emb=False,
mean_type=mean_type, sorted_on=sorted_on,
devname_orig=[dev_desired, '05', 'original']) # times_sort=times_sort,
if printing: # todo: also das dauert lange das könnte man optimizieren
print('arrays1 ' + str(time.time() - t3))
try:
pass
except:
print('dev nr thing')
embed()
delays_length = define_delays_trials(mean_type,
frame, sorted_on=sorted_on)
printing = True
t3 = time.time()
if ('Phase' not in mean_type0) & (mean_type0 != ''):
for i in range(len(delays_length['base_0'])):
delays_length['base_0'][i] = np.arange(0, delays_length['base_0'][i][-1], 1)
try:
array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials(
frame,
dev_desired,
delays_length,
mean_type)
except:
print('sorting thing')
embed()
array0_original, array01_original, array02_original, array012_original, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials(
frame,
'original',
delays_length,
mean_type)
if printing:
print('arrays2 ' + str(time.time() - t3))
if 'TrialsConcat' in mean_type:
array0 = array0[0]
array01 = array01[0]
array02 = array02[0]
array012 = array012[0]
test = False
if test:
from utils_test import test_arrays
test_arrays() # array0, array01, array02, array012
printing = True
t3 = time.time()
if 'Mean' not in mean_type:
delays_length_m = define_delays_trials('_MeanTrialsIndexPhaseSort_Min0.25sExcluded_',
frame, sorted_on=sorted_on)
if printing:
print('arrays3 ' + str(time.time() - t3))
if rocextra:
arrays = [[array0, array01], [array02, array012]]
arrays = arrays[plot_group]
else:
arrays = [array0, array01, array02, array012]
arrays_original = [array0_original, array01_original, array02_original, array012_original]
names = ['base_0', 'control_01', 'control_02', '012']
spikes_all_out = {}
for n, name in enumerate(names):
spikes_all = []
for s, sp in enumerate(np.array(spikes_pure[name])):
spikes = spikes_pure[name].iloc[s] * 1000
if name != 'base_0':
if s != 0:
if len(delays_length) > 0:
cut = delays_length[name][s - 1][0] / 40
spikes = spikes[spikes > cut] - cut
else:
try:
cut = delays_length_m[name][s - 1][0] / 40
except:
print('delay thing')
embed()
spikes = spikes[spikes > cut] - cut
spikes_all.append(np.array(spikes))
spikes_all_out[name] = spikes_all
if version_comp == 'develop':
for n, name in enumerate(names):
save_here = save_csv_to_pandas(arrays[n])
save_here.to_csv(load_name + '_05_' + name + '.csv')
save_here = save_csv_to_pandas(arrays_original[n])
save_here.to_csv(load_name + '_original_' + name + '.csv')
save_here = save_csv_to_pandas(spikes_all_out[name])
save_here.to_csv(load_name + '_spikes_' + name + '.csv')
# todo: das noch mit den normalen spike resave funktionen machen
test = False
if test:
from utils_test import test_arrays
test_arrays()
# todo: hier arrays, arrays_original und spikes_pure speichern
elif version_comp == 'public':
spikes_all_out = {}
arrays_original = []
arrays = []
names = ['base_0', 'control_01', 'control_02', '012']
for n, name in enumerate(names):
spikes = pd.read_csv(load_name + '_spikes_' + name + '.csv', index_col=0)
array_o = pd.read_csv(load_name + '_original_' + name + '.csv', index_col=0)
array_05 = pd.read_csv(load_name + '_05_' + name + '.csv', index_col=0)
spikes_all_out[name] = np.array(np.transpose(spikes))
arrays_original.append(np.array(np.transpose(array_o)))
arrays.append(np.array(np.transpose(array_05)))
return arrays, arrays_original, spikes_all_out
def find_nix_full_path(c, cell, data_dir):
base = cell.split(os.path.sep)[-1] + ".nix"
if data_dir == '':
path = '../data/ThreeFish/' + cell
else:
path = load_folder_name('data') + data_dir[c] + cell
full_path = path + '/' + base
return full_path
def plt_single_pds(nfft, f, p_means, p_mean_all_here, ylim_psd, xlim_psd, color_psd, names, ps, arrays, ax_ps, grid0,
row, j, p_means_all, psd_type='mean_freq', ax00s=[], log='log', ax00=None):
if psd_type == 'single':
ref = np.max([p_means_all['012'][0], p_means_all['01'][0], p_means_all['02'][0], p_means_all['0'][0]])
if not ax00:
if not ax00s:
ax00 = ax00s[j]
else:
ax00 = plt.subplot(grid0[row + 2, j])
ax_ps.append(ax00)
nfft = 2 ** 16
p, f = ml.psd(arrays[j][0] - np.mean(arrays[j][0]), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
ps[names[j]] = p
p = log_calc_psd(log, p, ref)
ax00.plot(f, p, color=color_psd)
ax00.set_xlim(xlim_psd)
if len(ylim_psd) > 0:
ax00.set_ylim(ylim_psd)
remove_xticks(ax00)
if j != 0:
remove_yticks(ax00)
elif psd_type == 'mean_temporal':
if not ax00:
ax00 = plt.subplot(grid0[row + 2, j])
ax_ps.append(ax00)
# hier mache ich noch den temporal mean
p_mean, f_mean = ml.psd(np.mean(arrays[j], axis=0) - np.mean(np.mean(arrays[j], axis=0)), Fs=40000, NFFT=nfft,
noverlap=nfft // 2) #
p = log_calc_psd(log, p_mean, ref)
ax00.plot(f, p, color=color_psd)
ax00.set_xlim(xlim_psd)
if len(ylim_psd) > 0:
ax00.set_ylim(ylim_psd)
p_means[names[j]] = p_mean
remove_xticks(ax00)
if j != 0:
remove_yticks(ax00)
# hier mache ich einen mean über die differenz
elif psd_type == 'all':
if not ax00:
ax00 = plt.subplot(grid0[row + 2, j])
ax_ps.append(ax00)
for p in p_mean_all_here:
p = log_calc_psd(log, p, ref)
ax00.plot(f, p, color='grey')
ax00.set_xlim(xlim_psd)
if len(ylim_psd) > 0:
ax00.set_ylim(ylim_psd)
remove_xticks(ax00)
if j != 0:
remove_yticks(ax00)
elif psd_type == 'mean_freq':
array_here = []
for name in names:
array_here.append(np.mean(p_means_all[name], axis=0))
ref = np.max(array_here)
if not ax00:
ax00 = plt.subplot(grid0[row + 2, j])
ax_ps.append(ax00)
if log == 'log':
ax00.plot(f, 10 * np.log10(np.mean(p_mean_all_here, axis=0) / ref), color=color_psd)
if j == 0:
ax00.set_ylabel('dB')
else:
ax00.plot(f, np.mean(p_mean_all_here, axis=0), color=color_psd)
ax00.set_xlim(xlim_psd)
if len(ylim_psd) > 0:
ax00.set_ylim(ylim_psd)
try:
ax00.set_xlabel('Frequency [Hz]')
except:
print('freq')
embed()
if j != 0:
remove_yticks(ax00)
return ref, ax00
def log_calc_psd(log, p, ref):
if log == 'log':
p = 10 * np.log10(p / ref)
return p
def chose_certain_group(DF1_desired, DF2_desired, grouped, concat=False, several=False, emb=False):
if DF1_desired == 'all' or DF2_desired == 'all':
return
if several:
if (type(DF1_desired) == float) & (type(DF2_desired) == float):
restrict = group_the_certain_group_several(grouped, DF2_desired, DF1_desired, emb=False)
if concat:
key_list = list(map(tuple, grouped.groups.keys()))
keys_r = np.array(key_list, dtype=object)[np.array(restrict)]
groups = []
for r in keys_r:
if len(groups) < 1:
groups = grouped.get_group(tuple(r))
else:
groups = pd.concat([groups, grouped.get_group(tuple(r))])
print(len(grouped.get_group(tuple(r))))
final_grouped = groups
else:
try:
final_grouped = np.array(list(grouped), dtype=object)[restrict]
except:
keys_r = np.array(list(map(tuple, grouped.groups.keys())))[restrict]
groups = []
for r in keys_r:
if len(groups) < 1:
groups = [tuple(r), grouped.get_group(tuple(r))]
else:
groups = pd.concat([groups, grouped.get_group(tuple(r))])
print(len(grouped.get_group(tuple(r))))
final_grouped = groups
else:
restricts = []
final_grouped = []
for d in range(len(DF1_desired)):
restrict = group_the_certain_group_several(grouped, DF2_desired[d], DF1_desired[d], emb=False)
print(restrict)
if concat:
keys_r = np.array(list(map(tuple, grouped.groups.keys())))[restrict]
groups = []
for r in keys_r:
if len(groups) < 1:
groups = grouped.get_group(tuple(r))
else:
groups = pd.concat([groups, grouped.get_group(tuple(r))])
final_grouped.append(groups)
else:
final_grouped.append(np.array(list(grouped))[restrict])
restricts.append(restrict)
else:
if (type(DF1_desired) == float) & (type(DF2_desired) == float):
restrict = group_the_certain_group(grouped, DF2_desired, DF1_desired)
grouped = list(grouped) # [::-1]
final_grouped = grouped[restrict]
else:
restricts = []
final_grouped = []
for d in range(len(DF1_desired)):
restrict = group_the_certain_group(grouped, DF2_desired[d], DF1_desired[d])
print(restrict)
final_grouped.append(list(grouped)[restrict])
restricts.append(restrict)
if emb:
embed()
return final_grouped
def phase_sort_arrays(f, delays_length, frame_dev):
if f != 0:
if delays_length['012'][f - 1] != []:
frame_dev.iloc[f]['012'] = frame_dev['012'].iloc[f][
np.arange(delays_length['012'][f - 1][0], len(frame_dev['012'].iloc[f]),
1)] # np.array(frame_dev['012'].iloc[f])[np.array(list(map(int, delays_length['012'][f - 1])))]
if delays_length['control_01'][f - 1] != []:
frame_dev.iloc[f]['control_01'] = frame_dev['control_01'].iloc[f][
np.arange(delays_length['control_01'][f - 1][0], len(frame_dev['control_01'].iloc[f]),
1)] # frame_dev['control_01'].iloc[f][delays_length['control_01'][f - 1]]
if delays_length['control_02'][f - 1] != []:
frame_dev.iloc[f]['control_02'] = frame_dev['control_02'].iloc[f][
np.arange(delays_length['control_02'][f - 1][0], len(frame_dev['control_02'].iloc[f]), 1)]
if 'base_0' in delays_length.keys():
if delays_length['base_0'][f - 1] != []:
frame_dev.iloc[f]['base_0'] = frame_dev['base_0'].iloc[f][
np.arange(delays_length['base_0'][f - 1][0], len(frame_dev['base_0'].iloc[f]), 1)]
return frame_dev
def cut_uneven_trials(frame, devname, mean_type, delays_length, sampling=40000):
frame_dev = frame[frame['dev'] == devname]
# wenn das alles phase sorted sein soll, werden die davor ausgerichtet
# für das Threewave nicht notwenig
length = []
for f in range(len(frame_dev['012'])):
if 'Phase' in mean_type:
frame_dev = phase_sort_arrays(f, delays_length, frame_dev)
length.append([len(frame_dev['012'].iloc[f]), len(frame_dev['control_01'].iloc[f]),
len(frame_dev['control_02'].iloc[f]), len(frame_dev['base_0'].iloc[f])])
#######################
# hier werden alle trials auf die gleiche Länge geschnitten
array0_all, array01_all, array02_all, array012_all = cut_even_arrays(length, sampling, mean_type, frame_dev)
return array012_all, array01_all, array02_all, array0_all
def cut_even_arrays(length, sampling, mean_type, frame_dev):
# hier sagen wir mindestens z.B. 0.25 S!
if 'Min' in mean_type: # 0.25sExcluded # DEFAULT
ms_exclude = float(mean_type.split('Min')[1].split('sExcluded')[0])
exclude_array = np.array(length) > ms_exclude * sampling
length_min = np.min(np.array(length)[np.array(length) > 0.25 * sampling])
else:
exclude_array = np.ones_like(length)
length_min = np.min(length)
array012_all = [] # [[]] * len(frame_dev['012'])
array01_all = [] # [[]] * len(frame_dev['012'])
array02_all = [] # [[]] * len(frame_dev['012'])
array0_all = [] # [[]] * len(frame_dev['012'])
for f in range(len(frame_dev['012'])):
if exclude_array[f][0]:
array012_all.append(frame_dev['012'].iloc[f][0:length_min])
if exclude_array[f][1]:
array01_all.append(frame_dev['control_01'].iloc[f][0:length_min])
if exclude_array[f][2]:
array02_all.append(frame_dev['control_02'].iloc[f][0:length_min])
if exclude_array[f][3]:
array0_all.append(frame_dev['base_0'].iloc[f][0:length_min])
return array0_all, array01_all, array02_all, array012_all
def plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02,
array012_all, array012, ):
fig, ax = plt.subplots(4, 2, sharey=True, sharex=True)
lmax = np.nanmax([np.nanmax(np.transpose(array02_all)), np.nanmax(np.transpose(array01_all)),
np.nanmax(np.transpose(array0_all)), np.nanmax(np.transpose(array012_all))])
lmin = np.nanmin([np.nanmin(np.transpose(array02_all)), np.nanmin(np.transpose(array01_all)),
np.nanmin(np.transpose(array0_all)), np.nanmin(np.transpose(array012_all))])
names = ['base_0', 'control_01', 'control_02', '012']
xlim = 4000
x = 1000
for nn, n in enumerate(names):
frame_dev = frame[frame.dev == devname]
for i in range(len(frame_dev[n])):
if nn == 0:
ax[nn, 1].set_title('not sorted')
ax[nn, 1].plot(frame_dev[n].iloc[i])
ax[nn, 0].set_xlim(0, xlim)
ax[nn, 0].axvline(x=x)
_, _ = ml.psd(frame_dev[n].iloc[i] - np.mean(frame_dev[n].iloc[i]),
Fs=40000,
NFFT=8000, noverlap=8000 / 2)
plt.suptitle(devname)
ax[3, 0].set_ylabel('012')
ax[3, 0].plot(np.transpose(array012_all))
ax[3, 0].plot(array012[0], color='red')
ax[3, 0].set_ylim(lmin, lmax)
ax[3, 0].set_xlim(0, xlim)
ax[3, 0].axvline(x=x)
ax[1, 0].set_ylabel('01')
ax[1, 0].plot(np.transpose(array01_all))
ax[1, 0].plot(array01[0], color='red')
ax[1, 0].set_ylim(lmin, lmax)
ax[1, 0].set_xlim(0, xlim)
ax[1, 0].axvline(x=x)
ax[2, 0].set_ylabel('02')
ax[2, 0].plot(np.transpose(array02_all))
ax[2, 0].plot(array02[0], color='red')
ax[2, 0].set_ylim(lmin, lmax)
ax[2, 0].set_xlim(0, xlim)
ax[2, 0].axvline(x=x)
ax[0, 0].set_ylabel('0')
ax[0, 0].set_title('sorted')
ax[0, 0].plot(np.transpose(array0_all))
ax[0, 0].plot(array0[0], color='red')
ax[0, 0].set_ylim(lmin, lmax)
ax[0, 0].set_xlim(0, xlim)
ax[0, 0].axvline(x=x)
plt.subplots_adjust(hspace=0.4, wspace=0.35)
save_visualization(show=False)
def assign_trials(frame, devname, delays_length, mean_type, sampling=40000):
# get all trails together
array012_all, array01_all, array02_all, array0_all = cut_uneven_trials(frame, devname, mean_type, delays_length,
sampling=sampling)
# calculate mean or also not
array0, array012, array01, array02, mean_nrs = assign_trials_mean(devname, frame, array012_all, mean_type,
array01_all,
array02_all, array0_all)
return array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all
def assign_trials_mean(devname, frame, array012_all, mean_type, array01_all, array02_all, array0_all):
if '_MeanTrials' in mean_type:
if 'TrialsConcat' in mean_type:
trial_concats = int(mean_type.split('TrialsConcat_')[1][0])
length = len(array012_all)
array012 = []
array01 = []
array02 = []
array0 = []
for trial_concat in range(trial_concats):
array012.append([np.mean(np.array(array012_all[int(length * (trial_concat / trial_concats)):int(
length * ((trial_concat + 1) / trial_concats))]), axis=0)])
array01.append([np.mean(np.array(array01_all[int(length * (trial_concat / trial_concats)):int(
length * ((trial_concat + 1) / trial_concats))]), axis=0)])
array02.append([np.mean(np.array(array02_all[int(length * (trial_concat / trial_concats)):int(
length * ((trial_concat + 1) / trial_concats))]), axis=0)])
array0.append([np.mean(np.array(array0_all[int(length * (trial_concat / trial_concats)):int(
length * ((trial_concat + 1) / trial_concats))]), axis=0)])
else:
if 'snippets' in mean_type:
snippets = int(mean_type.split('snippets')[0][-1])
array012 = [np.mean(np.array(array012_all)[0:int(len(array012_all) / snippets)], axis=0),
np.mean(np.array(array012_all)[int(len(array012_all) / snippets):-1], axis=0)]
array01 = [np.mean(np.array(array01_all)[0:int(len(array012_all) / snippets)], axis=0),
np.mean(np.array(array01_all)[int(len(array01_all) / snippets):-1], axis=0)]
array02 = [np.mean(np.array(array02_all)[0:int(len(array012_all) / snippets)], axis=0),
np.mean(np.array(array02_all)[int(len(array02_all) / snippets):-1], axis=0)]
array0 = [np.mean(np.array(array0_all)[0:int(len(array012_all) / snippets)], axis=0),
np.mean(np.array(array0_all)[int(len(array0_all) / snippets):-1], axis=0)]
else:
array012 = [np.mean(np.array(array012_all), axis=0)]
array01 = [np.mean(np.array(array01_all), axis=0)]
array02 = [np.mean(np.array(array02_all), axis=0)]
array0 = [np.mean(np.array(array0_all), axis=0)]
mean_nrs = len(array012_all)
test = False
if test:
if devname == 'eod':
plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02,
array012_all, array012, )
plt.show()
test = False
if test == True:
plot_traces_frame_three_roc(frame, [], show=True)
else:
mean_nrs = 1
array012 = array012_all # np.array(frame_dev['012'])
array01 = array01_all # np.array(frame_dev['control_01'])
array02 = array02_all # np.array(frame_dev['control_02'])
array0 = array0_all # np.array(frame_dev['base_0'])
test = False
if test == True:
from utils_test import test_assign_trials
test_assign_trials(devname, array01, array02, array0, array012)
return array0, array012, array01, array02, mean_nrs
def plot_traces_frame_three_roc(frame, id_group, show=True, names=['base_0', 'control_01', 'control_02', '012']):
#########################################
# plot thre arrays of the data
counter = 0
fig, axis = plt.subplots(len(frame.dev.unique()), len(names), sharex=True)
if len(id_group) > 0:
plt.suptitle('DF1 ' + str(np.mean(id_group[1]['DF1'])) + ' DF2' + str(
np.mean(id_group[1]['DF2'])))
for ff, f in enumerate(frame.dev.unique()):
for nn, n in enumerate(names):
frame_dev = frame[frame.dev == f]
for i in range(len(frame_dev[n])):
axis[ff, nn].plot(frame_dev[n].iloc[i])
p, freq = ml.psd(frame_dev[n].iloc[i] - np.mean(frame_dev[n].iloc[i]),
Fs=40000,
NFFT=8000, noverlap=8000 / 2)
if len(id_group) > 0:
max_f = freq[np.argmax(p[freq < np.mean(id_group[1]['eodf'])])]
axis[ff, nn].set_title(f + ' ' + n + ' ' + str(max_f))
counter += 1
plt.subplots_adjust(hspace=0.6, wspace=0.6)
save_visualization(show=False)
if show:
plt.show()
def divergence_title_add_on(group_mean, fr, autodefine):
if 'triangle' in autodefine:
if 'df1' in autodefine:
try:
divergence = np.abs(np.abs(group_mean[1]['DF1, DF2'].iloc[0][0]) - fr)
except:
print('df1 divergence problems')
embed()
elif 'df2' in autodefine:
divergence = np.abs(np.abs(group_mean[1]['DF1, DF2'].iloc[0][1]) - fr)
elif 'diagonal' in autodefine:
try:
divergence = np.abs(np.abs(
np.abs(group_mean[1]['DF1, DF2'].iloc[0][0]) + np.abs(group_mean[1]['DF1, DF2'].iloc[0][1])) - fr)
except:
print('diagonal divergence problems')
embed()
else:
divergence = ''
fr_end = '\n fr ' + str(fr) + ' Hz ' + ' fr_m ' + str(
np.round(np.mean(fr / group_mean[1].EODf + 1), 2)) + ' Hz ' + 'diverge from Fr by ' + str(
divergence) + ' Hz ' + autodefine
else:
fr_end = ''
return fr_end
def find_env(way, results_diff, position_diff, sampling, f0='f0'):
beat1 = np.abs(results_diff.loc[position_diff, 'f1'] - results_diff.loc[position_diff, f0])
beat2 = np.abs(results_diff.loc[position_diff, 'f2'] - results_diff.loc[position_diff, f0])
if 'mult_minimum' in way:
env_f = np.min([np.min(beat1), np.min(beat2)])
elif 'mult_env' in way:
env_f = np.abs(beat1 - beat2)
if env_f == 0:
env_f = np.min([np.min(beat1), np.min(beat2)])
elif 'mult_f1' in way:
env_f = beat1
elif 'mult_f2' in way:
env_f = beat2
else:
if 'f1' in results_diff.keys():
env_f = np.abs(results_diff.loc[position_diff, 'f1'] - results_diff.loc[position_diff, 'f2'])
'mult_minimum', 'mult_env', 'mult_f1', 'mult_f2'
datapoints = int((1 / env_f) * int(way[-1]) * sampling)
return datapoints
def check_nix_fish(b, with_fish2='with_fish2'):
nix_there = False
names_mt = []
names_f = []
for t_nr, mt in enumerate(
b.multi_tags): # todo: hier kann man immer noch die Daten anschaeun die ohne Nix sind aber die waren glaube ich nciht so gut
names_mt.append(mt.name)
if ('Three' in mt.name) and not nix_there:
# ok man braucht das hier wenn man nicht erst über alle mts gehen will!
if with_fish2:
for ff, f in enumerate(mt.features):
if 'id' not in f.data.name and not nix_there:
names_f.append(f.data.name)
if 'fish2' in f.data.name:
nix_there = True
else:
nix_there = False
else:
nix_there = True
return nix_there
def load_data_arrays(extract, mt_group, sampling_rate, sorted_on, b, mt, mt_nr, delay, printing=False):
array_eod = {}
############################################
t1 = time.time()
# global + Efield
# die brauchen wir weil wir die als stimulus mit abspeichern wollen
eod_globalEfield, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay,
mt.extents[:][mt_nr] + delay, array_name='GlobalEFieldStimulus')
eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay,
mt.extents[:][mt_nr] + delay, array_name='EOD')
array_eod['Global'] = eod_global
array_eod['EField'] = eod_globalEfield
if printing:
print('second0 ' + str(time.time() - t1))
#######################################################
# das brauchen wir auf jeden Fall halt womöglich zum plotten
# 'LocalReconst', 'Global', 'EField'
t1 = time.time()
time_eod = np.arange(0, len(eod_global) / 40000, 1 / 40000) - delay
spikes_mt = link_arrays_spikes(b, first=mt.positions[:][mt_nr] - delay, second=mt.extents[:][mt_nr] + delay,
minus_spikes=mt.positions[:][mt_nr])
array_eod['spikes_mt'] = spikes_mt
array_eod['time_eod'] = time_eod
if printing:
print('second1 ' + str(time.time() - t1))
t1 = time.time()
nrs = ['', '0.2', '0.4']
norms = ['', 'Norm']
# hier define ich die ich minimal brauche
sorted_ons = ['LocalReconst', sorted_on]
for sorted_on_here in sorted_ons:
for norm in norms:
if norm == 'Norm':
norm_ex = True
else:
norm_ex = False
for nr in nrs:
# das Ganze jetzt auch nochmal groß machen für das sorting
name = 'LocalReconst' + str(nr) + norm
name_am = 'LocalReconst' + str(nr) + norm + 'Am'
if nr != '':
if name in sorted_on_here:
t1 = time.time()
eod_global_norm = zenter_and_normalize(eod_global, 1)
eod_globalEfield_norm02 = zenter_and_normalize(eod_globalEfield, float(nr))
eod_local = cut_ends(eod_global_norm, eod_globalEfield_norm02)
if printing:
print('second20 ' + str(time.time() - t1))
else:
eod_local = []
else:
eod_local = cut_ends(eod_global, eod_globalEfield)
########################################
# ich glaube wir brauchen das jetzt nicht für alle
# weil hier will ich dass nur die ams rauskommen die ich brauche ich muss das ja nicht üfr alle local global etcs. machen
# todo: aber das muss man noch shcauen ob das nicht hier crashed
if name_am in sorted_on_here:
if len(eod_local) > 0:
t1 = time.time()
eod_final_am, eod_final = extract_am(
eod_local,
array_eod['time_eod'],
sampling=sampling_rate,
eodf=mt_group[1].eodf[
mt_nr],
emb=False, norm=norm_ex,
extract=extract)
if printing:
print('second21 ' + str(time.time() - t1))
if sorted_on_here == name_am:
eod_final = []
else:
eod_final_am = []
else:
eod_final_am = []
eod_final = []
else:
if len(eod_local) > 0:
if norm_ex:
pass
else:
eod_final = eod_local
eod_final_am = []
else:
eod_final_am = []
eod_final = []
t1 = time.time()
array_eod = update_array_matrix(array_eod, eod_final, name)
array_eod = update_array_matrix(array_eod, eod_final_am, name_am)
if printing:
print('second22 ' + str(time.time() - t1))
if printing:
print('second2 ' + str(time.time() - t1))
t1 = time.time()
if 'LocalEOD' in sorted_on:
eod_local, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay,
mt.extents[:][mt_nr] + delay, array_name='LocalEOD-1')
else:
eod_local = []
if sorted_on == 'LocalEOD':
array_eod['LocalEOD'] = eod_local
else:
array_eod['LocalEOD'] = []
for norm in norms:
if norm == 'Norm':
norm_ex = True
else:
norm_ex = False
if 'LocalEOD' + norm in sorted_on:
eod_final_am, eod_final = extract_am(
eod_local,
array_eod['time_eod'],
sampling=sampling_rate,
eodf=mt_group[1].eodf[
mt_nr],
emb=False, norm=norm_ex,
extract=extract)
if 'LocalEOD' + norm + 'Am' in sorted_on:
array_eod['LocalEOD' + norm + 'Am'] = eod_final_am
else:
array_eod['LocalEOD' + norm + 'Am'] = []
if 'LocalEOD' + norm in sorted_on:
array_eod['LocalEOD' + norm] = eod_final
else:
array_eod['LocalEOD' + norm] = []
if printing:
print('second3 ' + str(time.time() - t1))
return array_eod
def update_array_matrix(array_eod, eod_local_reconstruct_big_norm, name):
if name not in array_eod.keys():
array_eod[name] = eod_local_reconstruct_big_norm
else:
if len(array_eod[name]) == 0:
array_eod[name] = eod_local_reconstruct_big_norm
return array_eod
def chirps_delete_analysis(eod_local_interp, eod_norm, fish_cuts, time_eod, cut, fish_number, fish_number_base):
eods, _ = cut_eod_sequences(eod_norm, fish_cuts, time_eod, cut=cut,
rec=False, fish_number=fish_number,
fillup=False,
fish_number_base=fish_number_base)
eods_int, _ = cut_eod_sequences(eod_local_interp, fish_cuts, time_eod,
cut=cut, rec=False,
fish_number=fish_number,
fillup=False,
fish_number_base=fish_number_base)
keys = [k for k in eods]
fish_number_final = fish_number * 1
for e in range(len(eods)):
if len(eods[keys[e]]) > 0:
try:
freq, freq1, freq2, freq3, freq4 = calc_power(
eods[keys[e]], nfft=2 ** 9,
sampling_rate=40000,
shift_by=0.001)
except:
print('freq problem')
embed()
eods_time = np.arange(0, len(eods[keys[e]]) / 40000, 1 / 40000)
time_freq = np.linspace(eods_time[0] + 2 ** 9 / (40000 * 2),
eods_time[-1] - 2 ** 9 / (40000 * 2),
len(freq))
detection = 'No_Chirp_detected'
time_detected = []
chirp_size = 35
random_data_std = np.std(eods_int[keys[e]])
random_data_mean = np.mean(eods_int[keys[e]])
anomaly_cut_off = random_data_std * 3
lower_limit = random_data_mean - anomaly_cut_off
upper_limit = random_data_mean + anomaly_cut_off
range_exeed = upper_limit - lower_limit
if (np.ptp(freq) > chirp_size) & (
np.ptp(eods_int[keys[e]]) > range_exeed): # > 0.3
perc95 = np.median(freq) + chirp_size
pos = np.diff(np.where(freq > perc95))
if 1 in pos:
if 1 in np.diff(np.where(pos == 1)):
lim_w = 0.04
lower_window = time_freq[
np.where(np.diff(freq))] - lim_w
upper_window = time_freq[
np.where(np.diff(freq))] + lim_w
time_detected = []
for nr_w in range(len(lower_window)):
eods_cut = eods_int[keys[e]][
(eods_time > lower_window[nr_w]) & (
eods_time < upper_window[nr_w])]
diverge = np.median(eods_int[keys[e]]) - np.min(
eods_cut)
if (diverge > 0.25) & (
detection != 'Chirp_detected'): #
time_detected.append(
lower_window[nr_w] + lim_w)
if keys[e] not in chirp.keys():
chirp[keys[e]] = [mt_idx]
else:
chirp[keys[e]].append(mt_idx)
fish_number_final[fish_number_final.index(
keys[e])] = 'interspace'
if len(np.unique(fish_number_final)) == 1:
if np.unique(fish_number_final)[
0] == 'interspace':
print('to many chirps')
detection = 'Chirp_detected'
elif np.ptp(eods_cut) > range_exeed: # > 0.55
time_detected.append(
lower_window[nr_w] + lim_w)
test = False
if test:
from utils_test import check_chirp_del_directly
check_chirp_del_directly(time_detected, eods_time, eods,
range_exeed, eods_int, keys, e,
freq,
detection, time_freq, freq1, freq2,
freq3,
freq4)
return fish_number_final
def cut_spikes_and_eod_three(mt_group, b, extract, cut_nr=0, chirps='', devname_orig=['05'], emb=False, test=False,
mean_type='', sorted_on='LocalReconst', sampling_rate=40000, devname=[],
done=False,
counter=0, printing=False, printing_all=False):
# todo: das könnte man noch vereinfachen dass nur die wichtigen Sachen rauskommen
t1 = time.time()
mt_list = mt_group[1]['mt']
frame = []
chirp = {}
spikes_pure = []
print('cut_spikes_and_eod_three is running')
for mt_idx, mt_nr in enumerate(list(map(int, mt_list))): # range(start_l, len(mt.positions[:]))
features, mt, name_here, l = get_mt_features3(b, mt_group, mt_idx)
# somehow we have mts with negative extend, we exclude these
t0 = time.time()
if (mt.extents[:][mt_nr] > 0).any():
t1 = time.time()
_, _, _, _, fish_number, fish_cuts, whole_duration, cont = load_durations(mt_nr, mt, mt_group[1], mt_idx,
mean_type=mean_type, emb=False)
delay = np.abs(fish_cuts[0])
if printing:
print('first ' + str(time.time() - t1))
if cont:
# embed()
########################################
# load the according EOD arrays
# basics und reconstructs
t1 = time.time()
array_eod = load_data_arrays(extract, mt_group, sampling_rate, sorted_on, b, mt, mt_nr, delay)
if printing: # todo:das dauert ewig!
print('second ' + str(time.time() - t1))
if (len(array_eod['LocalReconst']) > 0) & (len(array_eod['spikes_mt']) > 0) & (
not ((len(array_eod['LocalReconst']) < 1) or (len(array_eod['spikes_mt']) < 1))):
########################################
# extract the am of the loaded arrays
# das phase sorting sollte nicht anhand dieser AMs passieren, sondern anhand des gesamten Stimulus
# if 'PhaseSort' in mean_type:
eod_local_am = []
cut_edge = [cut_nr] # 0.02
for cut in cut_edge:
if (len(array_eod['time_eod']) > 0) & (len(array_eod['spikes_mt']) > 0) & (
array_eod['time_eod'][-1] + delay > whole_duration * 0.9) & (
array_eod['spikes_mt'][-1] + delay > whole_duration * 0.6) & any_spikes(
array_eod['spikes_mt'],
minimal=fish_cuts[0] + cut,
maximal=fish_cuts[-1] - cut):
t1 = time.time()
fish_number_base = remove_interspace_fish_nr(fish_number)
if 'ChirpsDelete' in chirps:
# die Snippets ausschließen wo der Fisch gechirpt hat
fish_number_final = chirps_delete_analysis(eod_local_recondstruct_am,
eods_local_reconstruct_norm, fish_cuts,
array_eod['time_eod'], cut, fish_number,
fish_number_base)
else:
fish_number_final = fish_number
devname, smoothened2, smoothed05, mat, time_array, arrays_calc, effective_duration, spikes_cut = cut_spikes_sequences(
delay, array_eod['spikes_mt'], sampling_rate, fish_cuts,
cut=cut, fish_number_base=fish_number_base,
fish_number=fish_number_final, devname_orig=devname_orig * 1,
mean_type=mean_type)
lengths = []
if printing:
print('Forth ' + str(time.time() - t1))
for name in fish_number_final[::-1]:
if 'interspace' not in name:
lengths.append(len(
arrays_calc[0][name])) # lengths.append(len(np.unique(arrays_calc[0][name])))
if np.min(lengths) > 2:
# das sind die verschiedenen EOD versions die man zum sortieren brauchen könnte
# if 'PhaseSort' in mean_type:
# eod_arrays = [eod_global, eod_local_am, eods_local_norm, eod_local_recondstruct_am, eod_local_reconstruct, eods_local_reconstruct_norm, eod_local_reconstruct_big_am, eod_local_reconstruct_big_norm]
# names = ['global','local_am','local_norm','local_reconst_am','local_reconst_norm','local_reconst','local_reconst_big_am','local_reconst_big_norm']
# eod_arrays = [eod_local_reconstruct_norm_huge,eod_global,eod_local_reconstruct, eod_local_reconstruct_big_norm]
# names = ['local','global','local_reconst','local_reconst_big_norm']
# todo da das nehmen was wir für das sort on so brauchen
# die braucen wir später fürs plotten einmal den stimulus mit machen ist immer gut
# das ist einmal reconstruiert und einmal das auf die richtige contrast größe gebracht
# todo das jetzt nochmal richtig machen und das so machen das man nur das macht was man braucht
# das sind die basics, das sind die die wir später plotten
###############################################################
t1 = time.time()
names = ['LocalReconst', 'Global', 'EField']
eod_arrays = [array_eod['LocalReconst'], array_eod['Global'], array_eod['EField']]
###############################################################
# und das ist fürs sorting, da nehmen wir jetzt auch noch das was wir eignetlich wollen
# todo: das muss man noch systematisch machen und bequemer implementieren
# eod_arrays_possible = [array_eod['Local'],
# array_eod['LocalReconst0.4Norm'],
# eod_local_recondstruct_am, eods_local_reconstruct_norm,
# eod_local_reconstruct_big_norm,
# eod_local_reconstruct_big_am]
# names_possible = ['Local',
# 'LocalReconst0.4Norm',
# 'LocalReconstAm', 'LocalReconst',
# 'LocalReconst0.2Norm', 'LocalReconst0.2Am']
# where_pos = np.where(np.array(names_possible) == sorted_on)[0]
if sorted_on in array_eod.keys(): # len(where_pos) > 0:
eod_arrays.append(array_eod[sorted_on]) # eod_arrays_possible[where_pos[0]])
names.append(sorted_on) # names_possible[where_pos[0]]
for e, eod_array in enumerate(eod_arrays):
try:
eods_cut, _ = cut_eod_sequences(eod_array, fish_cuts,
cut=cut, rec=False,
fish_number=fish_number_final,
fillup=True,
fish_number_base=fish_number_base)
except:
print('eod problem0')
embed()
arrays_calc.append(eods_cut)
time_array.append(array_eod['time_eod'])
devname.append(names[e])
if names[e] == 'Global':
idx = len(arrays_calc) - 1 + e
if printing:
print('Fifth ' + str(time.time() - t1))
t1 = time.time()
names_synch = ['EodLocSynch', 'EodAmSynch']
if 'Synch' in sorted_on:
#####################
# synthetisiere den stimulus aus dem global und dem idealen stimulus
# das ist gar nicht so schlecht
# das ist das gleiche wie das globale und das Efield
eods_loc_synch, eods_am_synch = synthetise_eod(mt_nr, extract, sampling_rate,
sampling_rate, mt_idx, idx,
arrays_calc,
mt_group)
eod_arrays = [eods_loc_synch, eods_am_synch]
# names = ['eod_loc_synch', 'eod_am_synch']
for e, eod_array in enumerate(eod_arrays):
arrays_calc.append(eods_cut)
time_array.append(array_eod['time_eod'])
devname.append(names_synch[e])
if test: # eod_local_am, eods_local_norm,'local_am','local_norm',
from utils_test import test_EOD_arrays
arrays_calc, devname, eods_cut, idx, time_array = test_EOD_arrays(cut, e,
eods_cut,
eods_loc_synch,
fish_cuts,
fish_number_base,
fish_number_final,
idx)
else:
array_eod[names_synch[0]] = []
array_eod[names_synch[1]] = []
if printing:
print('six ' + str(time.time() - t1))
##################################
# das in dataframe speichern
t1 = time.time()
frame, spikes_cut, spikes_pure, done = transform_dataframe(frame, spikes_pure, done,
arrays_calc, devname,
spikes_cut)
if printing:
print('seventh ' + str(time.time() - t1))
counter += 1
test = False
if test:
from utils_test import compare_chirp_nfft, compare_chirp_nfft_traces
compare_chirp_nfft() # time_eod, eod_local
compare_chirp_nfft_traces() # chirp, eod_local, time_eod, time_array[0],smoothed05, fish_cuts
else:
print('negative mt')
if done == False:
devname = []
frame = []
if printing_all:
print('all ' + str(mt_idx) + ' ' + str(time.time() - t0))
if emb:
embed()
if test == True:
from utils_test import test_eod_arrays2
test_eod_arrays2(frame)
if printing:
print('a_all ' + str(time.time() - t1))
if test == True:
names = []
for f in fish_number_base:
if not 'interspace' in f:
names.append(f)
overview_of_mt_group(frame, names=names)
if len(frame) < 1:
print('devname to short!')
return [[]] * 22
else:
return spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number_final, smoothened2, smoothed05, \
array_eod['LocalEOD'], eod_local_am, effective_duration, cut, devname, frame
def get_mt_features3(b, mt_group, mt_idx=-1):
name_here = mt_group[1]['mt_names'].iloc[mt_idx]
mt = b.multi_tags[name_here]
features, delay_name = feature_extract(mt)
l = mt_group[1]['mt'].iloc[-1]
return features, mt, name_here, l
def transform_dataframe(frame, spikes_pure, done, arrays_calc, devname, spikes_cut):
if done == False:
frame = pd.DataFrame(arrays_calc)
frame['dev'] = devname
spikes_pure = pd.DataFrame(spikes_cut)
done = True
else:
frame_new = pd.DataFrame(arrays_calc)
frame_new['dev'] = devname
frame = pd.concat([frame, frame_new])
spikes_cut = pd.DataFrame(spikes_cut)
spikes_pure = pd.concat([spikes_pure, spikes_cut])
return frame, spikes_cut, spikes_pure, done
def overview_of_mt_group(frame, names=['012', 'control_01', 'control_02', 'base_0']):
trial_nr = len(frame) / len(frame.dev.unique())
for i in range(int(trial_nr)):
fig, ax = plt.subplots(len(frame.dev.unique()), len(names), sharex=True)
for nn, name in enumerate(names):
for dd, dev in enumerate(frame.dev.unique()):
dev10 = frame[frame.dev == dev]
ax[dd, nn].plot(dev10[name].iloc[i])
ax[dd, nn].set_title(name)
save_visualization()
plt.show()
def calc_power(arr, nfft=2 ** 17, sampling_rate=40000, time_show=False, shift_by=0.01):
t1 = time.time()
shifts = np.arange(0, len(arr) - nfft, shift_by * sampling_rate)
np.arange(0, len(arr) - nfft, shift_by * sampling_rate)
freq = np.zeros(len(shifts))
freq1 = np.zeros(len(shifts))
freq2 = np.zeros(len(shifts))
freq3 = np.zeros(len(shifts))
freq4 = np.zeros(len(shifts))
pps = [[]] * len(shifts)
for s, start in enumerate(shifts):
pps[s], freq[s], freq1[s], freq2[s], freq3[s], freq4[s] = get_mult_freqs(arr[int(start):int(start + nfft)],
sampling_rate, nfft)
if time_show:
print('calc power' + str(time.time() - t1))
return freq, freq1, freq2, freq3, freq4
def get_mult_freqs(arr, sampling_rate, nfft, ):
p, f = ml.psd(
arr - np.mean(arr),
Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2)
pps = p
freq = f[np.argmax(p)]
freq1, freq2, freq3, freq4 = find_mult_freqs(p, freq, f)
return pps, freq, freq1, freq2, freq3, freq4
def find_mult_freqs(p, freq, f):
first_harm = (f > freq * 1.8) & (f < freq * 2.2)
freq1 = f[first_harm][np.argmax(p[first_harm])] / 2
second_harm = (f > freq * 2.8) & (f < freq * 3.2)
freq2 = f[second_harm][np.argmax(p[second_harm])] / 3
third_harm = (f > freq * 3.8) & (f < freq * 4.2)
freq3 = f[third_harm][np.argmax(p[third_harm])] / 4
forth_harm = (f > freq * 4.8) & (f < freq * 5.2)
freq4 = f[forth_harm][np.argmax(p[forth_harm])] / 5
return freq1, freq2, freq3, freq4
def find_corr_time(corr1):
corr_time_negative = corr1 * 1
corr_time = np.arange(0, len(corr_time_negative), 1)
corr_time_negative[corr_time > len(corr1) / 2] = 0
corr_time_neg_zentered = (corr_time - len(corr_time_negative) / 2) / 40000
return corr_time_neg_zentered, corr_time_negative, corr_time
def plt_frame_traces_original(frame_dev_eod, n):
fig, ax = plt.subplots(len(frame_dev_eod[n]), 1, sharex=True)
plt.title('Original Traces')
for i in range(len(frame_dev_eod[n])):
ax[i].plot(frame_dev_eod[n].iloc[i])
save_visualization()
plt.show()
def plt_delays_pair(i, inputs, outputs, titles, n, autocorr1, frame_dev_eod, shifted_eod, i_nr, delay,
corr_time_neg_zent, corr1):
grid1 = gridspec.GridSpec(4, 1, hspace=0.8,
wspace=1.2) #
test = False
if test == True:
from utils_test import plot_crosscorrelations
plot_crosscorrelations()
plt.subplot(grid1[0])
plt.title(titles + ' ' + n)
corr_time_neg_zent_a, corr_time_negative, corr_time = find_corr_time(
autocorr1)
plt.plot(corr_time_neg_zent_a, autocorr1, label='autocorr',
color='blue')
plt.plot(corr_time_neg_zent, corr1, label='corr initial',
color='red')
plt.axvline(x=0, color='grey', label='mitte')
plt.xlim(-0.1, 0.1)
plt.subplot(grid1[1])
plt.plot(corr_time_neg_zent_a, autocorr1, label='autocorr',
color='blue')
corr_test1 = scipy.signal.correlate(inputs[0:-delay], outputs[delay::])
corr_time_neg_zent_1, corr_time_negative, corr_time = find_corr_time(
corr_test1)
plt.plot(corr_time_neg_zent_1, corr_test1, label='corr after',
color='red')
plt.axvline(x=0, color='grey', label='mitte')
plt.legend(loc=(0, 4), ncol=2)
plt.xlim(-0.1, 0.1)
plt.subplot(grid1[2])
plt.title('eod before')
plt.plot(
np.arange(0, len(frame_dev_eod[n].iloc[i_nr]) / 40000, 1 / 40000),
frame_dev_eod[n].iloc[i_nr], label='i')
plt.plot(np.arange(0, len(frame_dev_eod[n].iloc[i + 1]) / 40000,
1 / 40000),
frame_dev_eod[n].iloc[i + 1],
label='i+1', color='red')
plt.xlim(0, 0.3)
plt.subplot(grid1[3])
plt.title('eod after')
plt.plot(
np.arange(0, len(frame_dev_eod[n].iloc[i_nr]) / 40000, 1 / 40000),
frame_dev_eod[n].iloc[i_nr], label='i')
plt.plot(np.arange(0, len(shifted_eod) / 40000, 1 / 40000),
shifted_eod, label='after2',
color='red')
plt.legend()
save_visualization()
plt.show()
def plt_shifted_input(inputs, delay, outputs):
plt.subplot(3, 1, 1)
plt.plot(inputs[0:-delay])
plt.subplot(3, 1, 2)
plt.plot(outputs[delay::])
plt.subplot(3, 1, 3)
plt.plot(inputs[0:-delay])
plt.plot(outputs[delay::])
save_visualization()
plt.show()
def find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length, i_nr=0):
inputs = frame_dev_eod[name_orig].iloc[i_nr] - np.nanmean(
frame_dev_eod[name_orig].iloc[i_nr]) # , input_eod,
if inputs != []:
outputs = frame_dev_eod[name_orig].iloc[i + 1] - np.nanmean(
frame_dev_eod[name_orig].iloc[i + 1]) # , output_eod,
titles = 'eod' # 'eod smoothed', '05', '2']
if outputs != []:
try:
autocorr1 = scipy.signal.correlate(inputs,
inputs)
corr1 = scipy.signal.correlate(inputs, outputs)
except:
print('corr1 in utils function')
embed()
corr_time_neg_zent, corr_time_negative, corr_time = find_corr_time(
corr1)
if 'Min' in mean_type:
minimum = float(mean_type.split('Min')[1].split('sExcluded_')[0])
corr_time_negative[corr_time < len(corr_time_negative) / 2 - minimum * 40000] = 0
delay = np.abs(np.argmax(corr_time_negative) - int(len(corr_time_negative) / 2))
shifted_eod = frame_dev_eod[name_orig].iloc[i + 1][delay::]
array_length = np.arange(0, len(outputs), 1)
delays_length[n].append(array_length[delay::])
plot = False
if plot == True:
###################
# plot traces
plt_frame_traces_original(frame_dev_eod, name_orig)
############
# plot crosscorrelation
plt_delays_pair(i, inputs, outputs, titles, name_orig, autocorr1, frame_dev_eod, shifted_eod, i_nr,
delay,
corr_time_neg_zent, corr1)
############
# plot shifted input
plt_shifted_input(inputs, delay, outputs)
else:
delays_length[n].append([])
else:
array_length = np.arange(0, len(frame_dev_eod[n].iloc[i_nr + 1]), 1)
delays_length[n].append(array_length)
i_nr += 1
return delays_length, i_nr
def create_arrays(df1, i, j, sampling_rate=10000):
time = np.arange(0, 30, 1 / sampling_rate) # period[-1]
time_fish_r = time * 2 * np.pi * df1[i]
eod_fish_r = 1 * np.sin(time_fish_r)
time_fish_e = time * 2 * np.pi * df1[j]
eod_fish_e = 1 * np.sin(time_fish_e)
stimulus = eod_fish_e + eod_fish_r
return stimulus, eod_fish_e, eod_fish_r, time
def exclude_ratios(f3, df, diff_max, integers, i, j, ratio_f, f_max, f_max2, diff_mean, diff_min, bigger, ratio, df1):
self = True
integers[i, j] = False
if (ratio % 1 == 0) or (ratio_f % 1 == 0):
if self == True:
f_max[i, j] = 1 / bigger
df[i, j] = 1 / bigger
f3[i, j] = 1 / bigger
f_max2[i, j] = 1 / bigger
diff_mean[i, j] = 1 / bigger
diff_min[i, j] = 1 / bigger
diff_max[i, j] = 1 / bigger
else:
f_max[i, j] = diff_mean[i, j]
integers[i, j] = True
if df1[i] == df1[j]:
if self == True:
f_max[i, j] = df1[j]
f_max2[i, j] = df1[j]
diff_mean[i, j] = df1[j]
diff_min[i, j] = df1[j]
diff_max[i, j] = df1[j]
df[i, j] = df1[j]
f3[i, j] = df1[j]
else:
f_max[i, j] = diff_mean[i, j]
integers[i, j] = True
return integers, f_max, diff_max, diff_min, diff_mean, f_max2
def do_splits(period_cut, sampling_rate, stimulus, length=0.4):
splits = period_cut * sampling_rate
if length != 'no':
stim0 = stimulus[int(splits[0]):int(splits[0] + length * sampling_rate)] # [int(splits[0]):int(splits[1])]
stim1 = stimulus[int(splits[1]):int(splits[1] + length * sampling_rate)] # [int(splits[1]):int(splits[2])]
stim2 = stimulus[int(splits[2]):int(splits[2] + length * sampling_rate)] # [int(splits[2]):int(splits[3])]
stim3 = stimulus[int(splits[3]):int(splits[3] + length * sampling_rate)] # [int(splits[3]):int(splits[4])]
else:
stim0 = stimulus[int(splits[0]):int(splits[1])]
stim1 = stimulus[int(splits[1]):int(splits[2])]
stim2 = stimulus[int(splits[2]):int(splits[3])]
stim3 = stimulus[int(splits[3]):int(splits[4])]
return stim0, stim1, stim2, stim3, splits
def calc_dist(stim0, stim1):
stim01 = stim0 * 1
stim02 = stim1 * 1
if len(stim0) > len(stim1):
stim01 = stim0[0:len(stim1)]
elif len(stim0) < len(stim1):
stim02 = stim1[0:len(stim0)]
dist = np.mean(np.sqrt((stim01 - stim02) ** 2))
return dist, stim01, stim02
def get_different_periods(df1, df2):
f_max = np.zeros([len(df1), len(df2)])
df = np.zeros([len(df1), len(df2)])
f3 = np.zeros([len(df1), len(df2)])
f_max2 = np.zeros([len(df1), len(df2)])
diff_mean = np.zeros([len(df1), len(df2)])
diff_min = np.zeros([len(df1), len(df2)])
diff_max = np.zeros([len(df1), len(df2)])
var = np.zeros([len(df1), len(df2)])
size_diffs = np.zeros([len(df1), len(df2)])
dist_variable = np.zeros([len(df1), len(df2)])
dist_max = np.zeros([len(df1), len(df2)])
dist_max2 = np.zeros([len(df1), len(df2)])
dist_min = np.zeros([len(df1), len(df2)])
dist_mean = np.zeros([len(df1), len(df2)])
dist_fmax = np.zeros([len(df1), len(df2)])
dist_f3 = np.zeros([len(df1), len(df2)])
dist_df = np.zeros([len(df1), len(df2)])
ratios = np.zeros([len(df1), len(df2)])
integers = np.zeros([len(df1), len(df2)])
limit = 0.1 # 0.09 # 0.05
plot_type = '' # 'dist'#''#'True'#'dist'#''#'dist'#'period'#
for i in range(len(df1)):
for j in range(len(df2)):
print('i ' + str(df1[i]) + ' j ' + str(df2[j]))
DF1_per = 1 / df1[i]
DF2_per = 1 / df2[j]
if not (np.isinf(DF1_per) | np.isinf(DF2_per)):
bigger = np.max([DF2_per, DF1_per])
smaller = np.min([DF2_per, DF1_per])
bigger_f = np.max([df1[j], df2[i]])
smaller_f = np.min([df1[j], df2[i]])
ratio_f = bigger_f / smaller_f
ratio = bigger / smaller
ratios[i, j] = ratio
dim = 4000
period = np.arange(0, dim, 1) * bigger # this is the window we are ready to sacrify t[-1]
time_bigger_f = (np.arange(0, dim, 1) * ratio)
rests_final = time_bigger_f % 1
period_interp = np.arange(0, period[-1], 1 / 1000)
interpolated = interpolate(period, rests_final, period_interp, kind='linear')
_, _ = ml.psd(interpolated - np.mean(interpolated), Fs=1 / np.diff(period_interp)[0], NFFT=5000,
noverlap=5000 / 2)
test = False
if test == True:
from utils_test import plot_psd
plot_psd()
#####################################
# find the right euclidean distance
sampling_rate = 10000
stimulus, eod_fish_e, eod_fish_r, time = create_arrays(df1, i, j, sampling_rate=sampling_rate)
p, f = ml.psd(rests_final - np.mean(rests_final), Fs=1 / np.diff(period)[0], NFFT=5000,
noverlap=5000 / 2)
f_max[i, j] = f[np.argmax(p)]
f3[i, j] = 1 / np.abs((1 / df1[i] + 1 / df1[j]))
df[i, j] = np.abs(df1[i] - df1[j])
one_zero = (rests_final < limit) | (rests_final - 1 > -limit)
period_cut = period[one_zero]
diff_mean[i, j] = 1 / np.mean(np.diff(period_cut))
diff_min[i, j] = 1 / np.min(np.diff(period_cut))
diff_max[i, j] = 1 / np.max(np.diff(period_cut))
p2, f2 = ml.psd(one_zero - np.mean(one_zero), Fs=1 / np.diff(period)[0], NFFT=20000,
noverlap=20000 / 2)
f_max2[i, j] = f2[np.argmax(p2)]
#####################
# period_cut paramteres
size_diff = np.max(np.diff(period_cut)) - np.min(np.diff(period_cut))
size_diffs[i, j] = size_diff
integers, f_max, diff_max, diff_min, diff_mean, f_max2 = exclude_ratios(f3, df, diff_max, integers, i,
j,
ratio_f, f_max, f_max2,
diff_mean, diff_min, bigger,
ratio, df1)
dist_f3[i, j] = find_dist_pure(1 / f3[i, j], sampling_rate, stimulus)
dist_df[i, j] = find_dist_pure(1 / df[i, j], sampling_rate, stimulus)
dist_fmax[i, j] = find_dist_pure(1 / f_max[i, j], sampling_rate, stimulus)
dist_mean[i, j] = find_dist_pure(1 / diff_mean[i, j], sampling_rate, stimulus)
dist_min[i, j] = find_dist_pure(1 / diff_min[i, j], sampling_rate, stimulus)
dist_max[i, j] = find_dist_pure(1 / diff_max[i, j], sampling_rate, stimulus)
dist_max2[i, j] = find_dist_pure(1 / f_max2[i, j], sampling_rate, stimulus)
var[i, j] = np.std(np.diff(period_cut))
dist_variable[i, j] = find_dist_pure(period_cut, sampling_rate, stimulus)
print(dist_min[i, j])
if plot_type == 'True':
test = True
elif plot_type != '':
if plot_type == 'dist':
if dist_min[i, j] > 0.2: # (:
test = True
else:
test = False
elif plot_type == 'period': #
if 1 / dist_variable[i, j] > 0.1: # (:1/diff_min[i,j] > 0.2
test = True
else:
test = False
if test:
from utils_test import plt_period
plt_period()
else:
print('inf')
f_max[i, j] = float('nan')
f_max2[i, j] = float('nan')
df[i, j] = float('nan')
f3[i, j] = float('nan')
diff_mean[i, j] = float('nan')
diff_min[i, j] = float('nan')
diff_max[i, j] = float('nan')
size_diffs[i, j] = float('nan')
ratios[i, j] = float('nan')
integers[i, j] = float('nan')
return dist_f3, dist_df, dist_fmax, dist_max2, dist_mean, dist_min, dist_max, dist_variable, var, diff_min, integers, ratios, size_diffs, diff_max, diff_mean, f3, df, f_max2, f_max
def find_dist_pure(f3, sampling_rate, stimulus):
if type(f3) == np.float64:
period_cut = np.arange(0, 20, f3)
else:
period_cut = f3
stim0, stim1, stim2, stim3, splits = do_splits(period_cut, sampling_rate, stimulus)
dist_f3, stim01, stim02 = calc_dist(stim0, stim1)
return dist_f3
def define_delays_trials(mean_type, frame, sorted_on='local_reconst_big_norm'):
if 'PhaseSort' in mean_type:
##############################################
# try the cross spektrum
test = False
frame_dev_eod = frame[frame.dev == sorted_on]
names = frame_dev_eod.keys()[0:-1][::-1]
delays_length = {}
if test:
from utils_test import test_delays
test_delays(frame)
if 'Same' in mean_type:
names_orig = ['control_02', 'control_02', 'control_02', 'control_02']
else:
names_orig = names
for nn, n in enumerate(names):
name_orig = names_orig[nn]
if 'base' not in n:
delays_length[n] = []
i_nr = 0
for i in range(len(frame_dev_eod) - 1):
delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type, delays_length,
i_nr=i_nr)
if test:
from utils_test import test_delay2
test_delay2()
else:
delays_length[n] = []
i_nr = 0
for i in range(len(frame_dev_eod) - 1):
delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type,
delays_length, i_nr=i_nr)
if 'Same' in mean_type:
delays_length[n] = []
i_nr = 0
for i in range(len(frame_dev_eod) - 1):
delays_length, i_nr = find_delays_length(frame_dev_eod, n, name_orig, i, mean_type,
delays_length,
i_nr=i_nr)
test = False
if test == True:
from utils_test import test_delays3
test_delays3()
else:
delays_length = []
return delays_length
def group_the_certain_group_several(grouped, DF2_desired, DF1_desired, emb=False):
try:
mult1 = np.array([a_tuple[2][0] for a_tuple in grouped.groups.keys()])
mult2 = np.array([a_tuple[2][1] for a_tuple in grouped.groups.keys()])
except:
print('tuple problem')
embed()
if str(mult1[0]) == '(':
tuples = np.array([a_tuple[2] for a_tuple in grouped.groups.keys()])
try:
tuples_convert = np.array([ast.literal_eval(a_tuple) for a_tuple in tuples])
except:
print('tuple thing')
embed()
mult1 = np.array([a_tuple[0] for a_tuple in tuples_convert])
mult2 = np.array([a_tuple[1] for a_tuple in tuples_convert])
try:
mult_array = np.round(np.abs(mult1 - DF1_desired) + np.abs((mult2 - DF2_desired)), 2)
except:
print('mult tuple problem')
embed()
restrict = np.argmin(mult_array)
min_val = mult_array[restrict]
restrict = mult_array == min_val
if emb:
embed()
return restrict
def calc_mult(freq1, eodf, freq2):
DeltaF1 = freq1 - eodf
DeltaF2 = freq2 - eodf
mult1 = DeltaF1 / eodf + 1
mult2 = DeltaF2 / eodf + 1
return mult1, mult2, DeltaF2, DeltaF1
def save_features(features, mt, mt_sorted):
for f in range(len(features)):
name = features[f][len(mt.name) + 1::]
mt_feature = np.concatenate(mt.features[features[f]].data[:])
if len(mt_feature) == len(mt.positions[:]):
mt_sorted[name] = mt_feature
else:
# es gibt diese alle ersten Zellen wo wir nur die Daten hatten und dann später in Nix files konvertiert hatten
# die mit sehr niedriegen Kontrasten
if (name == 'Frequency') | (name == 'DeltaF'):
mt_sorted[name + '1'] = mt_feature[0:len(mt_feature):2]
mt_sorted[name + '2'] = mt_feature[1:len(mt_feature):2]
else:
print('mt problems')
embed()
return mt_sorted
def load_metadata_infos_three(mt_sorted, mt, ver_here='new'):
if ('fish1.DeltaF' in mt_sorted) and (ver_here == 'new'):
phase = mt.metadata.sections[0]['fish1']['fish2']['Phase']
#######################################
# contrasts
contrast1 = mt.metadata.sections[0]['fish1alone']['Contrast']
if 'fish2alone' in mt.metadata.sections[0].sections:
contrast2 = mt.metadata.sections[0]['fish2alone']['Contrast']
else:
# das ist im Fall wenn der eine Kontrast Null ist, also solche Zellen sollten wir eigentlich nicht haben
# aber der Vollständigkeithalber ist das hier jetzt drin!
contrast2 = mt.metadata.sections[0]['fish1']['fish2']['Contrast']
freq1_orig = mt_sorted['fish1.Frequency']
freq2_orig = mt_sorted['fish2.Frequency']
eodf_orig = mt_sorted['EODf'] # .iloc[0]
# das ist für die älteren Zellen, die haben ein bisschen eine andere Namens gebeung
elif 'fish1.Frequency' in mt_sorted:
phase = mt.metadata.sections[0]['fish2']['Phase']
ver_here = 'code_old'
contrast1 = mt.metadata.sections[0]['Contrast']
contrast2 = mt.metadata.sections[0]['fish2']['Contrast']
freq1_orig = mt_sorted['fish1.Frequency'] # todo das eventuell noch ändern
freq2_orig = mt_sorted['fish2.Frequency']
# das ist für die sehr alteren Zellen, die haben ein bisschen eine andere Namens gebeung
else: # für z.B. Zelle ['2021-06-23-ab-invivo-1']
phase = mt.metadata.sections[0].sections[0]['Phase']
ver_here = 'code_very_old'
contrast1 = mt.metadata.sections[0]['Contrast']
contrast2 = mt.metadata.sections[0].sections[0]['Contrast']
freq1_orig = np.array(mt_sorted['Frequency']) # todo das eventuell noch ändern
try:
mt_sorted['fish1.Frequency'] = freq1_orig
except:
print('freq problem1')
embed()
freq2_orig = np.array(mt_sorted['fish2.Frequency'])
mt_sorted['fish2.Frequency'] = freq2_orig
# elif 'pureEODf' in eodftype:
# wir machen im Prinzip immer pure EODf das macht einfach Sinn
eodf_orig = mt_sorted['fish2.Frequency'] * float('nan') # .iloc[0]
return phase, ver_here, contrast1, contrast2, eodf_orig, freq1_orig, freq2_orig
def feautures_in_mtframe(mt):
mt_range = np.arange(0, len(mt.positions[:]))
mt_sorted = pd.DataFrame(mt_range, columns=['mt'])
mt_sorted['mt_names'] = mt.name
##############################
# features in nix
features, delay_name = feature_extract(mt)
mt_sorted = save_features(features, mt, mt_sorted)
return mt_sorted
def find_eodf(times_final, eodf_orig, eodftype, b, mt, mt_idx=[]):
if eodftype == '_psdEOD_': # DEFAULT
# ok das andere das kann ich auf einem mehrere MT Level extrahieren, aber diese Analyse muss ich hier einzeln machen
# also viele einzelne psds. Deswegen speichern wir das alles ab um das nicht jedes Mal neu zu machen, weil ich ja später nochmal über die Zelle
# iteriere
# und ich mache das nicht erst später, weil ich ja nach den Mehrfachen gruppiere und den Frequenzen die sich ja
# ändern können, deswegen ist das schon gut wenn das hier schon passieren kann
eodf, eodf_orig, freq_steps = find_eodf_three(b, mt, eodf_orig, mt_idx=mt_idx)
if np.isnan(eodf).any():
eodf = np.ones(len(eodf)) * times_final['EODf'].iloc[0]
else:
eodf = eodf_orig
return eodf
def predefine_grouping_frame(b, redo=False, load=True, eodftype='_psdEOD_', freqtype='', printing=True, ver_here='new',
intial=False, cell_name=[]):
name = 'calc_auc_three_core-spikes_core_AUCI_multsorted2__psdEOD_all.pkl'
path = load_folder_name('threefish') + '/' + name
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
load_function = find_load_function()
path_local = load_function + '-' + name.replace('.pkl', '.csv')
path_local_pkl = load_function + '-' + name
if (version_comp == 'public') & (redo == False):
# todo: ja das könnte man noch ausbauen
mt_sorted = pd.read_csv(path_local, index_col=0) #
if len(cell_name) > 0:
times_final = mt_sorted[mt_sorted['cell'] == cell_name]
else:
times_final = mt_sorted[mt_sorted['cell'] == b.name]
elif os.path.exists(path) & (load == True):
mt_sorted = pd.read_pickle(path)
times_final = mt_sorted[mt_sorted['cell'] == b.name]
if version_comp == 'develop':
mt_sorted.to_pickle(path_local_pkl)
print('reloaded mt')
else:
t1 = time.time()
for mt_nr, mt in enumerate(b.multi_tags):
# todo: dieses Predefined vielleicht abspeichern damit man nicht so lange plotten muss!
t3 = time.time()
if 'Three' in mt.name:
t2 = time.time()
mt_sorted = feautures_in_mtframe(mt)
phase, ver_here, contrast1, contrast2, eodf_orig, freq1_orig, freq2_orig = load_metadata_infos_three(
mt_sorted, mt, ver_here=ver_here)
if printing:
print('metadata: ' + str(time.time() - t2))
t4 = time.time()
# das hier machen wir einmal für alle mts!
try:
eodf = find_eodf(mt_sorted, eodf_orig, eodftype, b, mt)
except:
print('problem stuff')
embed()
if printing:
print('load_eod: ' + str(time.time() - t4))
if freqtype == '_psd_':
# die Frequenzen bestimmen wir so wie sie abgespeichert sind das ist sonst ein Problem mit den niedriegen Kotnrasten
freq1, freq1_orig, freq2, freq2_orig = find_freqs_three(b, mt, freq1_orig, freq2_orig, mt_sorted)
else: # DEFAULT
# das ist für die Bilder gut, da muss das glaube ich nicht so genau sein
freq1 = freq1_orig
freq2 = freq2_orig
t5 = time.time()
mult1, mult2, DeltaF2, DeltaF1 = calc_mult(freq1, eodf, freq2)
if printing:
print('calc_mult: ' + str(time.time() - t5))
# das runden wir auch zum gruppieren
mt_sorted['EODmult1'] = np.round(mult1, 2)
mt_sorted['EODmult2'] = np.round(mult2, 2)
mt_sorted['f1'] = freq1
mt_sorted['f2'] = freq2
mt_sorted['f1_orig'] = freq1_orig
mt_sorted['f2_orig'] = freq2_orig
# wir runden das DeltaF1 und DeltaF2, weil wir die dann ja groupieren wollen
try:
DeltaF1 = np.array(list(map(int, np.round(freq1 - eodf))))
except:
print('eodf thing')
embed()
DeltaF2 = np.array(list(map(int, np.round(freq2 - eodf))))
mt_sorted['DF2'] = DeltaF2
mt_sorted['DF1'] = DeltaF1
mt_sorted['phase'] = phase
mt_sorted['eodf'] = eodf
mt_sorted['eodf_orig'] = eodf_orig
mt_sorted['DF1, DF2'] = list(zip(DeltaF1, DeltaF2))
mt_sorted['m1, m2'] = list(zip(mt_sorted['EODmult1'], mt_sorted['EODmult2']))
mt_sorted['c1'] = contrast1
mt_sorted['c2'] = contrast2
restrict = np.arange(0, len(mt.positions[:]))
mt_sorted = mt_sorted.iloc[restrict]
# neuen (mt_sorted)
if intial == False:
times_final = mt_sorted
intial = True
else:
times_final = pd.concat([times_final, mt_sorted])
test = False
if test == True:
from utils_test import plt_freqs
plt_freqs()
if printing:
print('predefine_grouping_frame2: ' + str(time.time() - t2))
if printing:
print('predefine_grouping_frame3: ' + str(time.time() - t3))
if printing:
print('predefine_grouping_frame: ' + str(time.time() - t1))
return times_final
def find_freqs_three(b, mt, freq1_orig, freq2_orig, mt_sorted):
freq1 = []
freq2 = []
for mt_nr_small in range(len(mt.positions[:])):
##################
# get the times where to cut the stimulus
zeroth_cut, first_cut, second_cut, third_cut, fish_type, fish_cuts, whole_duration, delay, cont = load_four_durations(
mt, mt_sorted, mt_nr_small, mt_nr_small)
##################
# get the stimulus
eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small] - delay,
mt.extents[:][mt_nr_small] + delay,
array_name='GlobalEFieldStimulus')
# fish_number_base = remove_interspace_fish_nr(fish_number)
eods_glb, _ = cut_eod_sequences(eod_global, fish_cuts, cut=0, rec=False,
fish_number=fish_type, fillup=True, fish_number_base=fish_type)
if len(eods_glb['control_01']) > 0:
f1, p1, f = calc_freq_from_psd(eods_glb['control_01'], sampling_rate) # v
else:
f1 = freq1_orig.iloc[mt_nr_small]
freq1.append(f1)
if np.max(np.abs(freq1 - freq1_orig)) > 25:
print('f1 diff too big')
embed()
sampling_rate = 40000
_, _ = nfft_improval(sampling_rate, freq1_orig.iloc[mt_nr_small],
eods_glb['control_01'], freq1)
# problem: Bei kleinen Kontrasten ist das wohl keine so gute Idee..
# wir sollten dann doch davon ausgehen dass das stimmt mit den Frequenzen!
test = False
if test:
fig, ax = plt.subplots(2, 1)
ax[0].plot(eods_glb['control_01'])
ax[1].plot(f, p1)
plt.show()
if len(eods_glb['control_02']) > 0:
f2, p2, f = calc_freq_from_psd(eods_glb['control_02'], sampling_rate) # v
else:
f2 = freq2_orig.iloc[mt_nr_small]
freq2.append(f2)
if np.max(np.abs(freq2 - freq2_orig)) > 25:
print('f2 diff too big')
embed()
freq1 = np.array(freq1)
freq2 = np.array(freq2)
return freq1, freq1_orig, freq2, freq2_orig
def find_eodf_three(b, mt, eodf_orig, max_eod=False, mt_idx=[], freq_step_nfft_eod=0.6103515625):
eodf = []
# je nach dem ob ich alle mt freqs extrahieren will oder nur einen bestimmten index!
try:
if not list(mt_idx):
ranges_here = range(len(mt.positions[:]))
else:
ranges_here = mt_idx
except:
print('mt something wierd')
embed()
freq_steps = []
for mt_nr_small in ranges_here:
##################
# get the global EOD
# hier können wir alles vom mt nehmen weil sich das ja nicht im Abhängigkeit vom Stimulus ändert
# eod_global, sampling = link_arrays_eod(b, mt.positions[:][mt_nr] - delay,
# mt.extents[:][mt_nr] + delay, mt.positions[:][mt_nr],
# load_eod_array='EOD')
# die Dauer sollte mindestens eine halbe Sekunde haben sonst hat das nicht genug Power!
duration = mt.extents[:][mt_nr_small]
sampling = get_sampling(b, 'EOD')
nfft_eod = int(sampling / freq_step_nfft_eod)
# das ist das wir die minimal frequenz auflösung bekommen
if duration < nfft_eod / sampling:
duration = nfft_eod / sampling
global_eod, sampling = get_global_eod_for_eodf(b, duration, mt, mt_nr_small)
if len(global_eod) > 0:
##################
# das sollte die minimal Frequenz Auflösung sein
if max_eod:
maximal_nfft = len(global_eod)
else:
maximal_nfft = nfft_eod
eod_fr = get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, maximal_nfft, sampling)
try:
freq_step_maximal = get_freq_steps(maximal_nfft, sampling)
except:
print('unclear')
embed()
else:
eod_fr = eodf_orig[mt_nr_small]
maximal_nfft = nfft_eod
freq_step_maximal = get_freq_steps(maximal_nfft, sampling)
freq_steps.append(freq_step_maximal)
eodf.append(eod_fr)
eodf = np.array(eodf)
return eodf, eodf_orig, freq_steps
def find_all_dir_cells():
datasets = []
data_dir = []
dirs = ['cells'] # , 'cells_o', 'cells_l', 'cells_gp'
for dir in dirs:
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
if version_comp == 'develop': # version_comp == 'develop'
dir_path = '../../data/' + dir
else:
dir_path = '../data/' + dir
if os.path.exists(dir_path):
list_dir = os.listdir(dir_path + '/')[::-1]
for l, list_name in enumerate(list_dir):
if os.path.isdir(dir_path + '/' + list_name):
if 'noise' not in list_name: # ('invivo' in list_name) and
if list_name not in datasets:
datasets.append(list_name)
data_dir.append(dir + '/')
return datasets, data_dir
def get_all_nix_names(b, what='Three'):
all_mt_names = find_mt_all(b)
mt_names = find_mt(b, what)
t_names = []
for trials in b.tags:
if what in trials.name:
t_names.append(trials.name)
return all_mt_names, mt_names, t_names
def find_mt(b, what):
mt_names = []
for t_nr, trials in enumerate(b.multi_tags):
if what in trials.name:
mt_names.append(trials.name)
return mt_names
def find_right_dev(devname, devs):
dev_nrs = np.arange(len(devname))
dev_nrs = np.array(dev_nrs)[
np.array(devname) == devs[0]]
return dev_nrs
def load_cells_three(end, data_dir=[], datasets=[]):
if end == 'v2_2021-07-06':
cells = ['2021-07-06-ag-invivo-1', '2021-07-06-ab-invivo-1', '2021-07-06-ac-invivo-1', '2021-07-06-aa-invivo-1',
] # '2021-06-23-ac-invivo-1',
# Das sind glaube ich nochmal vier vom falschen Quadranten
elif end == 'v2_2021-07-08':
cells = ['2021-07-08-ab-invivo-1', '2021-07-08-aa-invivo-1', '2021-07-08-ac-invivo-1', '2021-07-08-ad-invivo-1']
# das werden nochmal vier sein wo aber nur ein Quadrant dabei ist
elif end == 'v2_2021-08-02':
cells = ['2021-08-02-ab-invivo-1', '2021-08-02-ac-invivo-1', '2021-08-02-ae-invivo-1']
# das sind drei Zellen wo ich das teilweise mit dem direkt mache
elif end == 'v2_2021-08-03':
cells = ['2021-08-03-ac-invivo-1', '2021-08-03-af-invivo-1', '2021-08-03-ad-invivo-1']
# das sind zwei Zellen wo ich das auch mit dem direkt mache
elif end == 'v2':
cells = ['2021-07-08-aa-invivo-1', '2021-07-08-ab-invivo-1', '2021-07-08-ac-invivo-1', '2021-07-08-ad-invivo-1',
'2021-08-03-ac-invivo-1', '2021-08-03-af-invivo-1', '2021-08-03-ad-invivo-1',
'2021-08-02-ab-invivo-1', '2021-08-02-ac-invivo-1', '2021-08-02-ae-invivo-1',
'2021-07-06-ag-invivo-1', '2021-07-06-ab-invivo-1', '2021-07-06-ac-invivo-1', '2021-07-06-aa-invivo-1',
]
elif end == 'all':
cells = datasets
cells = cells[::-1]
data_dir = data_dir[::-1] # todo: hier noch anpassen, weil es bei manchen nicht durchgeht!!
return data_dir, cells
def spikes_for_desired_cells(spikes, data_names=[],
names=['intro', 'contrasts', 'bursts', 'ampullary_small', 'model', 'eigen_small',
'eigemania_low_cv', 'eigenmania_high_cv', 'low_cv_punit', 'ampullary',
'bursts_all']):
if spikes != '':
data_names = find_names_cells(names, data_names)
return data_names
def find_names_cells(names, data_names=[]):
for name in names:
try:
data_names.extend(p_units_to_show(type_here=name))
except:
print('embed thing')
embed()
return data_names
def plt_model_overview2(ax, cells=[], color_special='white', color_all='grey', scores=['perc95_perc5_fr']):
a = 0
nr = '2'
position = 0
save_names = [load_folder_name('calc_model') +
'/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_9_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
] # load_folder_name('calc_model') +'/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_burstIndividual_']
cvs = []
cells_here = []
frames = []
for save_name in save_names:
frame = pd.DataFrame()
frame = load_model_overview(cells, frame, nr, position, save_name, redo=True)
frames.append(frame)
cvs.append(frame.cv_stim)
cells_here.append(frame.cell)
# todo: in der Burst corr version werden das weniger Zellen, schauen warum!
for c, cv in enumerate(cvs):
for s, save_name in enumerate(save_names):
cells_plot2 = p_units_to_show(type_here='model')
cells_plot2.extend(["2013-01-08-aa-invivo-1", "2012-12-13-an-invivo-1"])
# burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
frame = frames[s] # load_model_overview(cells, frame, nr, position, save_name)
for s, score in enumerate(scores):
ax.scatter(cv, frame[score], s=3.5, color=color_all) # , color = color)#, alpha = 0.45
ax.scatter(cv[frame['cell'].isin(cells_plot2)], frame[score][frame['cell'].isin(cells_plot2)], s=5,
edgecolor='black', alpha=0.5, color=color_special) # , alpha = 0.45
a += 1
def plt_model_overview(ax, cells=[], scores=['perc95_perc5_fr']):
a = 0
nr = '2'
position = 0
frame = pd.DataFrame()
save_names = [load_folder_name('calc_model') +
'/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
,
load_folder_name(
'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_30_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_burstIndividual_']
for save_name in save_names:
frame = load_model_overview(cells, frame, nr, position, save_name)
for s, score in enumerate(scores):
ax[a].scatter(frame['cv'], frame[score])
ax[a].set_ylabel(score)
ax[a].set_xlabel('cv')
a += 1
def load_model_overview(cells, frame, nr, position, save_name, redo=False):
path = save_name + '.pkl' # '../'+
model = load_model_susept(path, cells, save_name, save=False)
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
trials = path.split('TrialsStim_')[1].split('_a_fr_')[0]
trials_stim = int(trials)
save_name_final = find_load_function() + '_model' + trials + '.csv'
try:
(not os.path.exists(save_name_final)) | (redo == True)
except:
print('stil problems')
embed()
if ((not os.path.exists(save_name_final)) | (redo == True)) & (
version_comp != 'public'): # (version_comp == 'code') | (version_comp == 'develop'):
for cell in cells:
if len(model) > 0:
model_show = model[
(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
new_keys = model_show.index.unique() # [0:490]
try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
c_sig = 1
# todo: doch das stimmt für den Egerland Fall!
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, nr=nr) # var_basedD=D,
# doch ich glaube das stimmt schon, ich muss halt aufpassen worüber ich mittel
# ok doch das stimmt auch
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived)
diagonal, frame = get_overview_scores('', frame, stack_plot, position)
diag, diagonals_prj_l = get_mat_diagonals(np.array(stack_plot))
frame = signal_to_noise_ratios(diagonals_prj_l, frame, position, '')
frame = fill_frame_with_non_float_vals(frame, position, model_show)
position += 1
if version_comp == 'develop':
frame.to_csv(save_name_final)
else:
frame = pd.read_csv(save_name_final)
return frame
def get_overview_scores(add, frame, mat, position):
mat = np.array(mat)
maximum = np.max(np.array(mat))
minimum = np.min(np.array(mat))
percentiel99 = np.percentile(mat, 99)
percentiel90 = np.percentile(mat, 90)
percentiel80 = np.percentile(mat, 80)
percentiel70 = np.percentile(mat, 70)
percentiel10 = np.percentile(mat, 10)
percentiel95 = np.percentile(mat, 95)
percentiel5 = np.percentile(mat, 5)
percentiel1 = np.percentile(mat, 1)
frame.loc[position, 'std_mean' + add] = np.std(np.array(mat)) / np.mean(np.array(mat))
frame.loc[position, 'max_min' + add] = (maximum - minimum) / (maximum + minimum)
frame.loc[position, 'perc80_perc5' + add] = (percentiel80 - percentiel5) / (percentiel80 + percentiel5)
frame.loc[position, 'perc70_perc5' + add] = (percentiel70 - percentiel5) / (percentiel70 + percentiel5)
frame.loc[position, 'perc90_perc5' + add] = (percentiel90 - percentiel5) / (
percentiel90 + percentiel5)
frame.loc[position, 'perc90_perc10' + add] = (percentiel90 - percentiel10) / (
percentiel90 + percentiel10)
frame.loc[position, 'perc95_perc5' + add] = (percentiel95 - percentiel5) / (percentiel95 + percentiel5)
frame.loc[position, 'perc99_perc1' + add] = (percentiel99 - percentiel1) / (percentiel99 + percentiel1)
test = False
if test:
from utils_test import test_percentile
test_percentile()
extra = False
if extra:
diagonal = mat.diagonal()
diagonal_norm = diagonal / np.sum(diagonal)
mat_norm = mat / np.sum(diagonal)
entropy_mat = scipy.stats.entropy(np.concatenate(mat_norm))
entropy_diagonal = scipy.stats.entropy(diagonal_norm)
frame.loc[position, 'entropy_mat' + add] = entropy_mat
frame.loc[position, 'entropy_diagonal' + add] = entropy_diagonal
else:
diagonal = []
return diagonal, frame
def fill_frame_with_non_float_vals(frame, position, stack_here):
types = list(map(type, stack_here.keys()))
keys_else = stack_here.keys()[np.where(np.array(types) != float)]
stack_vals = stack_here[keys_else].iloc[0]
if 'osf' in stack_vals.keys():
stack_vals.pop('osf')
if 'spikes' in stack_vals.keys():
stack_vals.pop('spikes')
stack_vals.pop('isf')
stack_vals.pop('freqs')
if 'freqs_idx' in stack_vals.keys():
stack_vals.pop('freqs_idx')
frame.loc[position, list(np.array(stack_vals.keys()))] = stack_vals # stack_else
return frame
def convert_csv_str_to_float(stack_final):
stack_plot = stack_final
new_keys = stack_plot.index
try:
stack_plot = stack_plot[new_keys]
except:
new_keys = list(map(str, new_keys))
try:
stack_plot = stack_plot[new_keys]
except:
new_keys = np.round(stack_plot.index, 1)
new_keys = list(map(str, new_keys))
new_keys = [k + '.0' for k in new_keys]
stack_plot = stack_plot[new_keys]
print('stack two still not working')
embed()
stack_plot = stack_plot.astype(complex)
stack_plot.columns = list(map(float, stack_plot.columns))
return new_keys, stack_plot
def change_model_from_csv_to_plots(model_show):
new_keys = model_show.index.unique() # [0:490]
try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
return stack_plot
def file_names_exlude_func(frame):
file_names_there = ['gwn150Hz10s0.3',
'gwn300Hz10s0.3', 'gwn300Hz10s0.3short', 'gwn300Hz50s0.3', 'InputArr_400hz_30'
]
frame_file_ex = frame[frame.file_name.isin(file_names_there)]
return frame_file_ex
def extract_mat(stack_here, keys=[], e=0, sens='', ends_nr=[2000], abs=True, norm=''):
if len(keys) < 1:
keys = get_float_keys(stack_here)
mat_to_div = stack_here[keys[keys < ends_nr[e]]]
mat_to_div = mat_to_div.loc[mat_to_div.index < ends_nr[e]]
mat_val = ram_norm_choice(mat_to_div, norm, stack_here, abs=abs)
# hier adden wir nochmal die Zell Sensitivitäts Bereinigung durch die Var der Spikes
if sens != '':
spikes_var = stack_here['response_modulation'].iloc[
0] # stack_here['var_spikes'].iloc[0] / stack_here['snippets'].unique()[0]
mat_val = mat_val / spikes_var
# todo: das ist eigneltich egal weil die Maße danach beziehen das mit ein!
mat = np.array(mat_val)
return mat, mat_to_div
def ram_norm_choice(mat_to_div, norm, stack_here, abs=True):
if 'constnorm' in norm: # _constnorm
# const norm mit diesem d_isf1 ist generell falsch
# damals haben wir an der NUller vom Power SPectrum geschaut und das als Abschätung
# der varianz gemacht und nicht mal über die Power sondern das abs gemittelt deswegen ist das alles falsch
mat_val = RAM_norm_data(stack_here['d_isf1'].iloc[0],
mat_to_div,
stack_here['snippets'].unique()[0])
else:
if 'old' in norm:
power = 1
else:
power = 2
mat_val = RAM_norm_data(stack_here['isf'].iloc[0], mat_to_div,
stack_here['snippets'].unique()[0], abs=abs, power=power, stack_here=stack_here)
return mat_val
def get_mat_diagonals(mat):
diagonals = []
shapes = []
diagonals_prj = []
diagonals_prj_l = []
for m in range(len(mat)):
# todo: ich glaube ich mache diese Projektion falsch
try:
diagonals.append(np.diagonal(mat[:, ::-1][m::, :])) # [0:-m, 0:-m]mat[m:, m:]
except:
print('diagonal thing')
embed()
diagonals_prj.append(np.mean(np.diagonal(mat[:, ::-1][m::, :])))
diagonals_prj_l.append(
np.mean(np.diagonal(np.transpose(mat[:, ::-1])[m::, :])))
shapes.append(np.shape(mat[m:, m:]))
diag = np.diagonal(mat)
diagonals_prj_l = diagonals_prj_l[::-1]
diagonals_prj_l.extend(diagonals_prj)
return diag, diagonals_prj_l
def axis_projection(mat_val, axis='orig_axis'):
if 'orig_axis' in axis:
diff_val = np.diff(np.array(mat_val.index))[0] / 2
axis_d = np.arange(mat_val.index[0] - diff_val,
mat_val.index[-1] + diff_val, diff_val)
else:
axis_new = mat_val.index + mat_val.columns
diff_val = np.diff(np.array(axis_new))[0] / 2
axis_d = np.arange(axis_new[0] - diff_val,
axis_new[-1] + diff_val, diff_val)
return axis_d
def mod_lims_modulation(cell_type_here, frame_file, score_m, std_est=None):
if not std_est:
if 'P-Unit' in cell_type_here:
mod_limits = np.arange(0, 100,
5) # np.linspace(0,100,11)#np.max(frame_file[score_m])
mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]])
else:
mod_limits = np.arange(0, 60,
5) # np.linspace(0,100,11)#np.max(frame_file[score_m])
mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]])
else:
nbins = 70
std_estimate, center = hist_threshold(frame_file[score_m][~np.isnan(frame_file[score_m])],
thresh_fac=1,
nbins=nbins)
mod_limits = np.linspace(0, np.median(frame_file[score_m]) + 3 * std_estimate,
30)
mod_limits = np.concatenate([mod_limits, [np.max(frame_file[score_m])]])
test = False
if test:
_, _ = hist_threshold() # frame[score][~np.isnan(frame[score])],thresh_fac=1,nbins=nbins
return mod_limits
def signal_to_noise_ratios(diag_val, frame, position, add):
sub_mods = ['', '-med', '-center', '-m'] # , richtig = ''
perc_nrs = [99, 99.9] # ,99, 10092.5, 80, 85, 90,98,
# die gaussian Werte vom Jan
nbins = 70
std_estimate, center = hist_threshold(diag_val,
thresh_fac=1,
nbins=nbins)
std_sigma = np.percentile(diag_val, 84) - np.median(diag_val)
std_orig = np.std(diag_val)
div_mods = ['', 'med', 'stdthunder'] # , 'stdsigma','stdorig' ,'mean']# richtig = 'med' 'stdthunder2.576',
for sub_mod in sub_mods:
for perc_nr in perc_nrs:
for div_mod in div_mods:
percentiel99 = np.percentile(diag_val,
perc_nr)
frame.loc[position, 'perc' + str(perc_nr) + '_' + add] = percentiel99
frame.loc[position, 'med' + '_' + add] = np.median(
diag_val)
frame.loc[position, 'stdthunder' + '_' + add] = std_estimate
frame.loc[position, 'stdthunder2.576' + '_' + add] = std_estimate * 2.576
frame.loc[
position, 'stdsigma' + '_' + add] = std_sigma
frame.loc[
position, 'stdorig' + '_' + add] = std_orig
frame.loc[position, 'center' + '_' + add] = center
if div_mod == 'med':
div = np.median(diag_val)
elif div_mod == 'stdthunder':
div = std_estimate
elif div_mod == 'stdthunder2.576':
div = std_estimate * 2.576
elif div_mod == 'stdsigma':
div = std_sigma
elif div_mod == 'stdorig':
div = std_orig
elif div_mod == '':
div = 1
else:
div = np.mean(diag_val)
if sub_mod == '-m':
sub = div
elif sub_mod == '-med':
sub = np.median(diag_val)
elif sub_mod == '-center':
sub = center
else:
sub = 0
frame.loc[
position, 'perc' + str(perc_nr) + sub_mod + '/' + div_mod + '_' + add] = (percentiel99 - sub) / div
return frame
def restrict_base_durationts(duration):
if duration > 30:
duration = 30
else:
duration = duration
return duration
def update_fav_snippet(nfft, fav_snippet=9):
return int(np.round(
fav_snippet / float(nfft.replace('sec', ''))))
def find_cell_cont(redo, cell, frame, saved):
if saved == True:
try:
np.array(frame.cell.unique())
except:
print('problem')
embed()
if cell not in np.array(frame.cell.unique()):
cont = True
else:
cont = False
else:
cont = True
if redo:
cont = True
return cont
def load_frame(redo, name):
if redo == False:
if os.path.exists(name):
if 'csv' in name:
try:
frame = pd.read_csv(name, index_col=0)
except:
print('parse thing')
embed()
else:
frame = pd.read_pickle(name)
position = len(frame)
saved = True
else:
frame = pd.DataFrame()
position = 0
saved = False
else:
frame = pd.DataFrame()
position = 0
saved = False
return frame, position, saved
def find_common_mod(save_names=[
'calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
]):
amps = []
mods_p = []
mods_a = []
mods = []
cell_type_type = 'cell_type_reclassified'
for ss, save_name in enumerate(save_names):
save_name_here = load_folder_name('calc_RAM') + '/' + save_name + '.csv'
frame_load = pd.read_csv(save_name_here, index_col=0)
amps.extend(frame_load.amp.unique())
spikes_var = np.sqrt(frame_load['var_spikes'] / frame_load['snippets'])
frame_load['modulation'] = spikes_var
mods_p.extend(frame_load[frame_load[cell_type_type] == ' P-unit']['modulation'])
mods_a.extend(frame_load[frame_load[cell_type_type] == ' Ampullary']['modulation'])
mods.extend(frame_load['modulation'])
mod_limits_p = np.linspace(0, 1200, 8)
mod_limits_p = np.concatenate([mod_limits_p, [np.max(mods)]])
mod_limits_a = np.linspace(0, 500, 8)
mod_limits_a = np.concatenate([mod_limits_a, [np.max(mods)]])
return mod_limits_a, mod_limits_p, mods_a, mods_p, frame_load
def find_norm_compars(isf, isf_mean, osf, deltat, stack_plot, mean=True):
f_range = np.arange(len(stack_plot))
try:
_, _, _, _ = fft_matrix(osf[0], f_range, isf[0], norm='') # stimulus,
except:
print('fmat thing')
embed()
f_mat1, f_mat2, f_idx_sum, cross_norm = fft_matrix(osf[0], f_range, isf[0],
norm='_normPS_') # stimulus,
mats_all = []
mats_all_norm = []
scales = []
for t in range(len(osf)):
f_mat1, f_mat2, f_idx_sum, mat_all = fft_matrix(osf[t], f_range, isf[t],
norm='') # stimulus,
f_mat1, f_mat2, f_idx_sum, cross_norm = fft_matrix(osf[t], f_range, isf[t],
norm='_normPS_') # stimulus,
mats_all_norm.append(cross_norm)
mats_all.append(mat_all)
scale = find_norm_susept(f_idx_sum, isf[t][f_range])
scales.append(scale)
if mean:
mats_all_here = np.mean(mats_all, axis=0)
else:
mats_all_here = np.sum(mats_all, axis=0)
mats_all_here_norm = np.mean(mats_all_norm, axis=0)
scales = np.mean(scales, axis=0)
power_isf_1 = (np.abs(isf_mean[f_range]))
power_isf_1 = [power_isf_1] * len(stack_plot)
norm_char22 = find_norm_susept(stack_plot, isf_mean[f_range])
norm_char2 = 1 / norm_char22
return scales, cross_norm, f_mat2, mats_all_here, mats_all_here_norm, norm_char2
def overview_model(individual_tag='', many=False, fs=8, hs=0.39,
nffts=['whole'],
powers=[1], cells=["2013-01-08-aa-invivo-1"], var_items=['contrasts'], show=False,
contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'],
external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300,
repeats=[1000000],
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1],
c_signal=[0.9],
cut_offs1=[300], burst_corrs=[''], restrict='restrict'):
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
aa = 0
for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product(
burst_corrs, cells, D_extraction_method, external_noise
, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
fig, ax = plt.subplots(2, 4, sharex=True,
figsize=(12, 5.5)) # sharey=True,constrained_layout=True,, figsize=(11, 5)
plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
ax = np.concatenate(ax)
a = 0
iternames = [burst_corrs, D_extraction_method, external_noise,
repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
position = 0
frame = pd.DataFrame()
for all in it.product(*iternames):
burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim,
power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract,
noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
path = save_name + '.pkl' # '../'+
model = load_model_susept(path, cells, save_name)
path_cell_ref = load_folder_name(
'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl'
model_sorting = load_model_susept(path_cell_ref, cells, save_name)
# ok hier sortiere ich das irgendwie und irgendwas geht dabei schief
cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)]
cv_sort = True
if cv_sort:
cells = np.array(np.unique(cells_sorted, return_index=True)[0])[
np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))]
for cell in cells:
if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5))
stim_type_noise_name = stim_type_noise
else:
stim_type_noise_name = ''
if dendrid == '':
dendrid_name = 'standard'
else:
dendrid_name = dendrid
if ref_type == '':
ref_type_name = 'standard'
else:
ref_type_name = dendrid
if adapt_type == '':
adapt_type_name = 'standard'
else:
adapt_type_name = adapt_type
if len(model) > 0:
titles = ''
suptitles = ''
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe,
stim_type_noise_name)
if 'cells' in var_items:
titles += cell[2:13]
else:
suptitles += cell[2:13]
if 'internal_noise' in var_items:
titles += ' intrinsic noise=' + stim_type_noise_name2
else:
suptitles += ' intrinsic noise=' + stim_type_noise_name2
if 'external_noise' in var_items:
titles += ' additive RAM=' + stim_type_afe_name
else:
suptitles += ' additive RAM=' + stim_type_afe_name
if 'repeats' in var_items:
titles += ' $N_{repeat}=$' + str(trials_stim)
else:
suptitles += ' $N_{repeat}=$' + str(trials_stim)
if 'contrasts' in var_items:
titles += ' contrast=' + str(a_fe)
else:
suptitles += ' contrast=' + str(a_fe)
if 'level_extraction' in var_items:
titles += ' Extract Level=' + str(extract)
else:
suptitles += ' Extract Level=' + str(extract)
if 'D_extraction_method' in var_items:
titles += str(var_type)
else:
suptitles += str(var_type)
if 'noises_added' in var_items:
titles += ' high freq noise=' + str(noise_added)
else:
suptitles += ' high freq noise=' + str(noise_added)
model_show = model[
(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
new_keys = model_show.index.unique() # [0:490]
try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength # (noise_strength ** 2) / 2
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived)
if many == True:
titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0]))
color = title_color(cell)
print(color)
diagonal, frame = get_overview_scores('', frame, stack_plot, position)
frame = fill_frame_with_non_float_vals(frame, position, model_show)
position += 1
else:
print('len problem2')
embed()
else:
print('len problem')
embed()
a += 1
scores = ['std_mean', 'max_min', 'perc80_perc5', 'perc70_perc5', 'perc90_perc10', 'perc95_perc5', 'entropy_mat',
'entropy_diagonal']
for s, score in enumerate(scores):
ax[s].scatter(frame['cv'], frame[score])
ax[s].set_title(score)
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title + titles + '\n $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str(
int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str(
np.round(model_show.cv.iloc[0], 2)) + \
' $cv_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2)) + '\n $D_{sig}$=' + str(
np.round(D_derived, 5)) + ' s=' + str(
np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color) # +' file '
save_visualization(individual_tag=individual_tag, pdf=True, show=show)
def overview_model_trials(individual_tag='', many=False, row='no', fs=8, hs=0.39,
nffts=['whole'],
powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=8, var_items=['contrasts'],
show=False,
contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz',
mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'],
external_noise=['eRAM'],
scores=['std_mean', 'max_min', 'perc80_perc5', 'perc70_perc5', 'perc90_perc10',
'perc95_perc5', 'entropy_mat', 'entropy_diagonal']
, level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000],
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1],
c_signal=[0.9],
cut_offs1=[300], burst_corrs=[''], restrict='restrict'):
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
aa = 0
for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product(
burst_corrs, cells, D_extraction_method, external_noise
, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
if row == 'no':
col, row = find_row_col(np.arange(aa * 8 / len(cells)), col=col_desired) # np.arange(
else:
col = col_desired
fig, ax = plt.subplots(row, col, sharex=True,
figsize=(12, 5.5)) # sharey=True,constrained_layout=True,, figsize=(11, 5)
plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
a = 0
iternames = [burst_corrs, D_extraction_method, external_noise,
repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
position = 0
frame = pd.DataFrame()
for all in it.product(*iternames):
burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim,
power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract,
noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
path = save_name + '.pkl' # '../'+
model = load_model_susept(path, cells, save_name)
path_cell_ref = load_folder_name(
'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl'
model_sorting = load_model_susept(path_cell_ref, cells, save_name)
# ok hier sortiere ich das irgendwie und irgendwas geht dabei schief
cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)]
cv_sort = True
if cv_sort:
cells = np.array(np.unique(cells_sorted, return_index=True)[0])[
np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))]
for cell in cells:
if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5))
stim_type_noise_name = stim_type_noise
else:
stim_type_noise_name = ''
if dendrid == '':
dendrid_name = 'standard'
else:
dendrid_name = dendrid
if ref_type == '':
ref_type_name = 'standard'
else:
ref_type_name = dendrid
if adapt_type == '':
adapt_type_name = 'standard'
else:
adapt_type_name = adapt_type
if len(model) > 0:
titles = ''
suptitles = ''
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe,
stim_type_noise_name)
if 'cells' in var_items:
titles += cell[2:13]
else:
suptitles += cell[2:13]
if 'internal_noise' in var_items:
titles += ' intrinsic noise=' + stim_type_noise_name2
else:
suptitles += ' intrinsic noise=' + stim_type_noise_name2
if 'external_noise' in var_items:
titles += ' additive RAM=' + stim_type_afe_name
else:
suptitles += ' additive RAM=' + stim_type_afe_name
if 'repeats' in var_items:
titles += ' $N_{repeat}=$' + str(trials_stim)
else:
suptitles += ' $N_{repeat}=$' + str(trials_stim)
if 'contrasts' in var_items:
titles += ' contrast=' + str(a_fe)
else:
suptitles += ' contrast=' + str(a_fe)
if 'level_extraction' in var_items:
titles += ' Extract Level=' + str(extract)
else:
suptitles += ' Extract Level=' + str(extract)
if 'D_extraction_method' in var_items:
titles += str(var_type)
else:
suptitles += str(var_type)
if 'noises_added' in var_items:
titles += ' high freq noise=' + str(noise_added)
else:
suptitles += ' high freq noise=' + str(noise_added)
model_show = model[
(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
new_keys = model_show.index.unique() # [0:490]
try: # je nach dem in welchem folder wir sind also im übergeordneten oder untergerodneten
stack_plot = model_show[list(map(str, new_keys))]
except:
stack_plot = model_show[new_keys]
stack_plot = stack_plot.iloc[np.arange(0, len(new_keys), 1)]
stack_plot.columns = list(map(float, stack_plot.columns))
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength # (noise_strength ** 2) / 2
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived)
if many == True:
titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0]))
color = title_color(cell)
print(color)
diagonal, frame = get_overview_scores('', frame, stack_plot, position)
frame = fill_frame_with_non_float_vals(frame, position, model_show)
position += 1
else:
print('len problem2')
embed()
else:
print('len problem')
embed()
for s, score in enumerate(scores):
ax[a, s].scatter(frame['cv'], frame[score])
ax[0, s].set_title(score)
ax[-1, s].set_xlabel('cv')
ax[a, 0].text(0, 1.2, titles + ' $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str(
int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz $cv_{B}$=' + str(
np.round(model_show.cv.iloc[0], 2)) + \
' $cv_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2)) + ' $D_{sig}$=' + str(
np.round(D_derived, 5)) + ' s=' + str(
np.round(model_show.ser_sum_stim.iloc[0], 2)), transform=ax[a, 0].transAxes, )
a += 1
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title, fontsize=fs, color=color) # +' file '
save_visualization(individual_tag=individual_tag, pdf=True, show=show)
def model_cells(individual_tag='', nr_clim=10, many=False, width=0.02, row='no', HZ50=True, fs=8, hs=0.39,
nffts=['whole'],
powers=[1], cells=["2013-01-08-aa-invivo-1"], col_desired=2, var_items=['contrasts'], show=False,
contrasts=[0], noises_added=[''], fft_i='forward', fft_o='forward', spikes_unit='Hz', mV_unit='mV',
D_extraction_method=['additiv_visual_d_4_scaled'], internal_noise=['eRAM'],
external_noise=['eRAM'], level_extraction=['_RAMdadjusted'], cut_off2=300, repeats=[1000000],
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1], c_signal=[0.9],
cut_offs1=[300], burst_corrs=[''], clims='all', restrict='restrict', label=r'$\frac{1}{mV^2S}$'):
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
aa = 0
for burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe, in it.product(
burst_corrs, cells, D_extraction_method, external_noise
, repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ):
aa += 1
if row == 'no':
col, row = find_row_col(np.arange(aa), col=col_desired) # np.arange(
else:
col = col_desired
if row == 2:
default_settings(column=2, length=7.5) # 2+2.25+2.25
elif row == 1:
default_settings(column=2, length=4)
fig, ax = plt.subplots(row, col, sharex=True,
sharey=True) # constrained_layout=True,, figsize=(11, 5)
if row == 2:
plt.subplots_adjust(bottom=0.067, wspace=0.45, top=0.81, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
elif row == 1:
plt.subplots_adjust(bottom=0.1, wspace=0.45, top=0.81, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
else:
plt.subplots_adjust(wspace=0.8, bottom=0.067, top=0.86, hspace=hs, right=0.88,
left=0.075) # , hspace = 0.6, wspace = 0.5
if row != 1:
ax = np.concatenate(ax)
a = 0
maxs = []
mins = []
ims = []
perc05 = []
perc95 = []
iternames = [burst_corrs, D_extraction_method, external_noise,
repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
for all in it.product(*iternames):
burst_corr, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
save_name = save_ram_model(stimulus_length, cut_off1, nfft, a_fe, stim_type_noise, mimick, variant, trials_stim,
power,
cell_recording_save_name, nr=nr, fft_i=fft_i, fft_o=fft_o, Hz=spikes_unit,
mV=mV_unit, burst_corr=burst_corr, stim_type_afe=stim_type_afe, extract=extract,
noise_added=noise_added,
c_noise=c_noise, c_sig=c_sig, ref_type=ref_type, adapt_type=adapt_type,
var_type=var_type, cut_off2=cut_off2, dendrid=dendrid, a_fr=a_fr,
trials_nr=trial_nrs, trans=trans, zeros='ones')
path = save_name + '.pkl' # '../'+
model = load_model_susept(path, cells, save_name)
path_cell_ref = load_folder_name(
'calc_model') + '/calc_RAM_model-2__nfft_whole_power_1_eRAM_RAMdadjusted_additiv_visual_d_4_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_100000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV.pkl'
model_sorting = load_model_susept(path_cell_ref, cells, save_name)
cells_sorted = model_sorting.cell.iloc[np.argsort(model_sorting.cv)]
cv_sort = True
if cv_sort:
cells = np.array(np.unique(cells_sorted, return_index=True)[0])[
np.array(np.argsort(np.unique(cells_sorted, return_index=True)[1]))]
for cell in cells:
if 'additiv' in var_type: # ' ser1 ' + str(np.round(model_show.ser_first_stim.iloc[0], 2))+ ' ser mean ' + str(np.round(model_show.ser_stim.iloc[0], 5))
stim_type_noise_name = stim_type_noise
else:
stim_type_noise_name = ''
if dendrid == '':
dendrid_name = 'standard'
else:
dendrid_name = dendrid
if ref_type == '':
ref_type_name = 'standard'
else:
ref_type_name = dendrid
if adapt_type == '':
adapt_type_name = 'standard'
else:
adapt_type_name = adapt_type
if len(model) > 0:
titles = ''
suptitles = ''
stim_type_noise_name2, stim_type_afe_name = stim_type_names(a_fe, c_sig, stim_type_afe,
stim_type_noise_name)
if 'cells' in var_items:
titles += cell[2:13]
else:
suptitles += cell[2:13]
if 'internal_noise' in var_items:
titles += ' intrinsic noise=' + stim_type_noise_name2
else:
suptitles += ' intrinsic noise=' + stim_type_noise_name2
if 'external_noise' in var_items:
titles += ' additive RAM=' + stim_type_afe_name
else:
suptitles += ' additive RAM=' + stim_type_afe_name
if 'repeats' in var_items:
titles += ' $N_{repeat}=$' + str(trials_stim)
else:
suptitles += ' $N_{repeat}=$' + str(trials_stim)
if 'contrasts' in var_items:
titles += ' contrast=' + str(a_fe)
else:
suptitles += ' contrast=' + str(a_fe)
if 'level_extraction' in var_items:
titles += ' Extract Level=' + str(extract)
else:
suptitles += ' Extract Level=' + str(extract)
if 'D_extraction_method' in var_items:
titles += str(var_type)
else:
suptitles += str(var_type)
if 'noises_added' in var_items:
titles += ' high freq noise=' + str(noise_added)
else:
suptitles += ' high freq noise=' + str(noise_added)
model_show = model[
(model.cell == cell)]
stack_plot = change_model_from_csv_to_plots(model_show)
ax[a].set_xlim(0, 300)
ax[a].set_ylim(0, 300)
ax[a].set_aspect('equal')
model_cells = resave_small_files("models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
if len(model_show) > 0:
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength # (noise_strength ** 2) / 2
D_derived, var, cut_off = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based
stack_plot = RAM_norm(stack_plot, trials_stim, D_derived)
if many == True:
titles = titles + ' Ef=' + str(int(model_params.EODf.iloc[0]))
color = title_color(cell)
print(color)
ax[a].set_title(
titles + '\n $fr_{B}$=' + str(int(np.round(model_show.fr.iloc[0]))) + ' $fr_{S}$=' + str(
int(np.round(model_show.fr_stim.iloc[0]))) + 'Hz\n $cv_{B}$=' + str(
np.round(model_show.cv.iloc[0], 2)) + \
' $cv_{S}$=' + str(
np.round(model_show.cv_stim.iloc[0], 2)) + '\n $D_{sig}$=' + str(
np.round(D_derived, 5)) + ' s=' + str(
np.round(model_show.ser_sum_stim.iloc[0], 2)), fontsize=fs, color=color)
perc = '' # 'perc'
im = plt_RAM_perc(ax[a], perc, stack_plot)
ims.append(im)
maxs.append(np.max(np.array(stack_plot)))
mins.append(np.min(np.array(stack_plot)))
perc05.append(np.percentile(stack_plot, 5))
perc95.append(np.percentile(stack_plot, 95))
plt_triangle(ax[a], model_show.fr.iloc[0], np.round(model_show.fr_stim.iloc[0]), 300,
model_show.eod_fr.iloc[0])
if HZ50:
plt_50_Hz_noise(ax[a], 300)
ax[a].set_aspect('equal')
cbar = colorbar_outside(ax[a], im, fig, add=0, width=width)
if many == False:
cbar[0].set_label(label, labelpad=100) # rotation=270,
else:
if a in np.arange(col - 1, 100, col):
cbar[0].set_label(label, labelpad=100)
if a >= row * col - col:
ax[a].set_xlabel(F1_xlabel(), labelpad=20)
ax[0].set_ylabel(F2_xlabel())
if a in np.arange(0, 10, 1) * col:
ax[a].set_ylabel(F2_xlabel())
else:
print('len problem2')
embed()
else:
print('len problem')
embed()
a += 1
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
end_name = cut_title(end_name, datapoints=120)
name_title = end_name
plt.suptitle(name_title) # +' file '
set_clim_shared(clims, ims, maxs, mins, nr_clim, perc05, perc95)
save_visualization(individual_tag=individual_tag, pdf=True, show=show)
def plt_punit(amp_desired=[0.5, 1, 5], xlim=[], cells_plot2=[], show=False, annotate=False):
plot_style()
default_settings(column=2, width=12, length=8) # ts=10, fs=10, ls=10,
save_names = ['noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s_burst_corr']
amps_desired = amp_desired
cell_type_type = 'cell_type_reclassified'
frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type)
cells_plot = cells_plot2
grid = gridspec.GridSpec(2, 1, wspace=0.1, height_ratios=[1, 4.5], hspace=0.25, top=0.96, left=0.095, bottom=0.07,
right=0.92)
colors = {'unkown': 'grey',
' P-unit': 'blue',
' Ampullary': 'green',
'nan': 'grey',
' T-unit': 'purple',
' E-cell': 'red',
' Pyramidal': 'darkred',
' I-cell': 'pink',
' E-cell superficial': 'orange',
' Ovoid': 'cyan'}
grid2 = gridspec.GridSpecFromSubplotSpec(1, 4, grid[0], wspace=0.5,
hspace=0.2)
cell_types = [' P-unit', ' Ampullary']
ax0, ax1, ax2 = plt_scatter_three2(grid2, frame, cell_type_type, annotate, colors)
ax3 = plt.subplot(grid2[3])
axs = [ax3]
burst_name = ['', ' burst corr ']
save_names1 = [save_names[0]]
# todo: um das mit dem burst cv aus der baseline zu machen muss man das auch aus dem baseline file laden
for s, save_name in enumerate(save_names1):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '.csv'
for c, cell_type_it in enumerate(cell_types):
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type_it)
axs[s].scatter(np.array(frame_g['cv']), np.array(frame_g['cv_stim']), alpha=0.5, s=7,
color=colors[str(cell_type_it)])
axs[s].set_xlim(0, 1.5)
axs[s].set_ylim(0, 1.5)
axs[s].set_ylabel('CV stim ' + burst_name[s])
axs[s].set_xlabel('CV ' + burst_name[s])
if len(amps_desired) + 1 == 4:
wr = [0.5, 0, 1, 1, 1]
elif len(amps_desired) + 1 == 3:
wr = [0.5, 0, 1, 1]
elif len(amps_desired) + 1 == 5:
wr = [0.5, 0, 1, 1, 1, 1]
grid1 = gridspec.GridSpecFromSubplotSpec(len(cells_plot), len(amps_desired) + 2, grid[1], hspace=0.17,
wspace=0.35, width_ratios=wr) # ,
plt_cellbody_punitsingle(grid1, ax0, ax1, ax2, frame, colors, amps_desired, save_names, cells_plot, cell_type_type,
ax3=ax3, xlim=xlim, plus=2, burst_corr='_burst_corr_individual')
save_visualization(pdf=True)
show_func(show=show)
def plt_sqaure_isf2(grid1, ax0, ax1, ax2, b, cell, frame, colors, amps_desired, save_names, cells_plot,
cell_type_type,
labeloff=True, predefined_amps2=False, norm=False):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
cell_type = frame_cell[cell_type_type].iloc[0]
fr = frame_cell.fr.iloc[0]
cv = frame_cell.cv.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
# das ist der title fals der square nicht plottet
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz',
fontsize=11, ) # cell[0:13] + color=color+ cell_type
load_name = load_folder_name('calc_RAM') + '/' + save_names[0] + '_' + cell
im = []
axs = []
if os.path.exists(load_name + '.pkl'):
im, axs = plt_stack_single(cell_type, load_name, b, cells_plot, norm, cell, amps_desired, labeloff, grid1,
eod_fr, save_names, predefined_amps2)
################################
# do the scatter of these cells
add = ['', '_burst_corr', ]
if ax0 != []:
ax0.scatter(frame_cell['cv' + add[0]], frame_cell['fr' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if ax1 != []:
ax1.scatter(frame_cell['cv' + add[0]], frame_cell['vs' + add[0]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
if ax2 != []:
ax2.scatter(frame_cell['cv' + add[1]], frame_cell['fr' + add[1]], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
return im, axs
def hist_part2(axi, cell_type, burst_corr, colors, cell, spikes, eod_fr, ):
spikes_all, hists, frs_calc, spikes_cont = load_spikes(spikes, eod_fr)
if 'burst' in burst_corr:
lim_here = find_lim_here(cell, burst_corr)
print(lim_here)
if np.min(np.concatenate(hists)) < lim_here:
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists, spikes_all,
[eod_fr] * len(spikes_all),
[eod_fr] * len(spikes_all), lim=lim_here,
burst_corr=burst_corr)
hists_both = [hists, hists2]
else:
hists_both = [hists]
else:
hists_both = [hists]
if len(hists_both) > 1:
colors_hist = ['grey', colors[str(cell_type)]]
else:
colors_hist = [colors[str(cell_type)]]
for gg in range(len(hists_both)):
hists_here = hists_both[gg]
plt_hist2(axi, hists_here, colors_hist, gg)
def vals_modulation():
pass
def plt_power3(spikes_all_here, axp, color='blue', only_one=False):
spikes_mat = [[]] * len(spikes_all_here)
sampling_calc = 40000
nfft = 2 ** 14
p_array = [[]] * len(spikes_all_here)
f_array = []
if only_one:
one = [spikes_all_here[0]]
else:
one = spikes_all_here
for s, sp in enumerate(one):
if len(sp) > 0:
try:
spikes_mat[s] = cr_spikes_mat(np.array(sp) / 1000, sampling_rate=sampling_calc,
length=int(sampling_calc * np.array(sp[-1]) / 1000))
except:
print('spikes_mat[s] =')
embed()
p_array[s], f_array = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=sampling_calc, NFFT=nfft,
noverlap=nfft // 2)
axp.plot(f_array, p_array[s], color=color) # alpha=float(alpha - 0.05 * s) color=colors[str(cell_type)],
axp.set_xlim(0, 1000)
axp.set_xlabel('Hz')
axp.set_ylabel('Hz')
return p_array, f_array
def plt_hist2(axi, hists_here, colors_hist, gg):
if len(hists_here) > 0:
h = np.concatenate(hists_here)
axi.hist(h, bins=100, color=colors_hist[gg],
label='CV ' + str(np.round(np.std(h) / np.mean(h), 3)),
alpha=0.7) # float(alpha - 0.05 * (hh))
def plt_stack_single(cell_type, load_name, b, cells_plot, norm, cell, amps_desired, labeloff, grid1, eod_fr, save_names,
predefined_amps2):
im = []
axs = []
stack = pd.read_pickle(load_name + '.pkl')
if 'p-unit' in cell_type: # == ['p-unit', ' P-unit']:
file_names_exclude = punit_file_exclude() #
else:
file_names_exclude = ampullary_file_exclude() #
files = stack['file_name'].unique()
fexclude = False
if fexclude:
if len(files) > 1:
stack = stack[~stack['file_name'].isin(file_names_exclude)]
files = stack['file_name'].unique()
amps = stack['amp'].unique()
if predefined_amps2:
for a, amp in enumerate(amps):
if amp not in amps_desired:
pass
amps_defined = [np.min(amps)]
stack_file = stack[stack['file_name'] == files[0]]
for a, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
stack_amp = stack_file[stack_file['amp'] == amp]
lengths = stack_file['stimulus_length'].unique()
length = np.max(lengths)
stack_final = stack_amp[stack_amp['stimulus_length'] == length]
trial_nr_double = stack_final.trial_nr.unique()
# ok das ist glaube ich ein Anzeichen von einem Fehler
if len(trial_nr_double) > 1:
print('trial_nr_double')
try:
stack_final1 = stack_final[stack_final.trial_nr == np.max(trial_nr_double)]
except:
print('stack_final1 problem')
embed()
axs = plt.subplot(grid1[2])
osf = stack_final1.osf
isf = stack_final1.isf
im, min_lim, max_lim = square_func([axs], stack_final1, norm=norm)
plt.colorbar(im, ax=axs)
ax_pos = np.array(axs.get_position()) # [[xmin, ymin], [xmax, ymax]].
fr = stack_final1.fr.unique()[0]
snippets = stack_final1['snippets'].unique()[0]
cv = stack_final1.cv.unique()[0]
ser = stack_final1.ser.unique()[0]
cv_stim = stack_final1.cv_stim.unique()[0]
fr_stim = stack_final1.fr_stim.unique()[0]
ser_stim = stack_final1.ser_stim.unique()[0]
plt.suptitle(cell + ' EODf ' + str(np.round(eod_fr)) + ' Hz ' + '' + 'S.Nr ' + str(
snippets) + ' % ' +
' Base: cv ' + str(np.round(cv, 2)) + ' fr ' + str(
np.round(fr)) + ' Hz' + ' ser ' + str(np.round(ser))
+ ' Stim: cv ' + str(np.round(cv_stim, 2)) + ' fr ' + str(
np.round(fr_stim)) + ' Hz' + ' ser ' + str(np.round(ser_stim)) + ' length ' + str(length)
,
fontsize=11, ) # cell[0:13] + color=color+ cell_type
eod_fr_half_color = 'purple'
fr_color = 'red'
eod_fr_color = 'magenta'
fr_stim_color = 'darkred'
if labeloff:
if b != len(cells_plot) - 1:
remove_xticks(axs)
axs.set_xlabel('')
# plot the input above
axs2 = plt.subplot(grid1[1])
ax_pos2 = np.array(axs2.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width
axs2.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0], ax_pos2[1][1] - ax_pos2[0][1]])
clip_on = True
freqs = [fr, fr * 2, fr_stim, fr_stim * 2, eod_fr, eod_fr * 2, eod_fr / 2]
colors_f = [fr_color, fr_color, fr_stim_color, fr_stim_color, eod_fr_color, eod_fr_color,
eod_fr_half_color]
plt_isf_ps_red(stack_final1, isf, 0, axs2, freqs=freqs, colors=colors_f, clip_on=clip_on)
axs2.set_xlim(min_lim, max_lim)
remove_xticks(axs2)
axs1 = plt.subplot(grid1[0])
if '2.5' in save_names[0]:
burst_name = '2.5 EOD burst corr'
elif 'Individual' in save_names[0]:
burst_name = 'individual burst corr'
elif 'burst' in save_names[0]:
burst_name = '1.5 EOD burst corr'
else:
burst_name = ''
axs1.set_title(' std ' + str(amp) + ' ' + burst_name) # + files[0] + '\n' + names)
remove_xticks(axs1)
ax_pos2 = np.array(axs1.get_position()) # das würde auch gehen:.y0,.y1,.x0,.x1,.width
axs1.set_position([ax_pos[0][0], ax_pos2[0][1], ax_pos[1][0] - ax_pos[0][0],
ax_pos2[1][1] - ax_pos2[0][1]])
plt_isf_ps_red(stack_final1, osf, 0, axs1, freqs=freqs, colors=colors_f, clip_on=clip_on)
axs1.set_xlim(min_lim, max_lim)
return im, axs
def fft_matrix(osf, f_range, isf, norm='', quadrant=''): # stimulus,
# frequencies xaxis
f_mat1 = [f_range] * len(f_range)
# freqeuncies yxis
f_mat2 = np.transpose(f_mat1)
# sum frequency
f_idx_sum = f_mat1 + f_mat2
# diff frequency
f_idx_diff = f_mat1 - f_mat2
rate_matrix1, rate_matrix2 = find_isf_matrices(f_idx_sum, isf[f_range])
scale = find_norm_susept(f_idx_sum, isf[f_range])
rate_matrix = [[]] * len(f_idx_sum)
cross = [[]] * len(f_idx_sum)
osf_mat = [[]] * len(f_idx_sum)
suscept_nonlin = [[]] * len(f_idx_sum)
abs_result = [[]] * len(f_idx_sum)
test = False
if test:
c = isf[f_range][0]
# abs einer complexen Zahl berechnet den pythagoras aufgezogen in dem Raum
np.abs(c)
np.sqrt(np.real(c) ** 2 + np.imag(c) ** 2)
for ff in range(len(f_idx_sum)):
rate_matrix[ff] = osf[f_idx_sum[ff]]
if quadrant == '':
if norm != '':
cross[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] * scale[ff]
else:
cross[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] # *scale
else:
if norm != '':
cross[ff] = np.conj(osf[np.abs(f_idx_diff[ff])]) * np.conj(rate_matrix1[ff]) * rate_matrix2[ff] * scale[
ff]
else:
cross[ff] = np.conj(osf[np.abs(f_idx_diff[ff])]) * np.conj(rate_matrix1[ff]) * rate_matrix2[
ff] # *scale
# hier mache ich quasi die conjunktion des zweiten Arguments weg
# todo: hier das norm einbauen!
suscept_nonlin[ff] = osf[f_idx_sum[ff]] * rate_matrix1[ff] * rate_matrix2[ff] * scale[ff]
test = False
if test:
fig, ax = plt.subplots(1, 3)
ax[0].pcolormesh(np.abs(cross))
ax[1].pcolormesh(abs_result)
ax[2].pcolormesh(np.abs(osf_mat))
plt.show()
return np.array(f_mat1), np.array(f_mat2), np.array(f_idx_sum), np.array(cross)
def exclude_nans_for_corr(file_here, var_item, x=[], y=[], max_x=None, cv_name='cv_base', score='perc99/med'):
if len(x) == 0:
x = file_here[cv_name]
if len(y) == 0:
y = file_here[score]
c_axis = file_here[var_item]
exclude_here = exclude_nans(x, y)
x = x[~exclude_here]
y = y[~exclude_here]
c_axis = c_axis[~exclude_here]
if max_x:
if np.sum(x > max_x) > 0:
y = y[x < max_x]
try:
c_axis = c_axis.loc[x < max_x]
except:
print('c something')
embed()
x = x[x < max_x]
return c_axis, x, y, exclude_here
def exclude_nans(x, y):
exclude_here = (np.isnan(x)) | (np.isnan(y)) | (np.isinf(x)) | (np.isinf(y))
return exclude_here
def fav_calc_RAM_cell_sorting(save_names_load=[
'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s'],
base_sorted='base_ram_sorted', sorted_cv='cv_base', redo=False, redo_base=False,
cell_types_sort=[' P-unit', ' Ampullary', ' P-unit_problem', 'unkown', ' unknown_problem',
' Ampullary_problem', 'unkown', ' Pyramidal',
' T-unit']):
cell_sorted = 'cv_cell_sorted' # ''#'cv_cell_sorted'#''#'cell_sorted'
if 'cell_sorted' in cell_sorted:
cell_type_type = 'cell_type_reclassified'
for s, save in enumerate(save_names_load):
if 'calc_RAM_overview-_simplified_' not in save:
save_names_load[s] = 'calc_RAM_overview-_simplified_' + save
data_names, frame, cell_types = sort_cells_base(small_cvs_first=True,
name='calc_base_data-base_frame_overview.pkl',
cell_sorted=cell_sorted, cell_type_type=cell_type_type,
save_names=save_names_load, sorted_cv=sorted_cv,
base_sorted=base_sorted, cell_type_sort=cell_types_sort,
gwn_filtered=True, redo=redo, redo_base=redo_base)
return data_names
def version_final():
save_name = 'noise_data12_nfft0.5sec_original__StimPreSaved4__first1_order_'
return save_name
def find_stimuli(b):
names = []
for t in b.tags:
if 'filestimulus' in t.name.lower():
names.append(t.name)
return names
def pearson_label(corr, p_value, y, n=True):
if n:
n_add = ', $n=%s$' % (len(y))
else:
n_add = ''
if p_value < 0.001:
p_name = ', $p<0.001$' # ***
elif p_value < 0.01:
p_name = ', $p<0.01$' # **
elif p_value < 0.05:
p_name = ', $p=%s$' % (np.round(p_value, 2)) # + '*'
else:
p_name = ', $p=%s$' % (np.round(p_value, 2))
if np.abs(corr) < 0.01:
add = np.round(
corr, 3)
else:
add = np.round(
corr, 2)
return ' $r=%s$' % add + p_name + n_add
def chose_class_cells(cell_types_sort=[' P-unit_problem', 'unkown', ' unknown_problem', ' Ampullary_problem', 'unkown',
' Pyramidal', ' T-unit', ' P-unit',
' Ampullary', ]):
cell_type_type = 'cell_type_reclassified'
cell_sorted = 'cv_cell_sorted' # ''#'cv_cell_sorted'#''#'cell_sorted'
if 'cell_sorted' in cell_sorted:
save_names_load = [
'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s'] # noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s
base_sorted = 'stim_sorted'
sorted_cv = 'cv_base'
_, _, _ = sort_cells_base(small_cvs_first=True, cell_sorted=cell_sorted,
cell_type_type=cell_type_type,
save_names=save_names_load, sorted_cv=sorted_cv,
base_sorted=base_sorted, cell_type_sort=cell_types_sort,
gwn_filtered=True)
else:
_, _ = find_all_dir_cells()
# wir brauchen eine Zelle die das nix hat (neue Zelle) und eine wo wir die RAM Datei kopiert haben
def kernel_scatter(axy, axx, axs, c, cell_type_here, colors, cv_name, frame_file, score, xmin='no', alpha=1, ymin='no',
color_given=None, n=True, log=True):
###########################
# version comparison with all cells, and no modulation
if not color_given:
color_given = colors[str(cell_type_here)]
x_axis = plot_kernels_on_side(axx, axy, color_given, cv_name, frame_file, score, xmin=xmin,
ymin=ymin)
# todo: hier noch das andere seiteliche histogram machen
# if 'Ampullary' in cell_type_here:
# embed()
x_axis = plt_overview_scatter(axs, c, cell_type_here, colors, cv_name, frame_file,
score, alpha=alpha, n=n, color_text=color_given, color_given=color_given)
if log:
axy.set_yscale('log')
axs.set_yscale('log')
axy.set_yticks_blank()
axy.minorticks_off()
join_x([axs, axx])
join_y([axy, axs])
if log:
make_log_ticks([axy, axs])
remove_yticks(axy)
axy.minorticks_off()
return axs, x_axis
def plot_kernels_on_side(ax_x, ax_y, color, cv_name, frame_file, score, step_y=0, xmin='no', ymin='no', step_x=0,
ymax='no', xlim=None):
x_axis, y_axis = get_axis(cv_name, frame_file, score)
if xlim:
x_axis = x_axis[x_axis < xlim[1]]
kernel_histogram(ax_x, color, np.array(x_axis), xmin=xmin, norm='density', step=step_x, alpha=0.5) # step_x = 0.03
ax_x.show_spines('b')
remove_yticks(ax_x)
remove_xticks(ax_x)
test = False
if test:
from utils_test import test_kernel
test_kernel()
kernel_histogram(ax_y, color, np.array(y_axis), orientation='vertical', norm=True, step=step_y,
alpha=0.5, xmin=ymin, xmax=ymax)
ax_y.set_yticks_blank()
ax_y.show_spines('l')
remove_yticks(ax_y)
remove_xticks(ax_y)
return x_axis
def plt_albi(ax, cell_type_here, colors, max_val, species, x_axis, y_axis):
try:
ax.scatter(x_axis[x_axis < max_val], y_axis[x_axis < max_val],
alpha=1,
s=2.5, color=colors[
str(cell_type_here)],
clip_on=False) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
ax.axhline(2.576, color='grey', linestyle='--', linewidth=1)
ax.set_title(species)
ax.set_yscale('log')
except:
print('axs thing3')
embed()
def plt_eigen(cv_name, ax, c, cell_type_here, cells_extra, colors, frame_file, max_val, score, species):
x_axis, y_axis = get_axis(cv_name, frame_file, score)
x = x_axis[x_axis < max_val]
y = y_axis[x_axis < max_val]
try:
ax.scatter(x, y,
alpha=1,
s=2.5, color=colors[
str(cell_type_here)], label='r=' + str(np.round(np.corrcoef(x, y)[0][1], 2)) + ' n=' + str(
len(y)),
clip_on=False) ##0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
ax.set_title(species)
ax.set_yscale('log')
ax.axhline(2.576, color='grey', linestyle='--', linewidth=1)
if c == 1:
ax.legend()
except:
print('axs thing2')
embed()
if cell_type_here == ' P-unit':
cells_plot2 = p_units_to_show(type_here='eigen_small')[1::]
else:
cells_plot2 = [p_units_to_show(type_here='eigen_small')[0]]
# for cell_plt in cells_plot2:
try:
cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index
except:
print('cells extra here')
embed()
ax.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra],
alpha=1,
s=2.5, color=colors[
str(cell_type_here)], clip_on=False, marker='D', edgecolor='black')
def plt_overview_scatter(ax, c, cell_type_here, colors, cv_name, frame_file, score, x_pos=0, labelpad='no', n=True,
alpha=1, color_text='black', legend_spacing = 0.1, y_val=0.9, fs=7.5, ms=2.5, color_given=None, ha='left'):
if not color_given:
color_given = colors[str(cell_type_here)]
x_axis, y_axis = x_axis_wo_c(cv_name, frame_file, score)
try:
x = x_axis # [x_axis < max_val]
y = y_axis # [x_axis < max_val]
ax.scatter(x, y,
alpha=alpha,
s=ms, color=color_given,
clip_on=False) ##, label=corr0.45 colors[' P-unit']label = cell_type_here[1]+':'+' '+str(file), marker = marker,
print(' mean(' + str(cv_name) + str(np.mean(x)) + ') ' + ' mean(' + str(score) + str(np.mean(y)) + ') ')
except:
print('axs thing1')
embed()
legend_wo_dot(ax, y_val - legend_spacing * c, x, y, ha=ha, color=color_text, fs=fs, x_pos=x_pos, n=n)
if type(labelpad) != str:
ax.set_xlabel(cv_name, labelpad=labelpad)
else:
ax.set_xlabel(cv_name)
return x_axis
def x_axis_wo_c(cv_name, frame_file, score):
x_axis, y_axis = get_axis(cv_name, frame_file, score)
exclude_here = exclude_nans(x_axis, y_axis)
x_axis = x_axis[~exclude_here]
y_axis = y_axis[~exclude_here]
return x_axis, y_axis
def legend_wo_dot(ax, y_pos, x, y, color='black', x_pos=0.5, ha='left', n=True, fs=7.5): # , ha = 'right'
corr, p_value = stats.pearsonr(x, y)
pears_l = pearson_label(corr,
p_value,
y, n=n)
ax.text(x_pos, y_pos, pears_l, fontsize=fs, color=color,
transform=ax.transAxes, ha=ha) # ha="left", va="top",corr
def get_axis(cv_name, frame_file, score):
cvs = frame_file[cv_name] #
x_axis = cvs[frame_file[score] > 0]
y_axis = np.array(frame_file[score])[frame_file[score] > 0]
return x_axis, y_axis
def scatter_with_marginals_colorcoded(var_item_name, ax, cell_type_here, cv_name, frame_file, score, axl=None, axk=None,
ymin='no', xmin='no', ymax='no', top=False,
burst_fraction_reset='burst_fraction_burst_corr_individual_base',
var_item='response_modulation', labelpad=0, max_x=None, n=True, xlim=None,
x_pos=0, fs=7.5, ms=2.5, c=0, burst_fraction=1, sides=True, color_text='black',
ha='left', y_val=0.9, cmap_required=True, color_given=None, cbar_labelpad=0,
legend_spacing=0.1):
##
# function to plot scatter plot, with marignal distributions and colorbar, all optional
cmap = []
x_axis = []
y_axis = []
if len(frame_file) > 0:
if cmap_required: # pay attention if the cell type is not a cell but a fish this is not working anymore
mod_limits = mod_lims_modulation(cell_type_here, frame_file, score)
if cell_type_here == ' P-unit':
cm = 'coolwarm' # 'Blues' #
else:
cm = 'coolwarm' # 'Greens'
cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[
::-1] # len(amps)
cmap = cmap[0:len(mod_limits)][::-1]
frame_file = frame_file[frame_file[burst_fraction_reset] < burst_fraction]
colors = colors_overview()
if not color_given:
color_given = colors[cell_type_here]
if sides:
x_axis = plot_kernels_on_side(axk, axl, color_given, cv_name, frame_file, score, xmin=xmin,
ymin=ymin, ymax=ymax, xlim=xlim)
if var_item != '':
c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, var_item, cv_name=cv_name,
score=score, max_x=max_x)
if len(x_axis) > 0:
im = ax.scatter(x_axis, y_axis,
alpha=1,
s=2.5, c=c_axis, clip_on=True, cmap=cm) # color=cmap[
# ax.set_aspect('equal')
# ff])
corr = 'r=' + str(np.round(np.corrcoef(x_axis, y_axis)[0][1], 2))
# if c == 1:
#
legend_wo_dot(ax, 0.9 - legend_spacing * c, x_axis, y_axis, ha=ha, x_pos=x_pos, n=n)
if top:
ax_tag = ax
else:
ax_tag = axl
cbar, left, bottom, width, height = colorbar_outside(ax_tag, im, plt.gcf(), width=0.01,
pos_axis='top', orientation='bottom',
add=cbar_labelpad,
top=top)
if 'burst_fraction' in cv_name:
ax.set_xlim(0, 1.01)
ax.set_xticks_delta(0.5)
if 'burst_fraction' in score:
ax.set_ylim(0, 1.01)
ax.set_yticks_delta(0.5)
if 'burst_fraction' in var_item:
val_chosen = 1
else:
val_chosen = None
set_clim_same([im], clims='', val_chosen=val_chosen, lim_type='up', nr_clim='None')
cbar.set_label(var_item_name) ##+cell_type_here rotation=270,, labelpad=100
else:
colors = colors_overview()
x_axis = plt_overview_scatter(ax, c, cell_type_here, colors, cv_name, frame_file, score, x_pos=x_pos, ha=ha,
labelpad=labelpad, y_val=y_val, ms=ms, fs=fs, color_text=color_text,
color_given=color_given, n=n, legend_spacing = legend_spacing)
if axl:
axl.get_shared_y_axes().join(*[ax, axl])
axl.show_spines('')
if axk:
axk.get_shared_x_axes().join(*[ax, axk])
axk.show_spines('')
return cmap, x_axis, y_axis
def plt_burst_modulation(var_item_name, ax, cell_type_here, cv_name, frame_file, score, var_item='response_modulation'):
mod_limits = mod_lims_modulation(cell_type_here, frame_file, score)
if cell_type_here == ' P-unit':
cm = 'coolwarm' # 'Blues' #
else:
cm = 'coolwarm' # 'Greens'
cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[
::-1] # len(amps)
cmap = cmap[0:len(mod_limits)][::-1]
c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, var_item, cv_name=cv_name, score=score)
if len(x_axis) > 0:
im = ax.scatter(x_axis, y_axis,
alpha=1,
s=2.5, c=c_axis, clip_on=False, cmap=cm) # color=cmap[
legend_wo_dot(ax, 0.9, x_axis, y_axis, x_pos=0)
cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal"
cbar.set_label(var_item_name + '\n' + cell_type_here) # rotation=270,, labelpad=100
return cmap, x_axis, y_axis
def plt_modulation_overview(ax, cell_type_here, cv_name, frame_file, score, species):
mod_limits = mod_lims_modulation(cell_type_here, frame_file, score)
if cell_type_here == ' P-unit':
cm = 'coolwarm' # 'Blues' #
else:
cm = 'coolwarm' # 'Greens'
cmap = rainbow_cmap(np.arange(len(mod_limits) * 1.6), nrs=len(mod_limits) * 1.6, cm=cm)[
::-1] # len(amps)
cmap = cmap[0:len(mod_limits)][::-1]
c_axis, x_axis, y_axis, exclude_here = exclude_nans_for_corr(frame_file, 'response_modulation',
cv_name=cv_name, score=score)
if len(x_axis) > 0:
ax.set_yscale('log')
im = ax.scatter(x_axis, y_axis,
alpha=1,
s=2.5, c=c_axis, clip_on=False, cmap=cm,
label='r=' + str(np.round(np.corrcoef(x_axis, y_axis)[0][1], 2))) # color=cmap[
legend_wo_dot(ax, 0.98, x_axis, y_axis, x_pos=0.4)
cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal"
cbar.set_label(
'Modulation Depth\n' + cell_type_here + '(' + str(species[0:5]) + '.)') # rotation=270,, labelpad=100
return cmap, x_axis, y_axis
def data_overview():
plot_style()
default_settings(column=2, length=8.5)
grid0 = gridspec.GridSpec(3, 1, wspace=0.54, bottom=0.1,
hspace=0.25, height_ratios=[1, 1, 2], left=0.1, right=0.87, top=0.95)
scoreall = 'perc99/med'
scores = [scoreall + '_diagonal_proj']
##########################
# Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln
save_names = [
'calc_RAM_overview-_simplified_' + version_final(),
] # 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_','calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
x_axis = ["cv_base", "cv_base_w_burstcorr", "cv_base", ]
cv_name_title = ['CV', 'CV$_{BurstCorr}$', 'CV']
species_all = [' Apteronotus leptorhynchus', ' Apteronotus leptorhynchus', ' Eigenmannia virescens']
counter = 0
cell_types = [' P-unit', ' Ampullary', ]
colors = colors_overview()
ax_j = []
axls = []
score = scores[0]
for cv_n, cv_name in enumerate(x_axis):
if cv_n == 0:
pass
else:
pass
redo = False
frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo)
for c, cell_type_here in enumerate(cell_types):
species = species_all[cv_n]
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species)
grid = gridspec.GridSpecFromSubplotSpec(1, 3, grid0[0],
hspace=0, wspace=0.15)
#
if c == 0:
grid_k = gridspec.GridSpecFromSubplotSpec(2, 2, grid[0, cv_n],
hspace=0.1, wspace=0.1, height_ratios=[0.35, 3],
width_ratios=[3, 0.5])
try:
axk = plt.subplot(grid_k[0, 0])
except:
print('grid something')
embed()
ax_j.append(axk)
axs = plt.subplot(grid_k[1, 0])
ax_j.append(axs)
axl = plt.subplot(grid_k[1, 1])
axls.append(axl)
if c in [0, 2]:
axk.set_title(species)
axs, x_axis = kernel_scatter(axl, axk, axs, c, cell_type_here, colors, cv_name,
frame_file, score)
axs.set_xlabel(cv_name_title[cv_n])
if cv_n == 0:
axs.set_ylabel('Perc(99)/Median')
grid_lower = gridspec.GridSpecFromSubplotSpec(2, 2, grid0[2], hspace=0.55, wspace=0.5)
#
cv_name = "cv_base"
species = ' Apteronotus leptorhynchus'
for c, cell_type_here in enumerate(cell_types):
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species)
# embed()
##############################################
# jetzt kommen die extra P-unit statistiken
if cell_type_here == ' P-unit':
if c == 0:
################################
# Modulation, cell type comparison
var_types = ['burst_fraction_burst_corr_base', 'cv_base']
x_axis = ['cv_base', 'burst_fraction_burst_corr_base', ]
var_item_names = ['Burst Fraction', 'CV$'+basename()+'$']
x_axis_names = ['CV$'+basename()+'$', 'Burst Fraction', ]
for v, var_type in enumerate(var_types):
ax = plt.subplot(grid_lower[0, v])
cmap, _, y_axis = plt_burst_modulation(var_item_names[v], ax, cell_type_here,
x_axis[v], frame_file, score,
var_item=var_type)
ax.set_ylabel(score)
ax.set_xlabel(x_axis_names[v])
ax.set_yscale('log')
if v == 0:
############################
# extra Zellen Scatter
# todo: diese Zellen müssen noch runter konvertiert werden
# todo: extra funktion für Zellen über 9 Snippets schreiben und die nochmal extra machen
cells_plot2 = p_units_to_show(type_here='bursts')
cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index
ax.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra],
s=5, color='white', edgecolor='black', alpha=0.5,
clip_on=False) # colors[str(cell_type_here)]
##########################################
# burst gegen CV
var_types = ['burst_fraction_burst_corr_base', 'response_modulation']
var_item_names = ['Burst Fraction', 'Modulatoin']
x_axis = ['cv_base', 'burst_fraction_burst_corr_base']
x_axis_names = ['Burst Fraction$'+basename()+'$', 'Burst Fraction$'+basename()+'$'] # 'CV$'+basename()+'$'
scores_here = ['coherence_', 'burst_fraction_burst_corr_stim'] # 'wo_burstcorr'
for v, var_type in enumerate(var_types):
if scores_here[v] in frame_file.keys():
ax = plt.subplot(grid_lower[1, v])
cmap, _, y_axis = plt_burst_modulation(var_item_names[v], ax, cell_type_here,
x_axis[v], frame_file, scores_here[v],
var_item=var_type)
if v == 1:
ax.plot([0, 1], [0, 1], color='grey', linewidth=0.5)
ax.set_xlabel(x_axis_names[v])
ax.set_ylabel(scores_here[v])
else:
embed()
grid_lower_lower = gridspec.GridSpecFromSubplotSpec(1, 2, grid0[1], wspace=0.5,
hspace=0.55) # , height_ratios = [1,3]
for c, cell_type_here in enumerate(cell_types):
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species)
##############################################
# modulatoin comparison for both cell_types
################################
# Modulation, cell type comparison
# todo: hier die diff werte über die zellen
axs = plt.subplot(grid_lower_lower[c])
cmap, _, y_axis = plt_modulation_overview(axs, cell_type_here,
cv_name, frame_file, score,
species)
axs.set_ylabel(score)
axs.set_xlabel(cv_name)
# axs.get_shared_x_axes().join(*[axs, axd])
######################################################
# hier kommen die kontrast Punkte dazu
# für die Zellen spielt Burst correctin ja keine Rolle
if cell_type_here == ' P-unit':
cells_plot2 = p_units_to_show(type_here='contrasts')[1::]
else:
cells_plot2 = [p_units_to_show(type_here='contrasts')[0]]
# for cell_plt in cells_plot2:
cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index
# ax = plt.subplot(grid[1, cv_n])
axs.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra],
s=5, color='white', edgecolor='black', alpha=0.5, clip_on=False) # colors[str(cell_type_here)]
counter += 1
########################
# modell
model = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
cells = model.cell.unique()
plt_model_overview2(ax_j[1], cells, scores=[scoreall + '_'])
plt.subplots_adjust(left=0.07, right=0.95, top=0.98, bottom=0.05, wspace=0.45, hspace=0.55)
ax_j[0].get_shared_y_axes().join(*[ax_j[1], ax_j[3], ax_j[5], axls[0], axls[1], axls[2]])
ax_j[0].get_shared_x_axes().join(*ax_j)
save_visualization(pdf=True)
def calc_averag_spike_per_burst_package(burst_corr, h, lim, spikes_all):
# also hier nimmt man einfach all jene spikes die übrig bleiben, nur die erste Verteilung zu nehmen ist je unmöglich
if 'inverse' in burst_corr:
first_true = [False]
first_true.extend(h < lim)
else:
first_true = [True]
first_true.extend(h > lim)
test = False
# todo: also entweder man schneidet zusammen nur die teile des inputs und des outputs wo die bursts waren,
# fr_base = len(spikes_all[0]) / (spikes_all[0][-1] / 1000)
#
# todo: hier vielleicht auch noch ein <= machen
# spike_ex = np.array(spikes_all)[np.array(first_true)]
nrs_first_spike = np.arange(0, len(spikes_all), 1)[np.array(first_true)]
burst_nr = np.diff(nrs_first_spike)
return burst_nr, test
def get_float_keys(stack_here):
types = list(map(type, stack_here.keys()))
keys = stack_here.keys()[np.where(np.array(types) == float)]
if len(stack_here) != len(keys):
keys = stack_here.index # ()
return keys
def calc_serial(isi):
corrs2 = []
if len(isi) > 100:
length = len(isi) - 50
else:
length = len(isi)
for l in range(1, length):
previous = isi # [0:-l]
next = np.roll(isi, l)
cut = True
if cut:
previous = previous[l::] # [0:-l]
next = next[l::] # np.roll(isi, l)
corrs2.append(np.corrcoef(next, previous)[0][1])
corr = np.mean(corrs2)
sum_corr = np.sum(corrs2)
test = False
if test:
from utils_test import corr_test
corr_test()
return corr, corrs2[0], sum_corr
def roc_part(titles, devs, group_mean, ranges, fig, subdevision_nr, datapoints, datapoints_way, color, c, chose_score,
cell, DF1_desired_ROC, DF2_desired_ROC, contrast_small,
contrast_big, contrast1, dfs, start, dev, contrast, grid2, plot_group, autodefine2='_dfchosen_',
sorted_on='eod_loc_synch', c1=10, c2=10, cut_matrix='malefemale', autodefine='_dfchosen_closest_first_',
chirps='', data_dir='', mean_type='MeanTrialsIndexPhaseSort', extract='', mult_type='_multsorted2_',
indices=['_allindices_'], eodftype='_psdEOD_', titles_up=['Without female', 'With female']):
_, fr, pivot_chosen, max_val, max_x, max_y, mult, DF1_desired_ROC_exact, DF2_desired_ROC_exact, min_y, min_x, min_val, diff_cut = chose_mat_max_value(
DF1_desired_ROC, DF2_desired_ROC, extract, mult_type, eodftype, indices, cell,
contrast_small,
contrast_big, contrast1, dfs, start, dev, contrast,
autodefine=autodefine2,
cut_matrix=cut_matrix,
chose_score=chose_score, mean_type=mean_type) # chose_score = 'auci02_012-auci_base_01'
colors = ['orange', 'green']
base = cell.split(os.path.sep)[-1] + ".nix"
if data_dir == '':
path = load_folder_name('threefish') + '/' + cell
else:
path = '../data/' + data_dir[c] + cell
full_path = path + '/' + base
try:
file = nix.File.open(full_path, nix.FileMode.ReadOnly)
except:
full_path = '../data/cells/' + cell + '/' + cell + ".nix"
file = nix.File.open(full_path, nix.FileMode.ReadOnly)
print('load extra' + full_path)
b = file.blocks[0]
all_mt_names, mt_names, t_names = get_all_nix_names(b, what='Three')
if mt_names:
nix_there = check_nix_fish(b)
if nix_there:
times_sort = predefine_grouping_frame(b, eodftype=eodftype)
counter_waves = 0
times_sort = times_sort[
(times_sort['c2'] == c2) & (times_sort['c1'] == c1)]
for gg in range(len(DF1_desired_ROC_exact)):
ax1_3 = {}
###################
# all trials in one
grouped = times_sort.groupby(
['c1', 'c2', 'm1, m2'],
as_index=False)
grouped_mean = chose_certain_group(DF1_desired_ROC_exact[gg],
DF2_desired_ROC_exact[gg], grouped,
several=True, emb=False,
concat=True)
# for g in range(len(grouped_mean)):
# if 'Trials' not in mean_type:
###################
# groups sorted by repro tag
grouped = times_sort.groupby(
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
as_index=False)
grouped_orig = chose_certain_group(DF1_desired_ROC_exact[gg],
DF2_desired_ROC_exact[gg],
grouped,
several=True)
###################
# other group variants
colors_groups = ['black', 'brown', 'red', 'pink', 'orange',
'yellow',
'lightgreen', 'green', 'darkgreen',
'lightblue', 'blue', 'navy', 'purple'] # [::-1]
#########################################################
groups_variants = [[grouped_mean]]
ax1_3[plot_group] = plt.subplot(grid2, aspect='auto')
for g, grouped2 in enumerate(groups_variants):
results_diff = grouped2[0].copy()
cv0, spike_pures_split, delays_split = plt_error_bar(plot_group, group_mean, extract, ax1_3,
subdevision_nr, groups_variants.copy(), b,
chirps, mean_type, devs, counter_waves,
results_diff, datapoints, datapoints_way,
grouped_orig, sorted_on=sorted_on,
color=color) #
frame, devname, spikes_pure, group_name, auc_names_condition, auc_names_control = plt_only_roc_repetitive(
extract, ax1_3, fig, grouped2, g,
b,
chirps,
mean_type, devs,
counter_waves,
results_diff, datapoints,
datapoints_way,
grouped_orig,
colors_groups, ranges=ranges, sorted_on=sorted_on, lw=1.5)
fr_end = divergence_title_add_on(group_mean, fr[gg], autodefine)
plt.suptitle(
cell + ' c1: ' + str(group_name[0]) + '% m1: ' + str(
group_name[2][0]) + ' DF1: ' + str(
grouped_mean['DF1, DF2'].iloc[0][
0]) + ' c2: ' + str(
group_name[1]) + '% m2: ' + str(
group_name[2][1]) + ' DF2: ' + str(
grouped_mean['DF1, DF2'].iloc[0][
1]) + '\n Trials nr ' + str(
len(grouped_mean)) + ' sorted on ' + sorted_on + ' ' + mean_type + ' cv ' + str(
np.round(cv0, 2)) + ' ' + fr_end)
try:
mt_group1 = grouped2[0][1]
except:
mt_group1 = grouped2[0]
try:
eodf = np.mean(mt_group1.eodf)
except:
print('eod problem4')
embed()
_, _ = find_length_of_all_trials(grouped2,
group_name)
if g == 0:
if len(auc_names_control) > 0:
ax1_3[plot_group].text(0.5, 2,
auc_names_control[0][0] + '-' +
auc_names_control[0][1], va='center',
ha='center',
transform=ax1_3[plot_group].transAxes, )
else:
ax1_3[plot_group].text(0.5, 2,
'base' + '-' + '01',
auc_names_control[0][1], va='center',
ha='center',
transform=ax1_3[plot_group].transAxes, )
ax1_3[plot_group].text(0.7, 1.5, 'm1: ' + str(
group_name[2][0]) + ' /DF1: ' + str(
int((group_name[2][0] - 1) * eodf))
+ '[Hz]', va='center', ha='center',
transform=ax1_3[plot_group].transAxes,
color=colors[gg])
ax1_3[plot_group].text(0.7, 1.7, ' m2: ' + str(
group_name[2][1]) + '/ DF2: ' + str(
int((group_name[2][1] - 1) * eodf)) + '[Hz] ',
va='center', ha='center',
transform=ax1_3[plot_group].transAxes,
)
if (gg == 0) & (g == len(groups_variants) - 1):
ax1_3[plot_group].set_ylabel('Correct-Detection Rate: ' + titles[plot_group][1])
ax1_3[plot_group].set_xlabel('False-Positive Rate: ' + titles[plot_group][0])
ax1_3[plot_group].set_title(titles_up[plot_group])
return frame, devname, spikes_pure, spike_pures_split, delays_split
def plt_error_bar(plot_group, group_mean, extract, ax1_3, subdevision_nr, groups_variants, b, chirps, mean_type, devs,
counter_waves, results_diff, datapoints, datapoints_way, grouped_orig, sorted_on='eod_loc_synch',
color=['grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey', 'grey',
'grey']):
spike_pures = []
delays_split = []
if '_AllTrialsIndex' in mean_type:
plt_error_bar_trials_roc(counter_waves, results_diff, mean_type, extract, chirps, ax1_3, plot_group, group_mean,
datapoints_way, b, datapoints, devs, groups_variants, grouped_orig, test=False)
else:
range_nr = int(len(group_mean[1]) / subdevision_nr)
grouped_borders = find_group_variants(group_mean[1],
[], start=1, steps=1, ranges=[range_nr])
groups_variants = grouped_borders
for g, grouped in enumerate(groups_variants):
print('group_variants' + str(g))
group_name = grouped_orig[0][0]
if type(list(grouped)[0]) != str:
grouped = list(grouped)
tp = {}
fp = {}
for ggg in range(len(grouped)):
try:
grouped2 = [group_name, grouped[ggg]]
spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three(
grouped2, b, extract, chirps=chirps,
emb=False, mean_type=mean_type, sorted_on=sorted_on)
_, _, _, _ = get_mt_features3(b, grouped2)
except:
grouped2 = grouped[ggg]
spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three(
grouped2, b, extract, chirps=chirps,
emb=False, mean_type=mean_type)
_, _, _, _ = get_mt_features3(b, grouped2)
print('grouped2 problem')
embed()
spike_pures.append(spikes_pure)
mean_isi, std_isi, fr, isi, cv0, ser0, ser_first, sum_corr = calc_baseline_char(
np.array(spikes_pure.base_0), np.abs(fish_cuts[0]), len(spikes_pure.base_0))
t1 = time.time()
t2 = time.time() - t1
print('spikes pure' + str(t2))
dev_nrs = find_right_dev(devname, devs)
t = dev_nrs[0]
frame_dev = frame[frame['dev'] == devname[t]]
delays_length = define_delays_trials(mean_type,
frame, sorted_on=sorted_on)
delays_split.append(delays_length)
t1 = time.time()
array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials(
frame,
devname[t],
delays_length,
mean_type)
t2 = time.time() - t1
print('array' + str(t2))
plt_error_bar0(array0, array01, array012, array02, ax1_3, color, counter_waves, datapoints,
datapoints_way, frame_dev, g, ggg, group_mean,
plot_group, results_diff, tp=tp, fp=fp)
return cv0, spike_pures, delays_split
def plt_error_bar0(array0, array01, array012, array02, ax1_3, color, counter_waves, datapoints, datapoints_way,
frame_dev, g, ggg, group_mean, plot_group, results_diff, tp={}, fp={}): # mt, name_here, mt
threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values(
array0, array01, array02, array012, datapoints_way[0], datapoints[0], results_diff, counter_waves=counter_waves,
id_group=group_mean)
arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting(
[],
[], roc_01, roc_012,
tp_01_all, tp_012_all, frame_dev, 0,
fp_all, tp_02_all, roc_0, roc_02,
[], [])
if ggg == 0:
tp[g] = [auc_tp_condition[0][plot_group]]
fp[g] = [auc_fp_control[0][plot_group]]
else:
tp[g].append(auc_tp_condition[0][plot_group])
fp[g].append(auc_fp_control[0][plot_group])
try:
ax1_3[plot_group].plot(np.transpose(fp[g][ggg]), np.transpose(tp[g][ggg]),
color=color[ggg]) # , alpha=0.5
ax1_3[plot_group].plot(np.transpose(fp[g][ggg]), np.transpose(tp[g][ggg]),
color=color[ggg]) # , alpha=0.5
print(color[ggg])
except:
print('ggg problem')
embed()
test = False
if test:
some_roc_test(fp, tp)
def some_roc_test(fp, tp):
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True)
ax = np.concatenate(ax)
for g in range(len(tp)):
ax[g].plot(np.transpose(fp[g]), np.transpose(tp[g]))
ax[g].plot(np.percentile(fp[g], 95, axis=0), np.percentile(tp[g], 5, axis=0),
color='grey', alpha=0.5)
ax[g].plot(
np.percentile(fp[g], 5, axis=0), np.percentile(tp[g], 5, axis=0),
color='grey',
alpha=0.5)
def define_arrays_for_roc_plotting(roc_01_eod, roc_012_eod, roc_01, roc_012, tp_01_all, tp_012_all, frame_dev, d,
fp_all, tp_02_all, roc_0, roc_02, roc_02_eod, base_here_eod):
roc_array1_eod_condition = []
roc_array_condition = []
second = []
counter1 = []
counter2 = []
auc_names_condition = []
roc_array_eod_control = []
auc_array_condition = []
if frame_dev['control_02'].iloc[d] != []:
second = 'first_sw'
counter1 = 0
counter2 = 2
arrow = True
if second == 'first_sw':
##################################
# NICHT VERWIRREN LASSEN; VON OBEN NACH UNTEN LESEN; hier ist BASE und 01 das erste Bild!
auc_names_control = [['base', '02', ]]
auc_array_control = [[fp_all, tp_02_all, ]]
roc_array_control = [[roc_0, roc_02, ]]
roc_array_eod_control = [[base_here_eod, roc_02_eod, ]]
arrow = False
auc_names_condition = [['01', '012']]
auc_array_condition = [[tp_01_all, tp_012_all]]
if len(roc_01_eod) > 0:
roc_array_condition = [[roc_01, roc_012]]
roc_array1_eod_condition = [[roc_01_eod, roc_012_eod]]
counter1 = 1
counter2 = 3
elif second == 'first':
auc_names_control = [['02', 'base']]
auc_array_control = [[tp_02_all, fp_all]]
roc_array_control = [[roc_02, roc_0]]
roc_array_eod_control = [
[roc_02_eod, base_here_eod]]
auc_names_condition = [['012', '01']]
auc_array_condition = [[tp_012_all, tp_01_all]]
roc_array_condition = [[roc_012, roc_01]]
if len(roc_01_eod) > 0:
roc_array1_eod_condition = [
[roc_012_eod, roc_01_eod]]
elif second == 'second':
##################################
auc_names_control = [['01', 'base']]
auc_array_control = [[tp_01_all, fp_all]]
roc_array_control = [[roc_01, roc_0]]
if len(roc_01_eod) > 0:
roc_array_eod_control = [
[roc_01_eod, base_here_eod]]
auc_names_condition = [['012', '02']]
auc_array_condition = [[tp_012_all, tp_02_all]]
roc_array_condition = [[roc_012, roc_02]]
roc_array1_eod_condition = [[roc_012_eod, roc_02_eod]]
else:
##################################
# das plottet nur die zwei Kombis einmal kontrast 01 zu base und einmal kontrast 02 zu bas
auc_names_control = [['01', 'base'], ['02', 'base']]
auc_array_control = [[tp_01_all, fp_all],
[tp_02_all, fp_all]]
roc_array_control = [[roc_01, roc_0],
[roc_02, roc_0]]
if len(roc_01_eod) > 0:
roc_array_eod_control = [
[roc_01_eod, base_here_eod],
[roc_02_eod, base_here_eod]]
auc_names_condition = [['012', '02'], ['012', '01']]
auc_array_condition = [[tp_012_all, tp_02_all],
[tp_012_all, tp_01_all]]
if len(roc_012) > 0:
roc_array_condition = [[roc_012, roc_02],
[roc_012, roc_01]]
roc_array1_eod_condition = [[roc_012_eod, roc_02_eod],
[roc_012_eod, roc_01_eod]]
else:
auc_names_control = [['base', '01'], '012', ]
auc_array_control = [[fp_all, tp_01_all], tp_012_all]
roc_array_control = [[roc_0, roc_01], roc_012]
return arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_array_condition, auc_names_control, auc_array_control, roc_array_control, roc_array1_eod_condition, roc_array_condition
def find_length_of_all_trials(grouped, group_name):
lengths = []
for l in range(len(grouped)):
if len(grouped[l]) != 2:
grouped2 = [group_name, grouped[l]]
else:
grouped2 = grouped[l]
lengths.append(len(grouped2[1]))
sum_trials = lengths
return sum_trials, lengths
def plt_error_bar_trials_roc(counter_waves, results_diff, mean_type, extract, chirps, ax1_3, plot_group, group_mean,
datapoints_way, b, datapoints, devs, groups_variants, grouped_orig, test=False):
for g, grouped in enumerate(groups_variants):
print('group_variants' + str(g))
if type(list(grouped)[0]) != str:
grouped = list(grouped)
for ggg in range(len(grouped)):
if len(grouped[ggg]) != 2:
pass
else:
pass
spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three(
group_mean, b, extract, chirps=chirps,
emb=False, mean_type=mean_type)
features, mt, name_here, l = get_mt_features3(b, grouped_mean) # not there
dev_nrs = find_right_dev(devname, devs)
t = dev_nrs[0]
frame_dev = frame[frame['dev'] == devname[t]]
delays_length = define_delays_trials(mean_type,
frame)
array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials(
frame,
devname[t],
delays_length,
mean_type)
range_nr = int(len(array0) / 3)
array0_gr = find_group_variants(array0, [], start=1,
steps=1, ranges=[range_nr])
array01_gr = find_group_variants(array01, [], start=1,
steps=1, ranges=[range_nr])
array02_gr = find_group_variants(array02, [], start=1,
steps=1, ranges=[range_nr])
array012_gr = find_group_variants(array012, [], start=1,
steps=1, ranges=[range_nr])
tp = {}
fp = {}
for g in range(len(array012_gr)):
print(g)
for g_in in range(len(array012_gr[g])):
threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values(
array0_gr[g][g_in], array01_gr[g][g_in], array02_gr[g][g_in], array012_gr[g][g_in],
datapoints_way[0], datapoints[0], results_diff, mean_nrs, l, features, name_here, mt,
counter_waves=counter_waves, id_group=group_mean)
arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting(
[],
[], roc_01, roc_012,
tp_01_all, tp_012_all, frame_dev, 0,
fp_all, tp_02_all, roc_0, roc_02,
[], [])
if g_in == 0:
tp[g] = [auc_tp_condition[0][plot_group]]
fp[g] = [auc_fp_control[0][plot_group]]
else:
tp[g].append(auc_tp_condition[0][plot_group])
fp[g].append(auc_fp_control[0][plot_group])
print(g_in)
ax1_3[plot_group].plot(np.transpose(fp[g]), np.transpose(tp[g]), color='grey', alpha=0.5)
ax1_3[plot_group].plot(np.transpose(fp[g]), np.transpose(tp[g]), color='grey', alpha=0.5)
if test:
from utils_test import test_groups
test_groups()
def plt_only_roc_repetitive(extract, ax1_3, fig, grouped, g, b, chirps, mean_type,
devs, counter_waves, results_diff, datapoints, datapoints_way,
grouped_orig, colors_groups, sorted_on='eod_loc_synch', ranges=[], lw=0.4):
print('group_variants' + str(g))
group_name = grouped_orig[0][0]
if type(list(grouped)[0]) != str:
grouped = list(grouped)
roc_color = colors_groups[g]
# todo: diese Funktion funktioniert eigentlich nur für den Mean
for ggg in range(len(grouped)):
if len(grouped[ggg]) != 2:
grouped2 = [group_name, grouped[ggg]]
else:
grouped2 = grouped[ggg]
frame, devname, spikes_pure, auc_names_condition, auc_names_control = plt_only_roc_plot(extract, counter_waves,
results_diff,
datapoints,
datapoints_way, ax1_3,
fig, grouped2, b,
chirps, mean_type, devs,
roc_color=roc_color,
sorted_on=sorted_on,
range_roc=ranges, lw=lw)
return frame, devname, spikes_pure, group_name, auc_names_condition, auc_names_control
def plt_only_roc_plot(extract, counter_waves, results_diff, datapoints, datapoints_way, ax1_3, fig, group_mean, b,
chirps, mean_type, devs, roc_color='black', sorted_on='eod_loc_synch', range_roc=[], lw=0.7):
spikes_pure, fish_number_base, chirp, fish_cuts, time_array, fish_number, smoothened2, smoothed05, eod_mt, eod_interp, effective_duration, cut, devname, frame = cut_spikes_and_eod_three(
group_mean, b, extract, chirps=chirps, emb=False, mean_type=mean_type, sorted_on=sorted_on)
features, mt, name_here, l = get_mt_features3(b, group_mean)
_, _, _, _, _, _ = get_fish_number(b, group_mean, mean_type)
auc_names_condition = []
auc_names_control = []
if len(devname) > 0:
dev_nrs = find_right_dev(devname, devs)
t = dev_nrs[0]
frame_dev = frame[frame['dev'] == devname[t]]
delays_length = define_delays_trials(mean_type,
frame, sorted_on=sorted_on)
if len(delays_length) > 1:
if not delays_length['012']:
print('DEBUGG: add sorted_on=sorted_on in cut_spikes_and_eod_three!!!')
array0, array01, array02, array012, mean_nrs, array012_all, array01_all, array02_all, array0_all = assign_trials(
frame,
devname[t],
delays_length,
mean_type)
test = False
if test:
from utils_test import plt_arrays_sort
plt_arrays_sort()
plt_phase_sorted_trials(frame, devname, array0_all, array0, array01_all, array01, array02_all, array02,
array012_all, array012, )
auc_names_condition, auc_names_control = plt_only_roc_plot0(array0, array01, array012, array02, ax1_3,
counter_waves, datapoints, datapoints_way, features,
fig, frame_dev, group_mean, l, lw, mean_nrs,
mt, name_here, range_roc, results_diff, roc_color)
return frame, devname, spikes_pure, auc_names_condition, auc_names_control
def plt_only_roc_plot0(array0, array01, array012, array02, ax1_3, counter_waves, datapoints, datapoints_way, features,
fig, frame_dev, group_mean, l, lw, mean_nrs, mt, name_here, range_roc, results_diff,
roc_color):
threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves = calc_auci_values(
array0, array01, array02, array012, datapoints_way[0], datapoints[0], results_diff, mean_nrs, l, features,
name_here, mt, counter_waves=counter_waves, id_group=group_mean)
arrow, second, counter1, counter2, auc_names_condition, roc_array_eod_control, auc_tp_condition, auc_names_control, auc_fp_control, roc_array_control, roc_array1_eod_condition, roc_array_condition = define_arrays_for_roc_plotting(
[],
[], roc_01, roc_012,
tp_01_all, tp_012_all, frame_dev, 0,
fp_all, tp_02_all, roc_0, roc_02,
[], [])
a_all = 0
counter_a = 0
# here we choose which of the two arrays comparison we want, only the base-01 or also the 01-012
if len(range_roc) < 1:
range_roc = range(len(auc_fp_control[0]))
for a in range_roc:
if (type(ax1_3) == list) | (type(ax1_3) == dict):
ax = ax1_3[a]
else:
ax = ax1_3
try:
plot_rocs(fig, ax, counter_a, auc_names_control[a_all][a],
a,
auc_names_condition[a_all],
auc_fp_control[a_all],
auc_tp_condition[a_all], results_diff,
auc_names_control[a_all][a],
auc_names_condition[a_all][a],
pos=[0, -0.35], legend=False, arrow=arrow, add=0.2, alpha=1,
counter1=counter1, counter2=counter2, roc_color=roc_color, emb=False, second_roc=False, lw=lw)
except:
print('ax something')
embed()
return auc_names_condition, auc_names_control
def traces_new(array012, position_diff, array01, way, array02, array0):
datapoints_all = [250, 500, 750, 1000, 1500]
restricts = np.arange(1000, len(array012[0]), 4000)
counter = 0
grid0 = gridspec.GridSpec(len(datapoints_all), len(restricts),
bottom=0.07, top=0.93, wspace=0.24,
left=0.06, right=0.92) # hspace=0.4,wspace=0.2,
ax = None
for d, datapoints in enumerate(datapoints_all):
for r, restrict in enumerate(restricts):
print(len(array012[0][0:restrict]))
try:
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, [array012[0][0:restrict]], [array01[0][0:restrict]],
[array02[0][0:restrict]], [array0[0][0:restrict]], t_off=10, way=way, emb=False,
datapoints=datapoints) # , threshhold=threshhold
if type(ax) is None:
ax = plt.subplot(grid1[0, 0], sharey=ax)
else:
ax = plt.subplot(grid1[0, 0])
ax.set_title(str(restrict) + 'dp at all, ' + str(datapoints) + 'dp window', fontsize=7)
plt.plot(np.transpose(roc_0), color='orange')
plt.plot(np.transpose(roc_01), color='green')
plt.subplot(grid1[1, 0], sharey=ax, sharex=ax)
plt.plot(np.transpose(roc_02), color='orange')
plt.plot(np.transpose(roc_012), color='blue')
ax = plt.subplot(grid1[:, 1])
ax.plot(fp_all, tp_01_all, label='base-01', color='green')
ax.plot(tp_02_all, tp_012_all, label='02-012', color='blue')
counter += 1
except:
pass
grid1 = gridspec.GridSpecFromSubplotSpec(2, 2,
hspace=0.4,
wspace=0.2,
subplot_spec=
grid0[counter]) #
plt.legend()
save_visualization()
plt.show()
def calc_auci_values(array0, array01, array02, array012, way, datapoints, results_diff, mean_nrs='', l=[], features=[],
name_here=[], mt=[], counter_waves=[], t_off=10, sampling=40000, position_diff=[],
time_sacrifice=0, id_group=[]):
if position_diff == []:
position_diff = len(results_diff)
# todo: noch hier das mit Mehrfachen einbauen
results_diff.loc[position_diff, 'time_sacrifice'] = time_sacrifice
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off, way=way, datapoints=datapoints)
test = False
if test:
traces_new(array012, position_diff, array01, way, array02, array0)
if len(features) > 0:
results_diff = feature_extract_cut(mt, l, name_here, results_diff, position_diff, features)
results_diff.loc[position_diff, 'datapoints'] = datapoints
results_diff.loc[position_diff, 'datapoints_time'] = np.round(datapoints / sampling, 3)
results_diff.loc[position_diff, 'datapoints_way'] = way
results_diff.loc[position_diff, 'trial_nrs'] = len(roc_01)
results_diff.loc[position_diff, 'mean_nrs'] = mean_nrs
results_diff.loc[position_diff, 't_off'] = t_off
results_diff = save_structure_to_frame(position_diff, results_diff, np.array(id_group[1]['mt']), name='mt')
results_diff = save_structure_to_frame(position_diff, results_diff, id_group[0], name='g_idx')
counter_savename = []
return threshhold, roc_0, roc_02, roc_012, tp_012_all, tp_01_all, fp_all, tp_02_all, roc_01, results_diff, counter_savename, counter_waves
def concat_rocs(control_02, base_orig, control_01, array_012, datapoints, t_off):
roc_02_con = []
arrays = [array_012, control_02, control_01, base_orig]
names = ['012', '02', '01', 'base']
arrays_new = {}
arrays_last = {}
# todo: hier noch was machen
for a, array in enumerate(arrays):
start1 = True
for d in range(len(array)):
trials = np.arange(datapoints + t_off, len(array[d]), datapoints + t_off)
if len(array[d]) > 0:
arrays_new[names[a]] = np.split(array[d], trials)
arrays_new[names[a]] = arrays_new[names[a]][0:-1]
arrays_new[names[a]] = np.array(arrays_new[names[a]])[:, 0:-t_off]
try:
if len(arrays_new[names[a]]) != 1:
if len(arrays_new[names[a]][-1]) != len(arrays_new[names[a]][-2]):
arrays_new[names[a]] = arrays_new[names[a]][0:-1]
except:
print('utils func roc to short')
embed()
if start1 == True:
arrays_last[names[a]] = arrays_new[names[a]]
start1 = False
else:
try:
prev = list(arrays_last[names[a]])
prev.extend(arrays_new[names[a]])
arrays_last[names[a]] = prev
except:
print('array append problem')
embed()
if '012' in arrays_last.keys():
roc_012_con = arrays_last['012']
else:
roc_012_con = []
if 'base' in arrays_last.keys():
base_con = arrays_last['base']
else:
base_con = []
if '01' in arrays_last.keys():
roc_01_con = arrays_last['01']
else:
roc_01_con = []
if '02' in arrays_last.keys():
roc_02_con = arrays_last['02']
roc2_there = True
else:
roc2_there = False
return trials, roc2_there, roc_02_con, roc_012_con, roc_01_con, base_con
def calc_auci_pd(results_diff, position_diff, array_012, control_01, control_02, base_orig, add='', t_off=5, way='',
emb=[], printing=False, datapoints=[], threshhold_step=50, f0='EODf', sampling=40000):
## better to not convert to pandas to much especially if it has numerous of columns.. this might take really long!
if 'mult' in way: # 'mult_minimum','mult_env', 'mult_f1', 'mult_f2'
try:
datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0)
except:
try:
f0 = 'f0'
datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0)
except:
f0 = 'EODf'
datapoints = find_env(way, results_diff, position_diff, sampling, f0=f0)
t1 = time.time()
trials, roc2_there, roc_02_con, roc_012_con, roc_01_con, base_con = concat_rocs(control_02, base_orig, control_01,
array_012, datapoints, t_off)
if printing:
print('ROC0' + str(time.time() - t1))
t1 = time.time()
tp_02, tp_01, tp_012, fp_base, threshhold = threshold_roc(threshhold_step, roc2_there, base_con, roc_01_con,
roc_02_con,
roc_012_con)
if printing:
print('ROC1' + str(time.time() - t1))
t1 = time.time()
tp_012_all = np.mean(tp_012, axis=0)
tp_01_all = np.mean(tp_01, axis=0)
fp_base_all = np.mean(fp_base, axis=0)
if roc2_there == True:
tp_02_all = np.mean(tp_02, axis=0)
else:
tp_02_all = []
results_diff, names_present, names_present_real = calc_auc_diff(tp_02_all, add, results_diff, position_diff,
tp_012_all, fp_base_all, tp_01_all, roc2_there)
if printing:
print('ROC2' + str(time.time() - t1))
t1 = time.time()
if roc2_there == True:
results_diff.loc[position_diff, 'auc_' + '02' + '_' + '01' + add] = metrics.auc(tp_02_all, tp_01_all)
results_diff.loc[position_diff, 'auci_' + '02' + '_' + '01' + add] = np.abs(
np.asarray(results_diff.loc[position_diff, 'auc_' + '02' + '_' + '01' + add]) - 0.5)
names_present_real.append('02' + '_' + '01')
try:
_, interp = interp_arrays(fp_base_all, tp_02_all, step=0.05)
except:
print('Interp line 6662')
embed()
results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='base_02' + add, double=False)
_, interp = interp_arrays(tp_02_all, tp_012_all, step=0.05)
results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='02_012' + add, double=False)
try:
_, interp = interp_arrays(fp_base_all, tp_01_all, step=0.05)
except:
print('interp fp_base_all in utils_func')
embed()
test = False
if test:
fig, ax = plt.subplots(4, 1, sharex=True)
ax[0].plot(control_02[0])
ax[1].plot(control_01[0])
ax[2].plot(array_012[0])
ax[3].plot(base_orig[0])
results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='base_01' + add, double=False)
time_array, interp = interp_arrays(tp_01_all, tp_012_all, step=0.05)
results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='01_012' + add, double=False)
results_diff = save_structure_to_frame(position_diff, results_diff, interp, name='time_array' + add, double=False)
if printing:
print('ROC3' + str(time.time() - t1))
t1 = time.time()
diff_tuples = [
['base_012', 'base_02'],
['base_012', 'base_01'],
['02_012', 'base_02'],
['01_012', 'base_01'],
['02_012', 'base_01'],
['01_012', 'base_02'],
['01_02', 'base_01'],
['02_01', 'base_02']
]
for diff_tuple in diff_tuples:
if ('auc_' + diff_tuple[0] + add in results_diff.keys()) and (
'auc_' + diff_tuple[1] + add in results_diff.keys()):
results_diff.loc[position_diff, 'auc_' + diff_tuple[0] + '-' + 'auc_' + diff_tuple[1] + add] = \
results_diff.loc[position_diff, 'auc_' + diff_tuple[0] + add] - results_diff.loc[
position_diff, 'auc_' + diff_tuple[1] + add]
results_diff.loc[position_diff, 'auci' + diff_tuple[0] + '-' + 'auci_' + diff_tuple[1] + add] = \
results_diff.loc[position_diff, 'auci_' + diff_tuple[0] + add] - results_diff.loc[
position_diff, 'auci_' + diff_tuple[1] + add]
if printing:
print('ROC4' + str(time.time() - t1))
plot = False
if plot:
plot_roc_in_function() # tp_02_all, array_all, t, eod_fe, e, eod_fr, eod_fj, j, fpr, tpr,tp_012_all, fp_base_all, tp_01_all
if emb:
embed()
return trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_base_all, roc_01_con, base_con, roc_02_con, roc_012_con, threshhold
def threshold_roc(threshhold_step, roc2_there, base_con, roc_01_con, roc_02_con, roc_012_con):
if roc2_there == True:
max_arrays = np.concatenate([np.nanmax(base_con, axis=1),
np.nanmax(roc_012_con, axis=1),
np.nanmax(roc_01_con, axis=1),
np.nanmax(roc_02_con, axis=1)])
else:
try:
max_arrays = np.concatenate([np.nanmax(base_con, axis=1),
np.nanmax(roc_012_con, axis=1),
np.nanmax(roc_01_con, axis=1)])
except:
print('base_con problem')
embed()
higher_max = np.nanmax(max_arrays)
lower_max = np.nanmin(max_arrays)
threshhold = np.linspace(0.97 * lower_max,
1.02 * higher_max,
threshhold_step)
try:
tp_012 = np.transpose(
[np.max(roc_012_con, axis=1)] * len(threshhold) > np.transpose(
[threshhold] * len(np.max(roc_012_con, axis=1))))
tp_01 = np.transpose([np.max(roc_01_con, axis=1)] * len(threshhold) > np.transpose(
[threshhold] * len(np.max(roc_01_con, axis=1))))
fp_base = np.transpose([np.max(base_con, axis=1)] * len(threshhold) > np.transpose(
[threshhold] * len(np.max(base_con, axis=1))))
except:
print('threshold in utils_func')
if roc2_there == True:
tp_02 = np.transpose([np.max(roc_02_con, axis=1)] * len(threshhold) > np.transpose(
[threshhold] * len(np.max(roc_02_con, axis=1))))
else:
tp_02 = []
return tp_02, tp_01, tp_012, fp_base, threshhold
def calc_auc_diff(tp_02_all, add, results_diff, position_diff, tp_012_all, fp_base_all, tp_01_all, roc2_there):
if roc2_there == True:
auc_names = ['base', '01', '02', '012', ]
auc_array = [fp_base_all, tp_01_all, tp_02_all, tp_012_all]
else:
auc_names = ['base', '01', '012', ]
auc_array = [fp_base_all, tp_01_all, tp_012_all]
counter_a = 0
names_present = []
names_present_real = []
for a in range(len(auc_array)):
for aa in range(0, len(auc_array), 1):
if auc_names[a] != auc_names[aa]:
if auc_names[a] + '_' + auc_names[aa] not in names_present:
names_present.append(str(auc_names[a]) + '_' + auc_names[aa])
names_present.append(str(auc_names[aa]) + '_' + auc_names[a])
names_present_real.append(str(auc_names[a]) + '_' + auc_names[aa])
results_diff = results_diff.copy()
results_diff.loc[position_diff, 'auc_' + auc_names[a] + '_' + auc_names[aa] + add] = metrics.auc(
auc_array[a], auc_array[aa])
results_diff.loc[position_diff, 'auci_' + auc_names[a] + '_' + auc_names[aa] + add] = np.abs(
np.asarray(
results_diff.loc[position_diff, 'auc_' + auc_names[a] + '_' + auc_names[aa] + add]) - 0.5)
counter_a += 1
return results_diff, names_present, names_present_real
def plot_roc_in_function(tp_02_all, array_all, t, eod_fe, e, eod_fr, eod_fj, j, fpr, tpr, tp_012_all, fp_base_all,
tp_01_all):
plt.title('fe' + str(eod_fe[e]) + 'Hz fj' + str(
eod_fj[j]) + 'Hz fr ' + str(eod_fr) + 'Hz')
plt.subplot(2, 3, 1)
plt.title('012')
plt.plot(array_all['012'][t], color='red')
plt.subplot(2, 3, 2)
plt.title('01')
plt.plot(array_all['control_01'][t], color='blue')
plt.subplot(2, 3, 3)
plt.plot(array_all['012'][t], color='red')
plt.plot(array_all['control_01'][t], color='blue')
plt.subplot(2, 3, 4)
plt.plot(fpr, tpr)
plt.subplot(2, 3, 5)
plt.hist(array_all['012'][t], bins=100, color='red')
plt.hist(array_all['control_01'][t], bins=100, color='blue')
plt.subplot(2, 3, 6)
plt.plot(np.sort(array_all['012'][t]))
plt.plot(np.sort(array_all['control_01'][t]))
plt.show()
plt.subplot(3, 1, 1)
plt.plot(fp_base_all, tp_012_all)
plt.subplot(3, 1, 2)
plt.plot(fp_base_all, tp_01_all) #
plt.subplot(3, 1, 3)
plt.plot(fp_base_all, tp_02_all)
plt.show()
def calc_baseline_char(spike_adapted, stimulus_length, trials_nr_base, data_restrict=[], emb=False):
if emb:
embed()
fr = len(np.concatenate(spike_adapted)) / (stimulus_length * trials_nr_base)
if len(data_restrict) > 0:
max_pos = np.argmax(data_restrict)
isi = np.diff(spike_adapted[max_pos])
else:
isi = np.diff(spike_adapted[0])
if len(isi) < 3:
for i in range(len(spike_adapted)):
if len(spike_adapted[i]) > 2:
isi = np.diff(spike_adapted[i])
if len(isi) > 1:
std_isi = np.std(isi)
mean_isi = np.mean(isi)
cv0 = std_isi / mean_isi
try:
ser0, ser_first, sum_corr = calc_serial(isi)
except:
print('ser problem')
embed()
else:
cv0 = np.float('nan')
ser0 = np.float('nan')
std_isi = np.float('nan')
mean_isi = np.float('nan')
ser_first = np.float('nan')
sum_corr = np.float('nan')
return mean_isi, std_isi, fr, isi, cv0, ser0, ser_first, sum_corr
def find_group_variants(grouped_mean, groups_variants, start=15, steps=10, ranges=[]):
if len(ranges) < 1:
ranges = np.arange(start, len(grouped_mean), steps)
for rr in range(len(ranges)):
# das hier geht über die ranges und sagt wie viele einträge jeweils in einer gruppe sein sollen
# das sind sozusagen unterkategorien von means
# np.shape(groups_variants[0])
# so würde das zum bespeil das gruppieren
# Out[36]: (2, 15, 18523)
# In [37]: np.shape(groups_variants[1])
# Out[37]: (1, 25, 18523)
# In [38]: np.shape(groups_variants[2])
# Out[38]: (1, 35, 18523)
splits = np.arange(ranges[rr], len(grouped_mean), ranges[rr])
splits_done = np.split(grouped_mean, splits)
if len(splits_done[-1]) != ranges[rr]:
splits_done = splits_done[0:-1]
splits_append = splits_done
groups_variants.append(splits_append)
return groups_variants
def plot_second_roc(ax1, fig, array1, array2_0, array2_1, results_diff, names1_0, names1_1, names2_0,
names2_1, add_name='', arrow=True, arrow2=True, pos=[1, -0.45], add=0.1):
if arrow:
ax1.annotate('', ha='center',
xy=(1, 0.5),
xytext=(1.4, 0.5),
arrowprops={"arrowstyle": "->",
"linestyle": "-",
"linewidth": 3,
"color":
'black'},
zorder=1)
fig.texts.append(ax1.texts.pop())
time_interp, array2 = interp_arrays(array2_0[::-1], array2_1[::-1], step=0.01)
auc1 = np.round(
results_diff.iloc[-1][
'auci_' + str(
names1_0) + '_' +
names1_1 + add_name] * 100) / 100
auc2 = np.round(
results_diff.iloc[-1][
'auci_' + str(names2_0) + '_' +
names2_1 + add_name] * 100) / 100
auci_diff = np.round((auc1 - auc2) * 100) / 100
auci_label = 'auci ' + str(auc1) + '-' + str(auc2) + '=' + str(auci_diff)
auc1 = np.round(
results_diff.iloc[-1][
'auc_' + str(names1_0) + '_' + names1_1 + add_name] * 100) / 100
auc2 = np.round(
results_diff.iloc[-1][
'auc_' + str(names2_0) + '_' + names2_1 + add_name] * 100) / 100
auc_diff = np.round((auc1 - auc2) * 100) / 100
auc_label = 'auc ' + str(
auc1) + '-' + str(auc2) + '=' + str(
auc_diff)
if auc_diff > 0:
ax1.text(pos[0], pos[1] - add, auc_label, fontsize=10, transform=ax1.transAxes, color='red')
plt.fill_between(time_interp,
array2,
array1,
color='red', alpha=0.5)
ypos1 = array2[int(len(time_interp) / 2)]
ypos2 = array1[int(len(time_interp) / 2 - 5)]
xpos1 = time_interp[int(len(time_interp) / 2)]
xpos2 = time_interp[int(len(time_interp) / 2 - 5)]
mod = np.sqrt((ypos2 - ypos1) ** 2 + (xpos2 - xpos1) ** 2)
print(arrow)
if arrow2 == True:
if mod > 0.1:
ax1.annotate('', ha='center', xy=(xpos1,
ypos1),
xytext=(xpos2,
ypos2), arrowprops={
"arrowstyle": "<-",
"linestyle": "-",
"linewidth": 1,
"color":
'black'}, zorder=1)
else:
ax1.text(pos[0], pos[1] - add, auc_label,
fontsize=10,
transform=ax1.transAxes,
color='blue')
plt.fill_between(time_interp,
array2,
array1,
color='blue', alpha=0.5)
ypos1 = array1[
int(len(time_interp) / 2)]
ypos2 = array2[
int(len(time_interp) / 2 - 5)]
xpos1 = time_interp[
int(len(time_interp) / 2)]
xpos2 = time_interp[
int(len(time_interp) / 2 - 5)]
mod = np.sqrt((ypos2 - ypos1) ** 2 + (
xpos2 - xpos1) ** 2)
if arrow2 == True:
if mod > 0.1:
ax1.annotate('', ha='center', xy=(xpos1, ypos1), xytext=(xpos2, ypos2),
arrowprops={"arrowstyle": "->",
"linestyle": "-",
"linewidth": 1,
"color":
'black'}, zorder=1)
if auci_diff > 0:
ax1.text(pos[0], pos[1], auci_label, fontsize=10, transform=ax1.transAxes, color='red') # transform
else:
ax1.text(pos[0], pos[1], auci_label,
fontsize=10,
transform=ax1.transAxes,
color='blue') # transform
plt.plot(time_interp, array2,
color='black') # label=auci_label,
def plot_rocs(fig, ax, counter_a, auc_names, a, auc_names1, fp_arrays, tp_arrays, results_diff, names, names1,
counter1=0, counter2=2, legend=True, roc_color=[], alpha=1, lw=0.3, emb=False, second_roc=True,
arrow=True, pos=[1, -0.45], add=-0.1, add_name=''):
if emb:
embed()
fp_array = fp_arrays[a]
if len(tp_arrays) > 0:
tp_array = tp_arrays[a]
fp, tp = interp_arrays(fp_array[::-1], tp_array[::-1], step=0.01)
else:
fp = []
tp = []
if (counter_a == counter1) or (counter_a == counter2):
if (counter_a == 0) or (counter_a == 2):
array2_0 = fp_arrays[a + 1]
else:
array2_0 = fp_arrays[a - 1]
if len(tp_arrays) > 0:
##############################
# je nach dem ob man das mit dem vorhergehenden oder nachkommenden array vergleicht
if (counter_a == 0) or (counter_a == 2):
array2_1 = tp_arrays[a + 1]
else:
array2_1 = tp_arrays[a - 1]
else:
array2_1 = []
if len(roc_color) > 0:
color = roc_color
else:
color = 'purple'
ax.plot(fp, tp,
color=color, linewidth=lw, alpha=alpha, label=(np.round(results_diff.iloc[-1][
'auci_' + str(auc_names[a]) + '_' +
auc_names1[
a] + add_name] * 100) / 100)) # label=auci_label,
if second_roc:
if (counter_a == 0) or (counter_a == 2):
plot_second_roc(ax, fig, tp, array2_0, array2_1, results_diff, auc_names[a], auc_names1[
a], auc_names[a + 1], auc_names1[a + 1], pos=pos, linewidth=lw, add=add, add_name=add_name,
arrow=arrow)
else:
plot_second_roc(ax, fig, tp, array2_0, array2_1, results_diff, auc_names[a], auc_names1[
a], auc_names[a - 1], auc_names1[a - 1], pos=pos, linewidth=lw, add=add, add_name=add_name,
arrow=arrow)
else:
if len(roc_color) > 0:
color = roc_color
else:
color = 'black'
ax.plot(fp, tp,
label=(np.round(results_diff.iloc[-1][
'auci_' + str(names) + '_' + names1 + add_name] * 100) / 100),
color=color, linewidth=lw, alpha=alpha)
if legend:
plt.legend()
ax.plot([0, 1], [0, 1], color='grey',
linestyle='--', linewidth=0.5)
def find_c_unique(name0, contrastc2, contrastc1, ):
c1_uniques = []
c2_uniques = []
combinations = []
if os.path.exists(name0):
spikes_o = pd.read_pickle(name0)
combinations = spikes_o.groupby(
[contrastc1, contrastc2]).groups.keys() # [[contrastc1,contrastc2]].unique()
c2_unique = spikes_o[contrastc2].unique()
c1_unique = spikes_o[contrastc1].unique()
c1_uniques.append(c1_unique)
c2_uniques.append(c2_unique)
c1_unique = np.unique(c1_uniques)[::-1]
c2_unique = np.unique(c2_uniques)[::-1]
return c2_unique, c1_unique, combinations
def find_all_threewave_versions():
dirs = os.listdir(load_folder_name('threefish'))
dir_version = []
sizes = []
for dir in dirs:
if 'invivo' not in dir:
if ('DetectionAnalysis' not in dir) & ('pdf' not in dir) & ('png' not in dir) & ('AllTrials' in dir):
dir_version.append(dir)
sizes.append(os.path.getsize(load_folder_name('threefish') + '/' + dir))
dir_version = np.array(dir_version)
dir_version = dir_version[np.argsort(sizes)[::-1]]
sizes = np.sort(sizes)[::-1]
return dir_version, sizes
def get_fish_number(b, mt_group, mean_type):
mt_list = mt_group[1]['mt']
# todo: da könnte man noch die schleife rausnehmen
for mt_idx, mt_nr in enumerate(list(map(int, mt_list))): # range(start_l, len(mt.positions[:]))
repro_position = mt_nr
features, mt, name_here, l = get_mt_features3(b, mt_group, mt_idx)
# somehow we have mts with negative extend, we exclude these
if (mt.extents[:][mt_nr] > 0).any():
_, _, _, _, fish_number, fish_cuts, whole_duration, cont = load_durations(mt_nr, mt, mt_group[1], mt_idx,
mean_type=mean_type, emb=False)
delay = np.abs(fish_cuts[0])
if cont:
contrast1 = mt_group[1]['c1'].iloc[
mt_idx] # mt_group[1]['c1'].loc[indices[mt_idx]] # mt.metadata.sections[0]['fish1alone']['Contrast']
contrast2 = mt_group[1]['c2'].iloc[mt_idx] # mt.metadata.sections[0]['fish2alone']['Contrast']
return fish_cuts, whole_duration, delay, contrast1, contrast2, repro_position
def model_and_data_isi(nr_clim=10, many=False, width=0.005, HZ50=True, fs=8, nffts=['whole'],
powers=[1], var_items=['contrasts'], contrasts=[0], noises_added=[''], fft_i='forward',
fft_o='forward', spikes_unit='Hz',
mV_unit='mV',
D_extraction_method=['additiv_cv_adapt_factor_scaled'], internal_noise=['RAM'],
external_noise=['RAM'], level_extraction=[''], cut_off2=300,
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''], c_noises=[0.1],
c_signal=[0.9],
cut_offs1=[300], clims='all', restrict='restrict'): # ['eRAM']
stimulus_length = 1 # 20#550 # 30 # 15#45#0.5#1.5 15 45 100
trials_nrs = [1] # [100, 500, 1000, 3000, 10000, 100000, 1000000] # 500
variant = 'sinz'
mimick = 'no'
cell_recording_save_name = ''
trans = 1 # 5
rep = 500000 # 500000#0
repeats = [20, rep] # 250000
good_data, remaining = overlap_cells()
cells_all = good_data
default_settings(column=2, length=4.9) # 0.75
grid = gridspec.GridSpec(1, 4, wspace=0.95, bottom=0.115,
hspace=0.13, left=0.04, right=0.9, top=0.92, width_ratios=[0.7, 1, 1, 1])
a = 0
maxs = []
mins = []
ims = []
perc05 = []
perc95 = []
iternames = [D_extraction_method, external_noise,
internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
nr = '2'
cell_contrasts = ["2013-01-08-aa-invivo-1"]
cells_triangl_contrast = np.concatenate([cells_all, cell_contrasts])
rows = len(good_data) + len(cell_contrasts)
perc = 'perc'
lp = 10
label_model = r'Nonlinearity $\frac{1}{S}$'
for all in it.product(*iternames):
var_type, stim_type_afe, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
fig = plt.figure()
hs = 0.45
#################################
# model cells
adapt_type_name, ax_model, cells_all, dendrid_name, ref_type_name, suptitles, width = plt_model_part(HZ50, a,
a_fe, a_fr,
adapt_type,
c_noise,
c_sig,
cell_recording_save_name,
cells_all,
cut_off1,
cut_off2,
dendrid,
extract,
fft_i,
fft_o, fig,
fs,
grid, hs,
ims,
mV_unit,
many, maxs,
mimick,
mins, nfft,
noise_added,
nr, perc05,
perc95,
power,
ref_type,
repeats,
spikes_unit,
stim_type_afe,
stim_type_noise,
stimulus_length,
trans,
trial_nrs,
var_items,
var_type,
variant,
width,
label=label_model,
rows=rows,
perc=perc,
xlabels=False,
title=False)
#################################
# data cells
grid_data = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[1],
hspace=hs)
print('here')
ax_data, stack_spikes_all, eod_frs = plt_data_susept(fig, grid_data, cells_all, cell_type='p-unit', width=width,
cbar_label=False, lp=lp, title=False)
for ax in ax_data: #
remove_xticks(ax)
ax.set_xticks_delta(100)
ax.text(-0.42, 0.87, F2_xlabel(), ha='center', va='center',
transform=ax.transAxes, rotation=90)
ax.text(1.66, 0.5, nonlin_title(), rotation=90, ha='center', va='center',
transform=ax.transAxes)
ax.arrow_spines('lb')
#################################
# plt isi of data
grid_isi = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[0],
hspace=hs)
spikes_type = 'base'
if spikes_type == 'base':
ax_isi = []
for f, cell in enumerate(cells_triangl_contrast):
######################################################
# frame = load_cv_base_frame(good_data, cell_type_type='cell_type_reclassified')
frame, frame_spikes = load_cv_vals_susept(cells_triangl_contrast, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr',
'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual',
'cell_type_reclassified',
'cell'],
path_sp='/calc_base_data-base_frame_overview.pkl',
frame_general=False)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell,
'cell_type_reclassified',
frame)
spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
ax = plt.subplot(grid_isi[f])
colors = colors_overview()
plt_susept_isi_base(colors[cell_type], ax, isi, delta=5, xlim=[0, 15],
ypos=-0.15, clip_on=True)
ax_isi.append(ax)
else:
ax_isi = plt_isi(cells_all, grid_isi, stack_spikes=stack_spikes_all, eod_frs=eod_frs)
######################################################################
print('started model contrasts')
# hier das mit den Kontrasten
# ok der code ist jetzt halt complex, aber den hab ich jetzt halt schon
# daraus wollen wir so eine Übersicht machen
params_c = {'contrasts': [0, 0.01, 0.025]} # 0.01,
def_repeats = [rep]
params = [
{'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][0]],
'D_extraction_method': D_extraction_method},
{'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][1]],
'D_extraction_method': D_extraction_method},
{'level_extraction': level_extraction, 'repeats': def_repeats, 'contrasts': [params_c['contrasts'][2]],
'D_extraction_method': D_extraction_method},
]
axes_contrast = []
for a in range(3):
grid_model = gridspec.GridSpecFromSubplotSpec(rows, 1, grid[1 + a], hspace=hs)
ax = plt.subplot(grid_model[-1])
axes_contrast.append(ax)
plt_squares_special(params, col_desired=3, var_items=['contrasts', 'repeats', 'level_extraction'], clims='',
show=False, width=width, share=False, cells_given=[cells_all.iloc[0]], perc=perc,
internal_noise=internal_noise, external_noise=external_noise, lp=lp, ax=axes_contrast,
label='', new_plot=False,
titles_plot=False) # 'D_extraction_method','all'"2013-01-08-aa-invivo-1"
print('finished model contrasts')
for a, ax in enumerate(axes_contrast):
if a == 0:
ax_data.append(ax)
ax.text(-0.42, 0.87, F2_xlabel(), ha='center', va='center',
transform=ax.transAxes, rotation=90)
elif a == 1:
ax_model.insert(2, ax)
else:
ax_model.append(ax)
if a != 0:
remove_yticks(ax)
ax.text(1.05, -0.25, F1_xlabel(), ha='center', va='center',
transform=ax.transAxes)
ax.arrow_spines('lb')
ax.set_xlabel('')
ax.set_ylabel('')
if a == 2:
ax.text(1.5, 0.5, label_model, rotation=90, ha='center', va='center',
transform=ax.transAxes)
# axes.join
ax_isi[0].get_shared_x_axes().join(*ax_isi)
end_name = suptitles + ' a_fr=' + str(a_fr) + ' ' + ' dendride=' + str(
dendrid_name) + ' refractory period=' + str(ref_type_name) + ' adapt=' + str(
adapt_type_name) + ' ' + ' cutoff1=' + str(cut_off1) + '' ' stimulus length=' + str(
stimulus_length) + ' ' + ' power=' + str(
power) + ' ' + restrict #
set_clim_same(ims, perc05=perc05, perc95=perc95, lim_type='up', nr_clim=nr_clim, clims=clims)
axes = np.array([np.array(ax_isi),
np.array(ax_data),
np.array(ax_model[0:int(len(ax_model) / 2)]),
np.array(ax_model[int(len(ax_model) / 2)::])])
axes = np.transpose(axes)
fig.tag([list(axes[0])], xoffs=-3, yoffs=2) # , minor_index=2
fig.tag([list(axes[1])], xoffs=-3, yoffs=2) # , minor_index=2
fig.tag([list(axes[2])], xoffs=-3, yoffs=2) # , minor_index=2
# ATTENTION: niemals dieses minor index machen
save_visualization(pdf=True)
def create_full_matrix2(chi_2_ur_numeric, chi_2_ul_numeric):
"""
Creates all areas in the frequency plot from the reduced numeric matrix.
:param chi_2_numeric: The numeric matrix
:return: The full matrix
"""
steps = chi_2_ur_numeric.shape[0]
chi_2_ur = chi_2_ur_numeric.copy()
for i in range(steps):
for j in range(steps):
if i >= j:
chi_2_ur[i][j] = chi_2_ur[j][i]
# matrix for lower left corner
chi_2_ll = np.conj(np.flip(chi_2_ur))
# ok man könnte das auch gleich richtig abspeichern aber so geht das halt auch
# matrix for upper left corner
chi_2_ul = np.transpose(chi_2_ul_numeric).copy()
for i in range(steps):
for j in range(steps):
if i <= j:
chi_2_ul[i][j] = np.conj(chi_2_ul[j][i])
chi_2_ul = np.flip(chi_2_ul, 1)
# matrix for lower right corner
chi_2_lr = np.conj(np.flip(chi_2_ul))
# put all domains together
chi_2 = np.zeros(shape=(2 * steps, 2 * steps), dtype=complex)
for i in range(2 * steps):
for j in range(2 * steps):
# upper right
if i >= steps and j >= steps:
chi_2[i][j] = chi_2_ur[i - steps][j - steps]
# upper left
if i >= steps > j:
chi_2[i][j] = chi_2_ul[i - steps][j - steps]
# lower left
if i < steps and j < steps:
chi_2[i][j] = chi_2_ll[i][j]
# lower right
if i < steps <= j:
try:
chi_2[i][j] = chi_2_lr[i][j - steps]
except:
print('chi something')
embed()
return chi_2
def create_full_matrix(chi_2_numeric):
"""
Creates all areas in the frequency plot from the reduced numeric matrix.
:param chi_2_numeric: The numeric matrix
:return: The full matrix
"""
steps = chi_2_numeric.shape[0]
# matrix for upper right corner
chi_2_ur = chi_2_numeric.copy()
for i in range(steps):
for j in range(steps):
if i >= j:
chi_2_ur[i][j] = chi_2_ur[j][i]
# matrix for lower left corner
chi_2_ll = np.conj(np.flip(chi_2_ur))
# matrix for upper left corner
chi_2_ul = chi_2_numeric.copy()
for i in range(steps):
for j in range(steps):
if i <= j:
chi_2_ul[i][j] = np.conj(chi_2_numeric[j][i])
chi_2_ul = np.flip(chi_2_ul, 1)
# matrix for lower right corner
chi_2_lr = np.conj(np.flip(chi_2_ul))
# put all domains together
chi_2 = np.zeros(shape=(2 * steps, 2 * steps), dtype=complex)
for i in range(2 * steps):
for j in range(2 * steps):
# upper right
if i >= steps and j >= steps:
chi_2[i][j] = chi_2_ur[i - steps][j - steps]
# upper left
if i >= steps > j:
chi_2[i][j] = chi_2_ul[i - steps][j - steps]
# lower left
if i < steps and j < steps:
chi_2[i][j] = chi_2_ll[i][j]
# lower right
if i < steps <= j:
chi_2[i][j] = chi_2_lr[i][j - steps]
return chi_2
def get_axis_on_full_matrix(full_matrix, stack_final):
stack_final = pd.DataFrame(full_matrix, index=np.array(list(map(int, np.concatenate(
[-stack_final.index[::-1], stack_final.index])))),
columns=np.array(list(map(int, np.concatenate(
[-stack_final.columns[::-1],
stack_final.columns])))))
return stack_final
def get_stack_one_quadrant(cell, cell_add, cells_save, path1, save_name_rev, direct_load=False, redo=False,
creation_time_update=False, size_update=True):
stack_saved = get_stack_initial(cell, cell_add, cells_save, path1,
save_name_rev, direct_load=direct_load, redo=redo,
creation_time_update=creation_time_update, size_update=size_update)
stack_plot = change_model_from_csv_to_plots(stack_saved)
try:
stack_plot1 = RAM_norm(stack_plot, model_show=stack_saved)
except:
print('model thing2')
embed()
return stack_plot1, stack_saved
def get_stack_initial(cell, cell_add, cells_save, path1, save_name_rev, direct_load=False, redo=False,
creation_time_update=False, size_update=True):
if direct_load:
if '.pkl' in path1:
model = pd.read_pickle(path1) # pd.read_pickle(path)
else:
model = pd.read_csv(path1, index_col=0) # pd.read_pickle(path)
else:
model = load_model_susept(path1, cells_save, save_name_rev.split(r'/')[-1] + cell_add, redo=redo,
creation_time_update=creation_time_update, size_update=size_update)
#embed()
model_show = model[(model.cell == cell)] # & (model_cell.file_name == file)& (model_cell.power == power)]
return model_show
def find_noise_names(b, base=''):
try:
noise_there = True
except:
noise_there = False
if noise_there:
if base == '':
names_mt_gwns = find_names_gwn(b)
else:
names_mt_gwns = find_mt(b, 'init')
# ich glaube das stimmt so nicht!
else:
names_mt_gwns = []
return names_mt_gwns, noise_there
def get_groups_ram(base_properties, file_name, data_between_2017_2018=''):
try:
base_properties = base_properties.sort_values(by='c', ascending=False)
except:
print('contrast problem sorting')
embed()
# hier muss ich nochmal nach dem file sortieren!
if data_between_2017_2018 != 'all':
file_name_sorted = base_properties[base_properties.file_name == file_name]
else:
file_name_sorted = base_properties
del base_properties
if len(file_name_sorted) < 1:
print('file_name problem')
embed()
file_name_sorted = file_name_sorted.sort_values(by='start', ascending=False)[::-1]
# ich sollte auf dem level schon nach dem richtigen filename filtern!
grouped = file_name_sorted.groupby('c')
return grouped
def calc_abs_power(isfs):
cross = np.zeros(len(isfs[0]), dtype=np.complex_)
for i, isf in enumerate(isfs):
cross += np.abs(isf) ** 2
return cross
def get_osf_restricted(deltat, max_f, spikes_hann, stimulus_hann, fft_i='forward', fft_o='forward'):
f_orig = np.fft.fftfreq(len(stimulus_hann), deltat)
f, restrict = restrict_freqs(f_orig, max_f)
f_range = np.arange(0, len(f), 1)
f_same = f_orig[restrict]
#
# also hier haben wir auch ortho weil das ist auch nah an der Formeln dran!
# embed()
# ALSO FORWARD IST AUCHLAUT MASCHA RICHTIG!
# there none is backward and for forward you need an extra version
try:
osf = np.fft.fft(spikes_hann - np.mean(spikes_hann), norm=fft_o) # das sollte ortho seid
# Und wenn ich den Input genau so nehme wie er produziert wird sollte das passen nicht wahr?
# also forward ist hier das richtige denn es ist das gegenteil von der generierung
# also das soll wohl hier hin weil das genau das ist was wir geneirt haben
isf = np.fft.fft(stimulus_hann - np.mean(stimulus_hann), norm=fft_i) # /nfft # nas sollte forward sein
except:
if fft_o == 'backward':
osf = np.fft.fft(spikes_hann) # das sollte ortho seid
elif fft_o == 'forward':
osf = np.fft.fft(spikes_hann) * deltat
if fft_i == 'backward':
isf = np.fft.fft(stimulus_hann) # /nfft # nas sollte forward sein
elif fft_i == 'forward':
isf = np.fft.fft(stimulus_hann) * deltat # /nfft # nas sollte forward sein
left = np.argmin(np.abs(np.sort(f_orig)) - 0) - 10
left2 = np.argmin(np.abs(np.sort(f_orig)) - 0)
d_isf = np.mean(np.abs(isf)[np.argsort(f_orig)][left:left2])
d_osf = np.mean(np.abs(osf)[np.argsort(f_orig)][left:left2])
# ah wir haben für dieses d_osf wieder an der Null geschaut und nicht die varianz genommen,
# ja deswegen ist das so komisch, und nicht mal hoch zwei
d_osf1 = np.abs(osf)[np.argsort(f_orig)][left2 - 1]
d_isf1 = np.abs(isf)[np.argsort(f_orig)][left2 - 1]
isf = isf[restrict]
osf = osf[restrict]
return d_isf, d_isf1, d_osf, d_osf1, f, f_orig, f_range, f_same, isf, osf, restrict
def get_psds_for_coherence(amp, b, cut_off, file_name_save, indices, max_val, mt, names_mt_gwn, nfft, nr_snippets,
sampling, p11s=[], stimulus_given=[], p12s=[], give_stimulus=False, p22s=[],
spikes_mat_given=[], overlapp='',
dev='original', mean=''):
if overlapp == '':
nr_snippets = nr_snippets * 2
if len(p11s) < 1:
p22s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_)
p12s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_)
p11s = np.zeros(int(nfft / 2 - 1), dtype=np.complex_)
length = range_with_overlap(nfft, overlapp, max_val * sampling)
length_array_isf = [[]] * len(length)
length_array_osf = [[]] * len(length)
length_array_stimulus = [[]] * len(length)
length_array_spikes = [[]] * len(length)
a_mi = 0
mats = []
stims = []
count_final = 0
for mm, m in enumerate(indices):
print(mm)
first, minus, second, stimulus_length = find_first_second(b,
names_mt_gwn,
m, mt,
mm=mm)
if len(spikes_mat_given) > 0:
spikes_mt = spikes_mat_given[m]
else:
spikes_mt = link_arrays_spikes(b, first,
second, minus)
if len(spikes_mt) > 2:
if 'first' in mean:
if second > 10:
second = 10
if len(stimulus_given) > 0:
eod_interp = stimulus_given
deltat = 1 / sampling
else:
deltat, eod_interp, eodf_size, sampling, time_array = load_presaved(b, amp,
file_name_save,
first,
m, mt,
sampling,
second)
if len(spikes_mat_given) > 0:
spikes_mat = spikes_mat_given[m]
else:
spikes_mat = cr_spikes_mat(spikes_mt, sampling,
len(eod_interp)) # [0:-2]int(stimulus_length * 1 / deltat)
if dev == '05': # 'original':
window05 = 0.0005 * sampling
spikes_mat = gaussian_filter(spikes_mat, sigma=window05)
elif dev == '2':
window05 = 0.002 * sampling
spikes_mat = gaussian_filter(spikes_mat, sigma=window05)
elif dev == '5':
window05 = 0.005 * sampling
spikes_mat = gaussian_filter(spikes_mat, sigma=window05)
elif dev == '10':
window05 = 0.01 * sampling
spikes_mat = gaussian_filter(spikes_mat, sigma=window05)
mats.append(spikes_mat)
stims.append(eod_interp) #
fft_i = 'forward'
fft_o = 'forward'
if overlapp == '':
pass
else:
pass
isfs = []
osfs = []
stimulus = []
spikes = []
for aa, a in enumerate(length):
stimulus_array = eod_interp[int(0 + a):int(a + nfft)]
spikes_array = spikes_mat[int(0 + a):int(a + nfft)]
hann_true = True
if hann_true:
hann = np.hanning(len(stimulus_array))
try:
stimulus_hann = (stimulus_array - np.mean(stimulus_array)) * hann
except:
print('stimulus something')
embed()
hann = np.hanning(len(spikes_array))
spikes_hann = (spikes_array - np.mean(spikes_array)) * hann
if (len(stimulus_hann) == nfft) and (len(spikes_hann) == nfft):
d_isf, d_isf1, d_osf, d_osf1, f, f_orig, f_range, f_same, isf, osf, restrict = get_osf_restricted(
deltat,
cut_off,
spikes_hann,
stimulus_hann,
fft_i,
fft_o)
isfs.append(isf)
osfs.append(osf)
spikes.append(np.array(spikes_hann))
stimulus.append(np.array(stimulus_hann))
a_mi += 1
else:
isf = float('nan') * np.zeros(int(nfft / 2 - 1), dtype=np.complex_)
isfs.append(isf)
osf = float('nan') * np.zeros(int(nfft / 2 - 1), dtype=np.complex_)
osfs.append(osf)
spikes_hann = float('nan') * np.zeros(int(nfft)) # , dtype=np.complex_
stimulus_hann = float('nan') * np.zeros(int(nfft)) # , dtype=np.complex_
spikes.append(spikes_hann)
stimulus.append(stimulus_hann)
try:
length_array_isf[aa]
except:
print('length vals')
embed()
if len(length_array_isf[aa]) < 1:
length_array_isf[aa] = [isf]
if give_stimulus:
length_array_spikes[aa] = [spikes_hann]
length_array_stimulus[aa] = [stimulus_hann]
else:
try:
length_array_isf[aa].append(isf)
except:
print('append thing')
embed()
if give_stimulus:
length_array_spikes[aa].append(spikes_hann)
length_array_stimulus[aa].append(stimulus_hann)
if len(length_array_osf[aa]) < 1:
length_array_osf[aa] = [osf]
else:
try:
length_array_osf[aa].append(osf)
except:
print('append thing')
embed()
p22, count_final = crossSpectrum(osfs, osfs)
p12, count_final = crossSpectrum(isfs, osfs)
p11, count_final = crossSpectrum(isfs, isfs)
p22s += p22
p12s += p12
p11s += p11
count_final += 1
if len(length_array_osf) > nr_snippets:
print('length something0')
embed()
length_array_osf = np.array(length_array_osf)
length_array_isf = np.array(length_array_isf)
if give_stimulus:
length_array_stimulus = np.array(length_array_stimulus)
length_array_spikes = np.array(length_array_spikes)
print('done mm')
return count_final, stims, mats, a_mi, f_same, length, p11s, p12s, p22s, length_array_isf, length_array_osf, length_array_spikes, length_array_stimulus,
def coherence_and_mutual_response_wo_sqrt(a_mir, a_mir2, cut_vals, p12_rrs, p22_rrs):
coh_resp = np.abs(p12_rrs / a_mir) ** 2 / (p22_rrs.real / a_mir2) ** 2
mutual_information_resp = - np.log2(1 - coh_resp[cut_vals]) # np.sum(* np.diff(freq)[0]
return coh_resp, mutual_information_resp
def prepeare_test_arrays(indices, length, length_array_isf, length_array_osf, nr_snippets):
length_array_isf_test = [[]] * nr_snippets
length_array_osf_test = [[]] * nr_snippets
for _, _ in enumerate(indices):
for aa, a in enumerate(length):
if len(length_array_isf_test[aa]) < 1:
length_array_isf_test[aa] = [length_array_isf[0][0]]
length_array_osf_test[aa] = [length_array_osf[0][0]]
else:
try:
length_array_isf_test[aa].append(length_array_isf[0][0])
length_array_osf_test[aa].append(length_array_osf[0][0])
except:
print('append thing')
embed()
length_array_osf_test = np.array(length_array_osf_test)
length_array_isf_test = np.array(length_array_isf_test)
return length_array_osf_test, length_array_isf_test
def rescale_colorbar_and_values(abs_matrix, add_nonlin_title=None, resize_val=None):
# das auf jeden Fall auf der finalen Matrix machen!
if add_nonlin_title:
resize_val = find_resize(add_nonlin_title)
max_val = np.max(np.max(abs_matrix, axis=0), axis=0)
if not resize_val:
if max_val > 1000000000000:
resize_val = 1000000000000
elif max_val > 1000000000:
resize_val = 1000000000
elif max_val > 1000000:
resize_val = 1000000
elif max_val > 1000:
resize_val = 1000
elif max_val < 0.000000000001:
resize_val = 0.000000000001 # pico
elif max_val < 0.000000001:
resize_val = 0.000000001 # nano
elif max_val < 0.000001:
resize_val = 0.000001 # micro
elif max_val < 0.001:
resize_val = 0.001 # mili
else:
resize_val = 1
try:
abs_matrix = abs_matrix / resize_val
except:
print('resize thing')
embed()
add_nonlin_title = find_add_title(resize_val)
return abs_matrix, add_nonlin_title, resize_val
def find_resize(add_nonlin_title):
if add_nonlin_title == 'k':
resize_val = 1000
elif add_nonlin_title == 'M':
resize_val = 1000000
elif add_nonlin_title == 'G':
resize_val = 1000000000
elif add_nonlin_title == 'T':
resize_val = 1000000000000
elif add_nonlin_title == 'P':
resize_val = 1000000000000000
elif add_nonlin_title == 'E':
resize_val = 1000000000000000000
elif add_nonlin_title == 'p':
resize_val = 0.000000000001
elif add_nonlin_title == 'n':
resize_val = 0.000000001
elif add_nonlin_title == '$\mu$':
resize_val = 0.000001
elif add_nonlin_title == 'm':
resize_val = 0.001
elif add_nonlin_title == 'c':
resize_val = 0.01
elif add_nonlin_title == 'd':
resize_val = 0.1
elif add_nonlin_title == '':
resize_val = 1
return resize_val
def find_add_title(resize_val):
if resize_val == 1000:
add_nonlin_title = 'k'
elif resize_val == 1000000:
add_nonlin_title = 'M'
elif resize_val == 1000000000:
add_nonlin_title = 'G'
elif resize_val == 1000000000000:
add_nonlin_title = 'T'
elif resize_val == 1000000000000000:
add_nonlin_title = 'P' # Peta
elif resize_val == 1000000000000000000:
add_nonlin_title = 'E' # Peta
elif resize_val == 0.000000000001: # pico
add_nonlin_title = 'p'
elif resize_val == 0.000000001: # pico
add_nonlin_title = 'n'
elif resize_val == 0.000001:
add_nonlin_title = '$\mu$'
elif resize_val == 0.001:
add_nonlin_title = 'm' # mili
elif resize_val == 0.01:
add_nonlin_title = 'c' # centi
elif resize_val == 0.1:
add_nonlin_title = 'd' # deci
elif resize_val == 1:
add_nonlin_title = '' # deci
return add_nonlin_title
def range_with_overlap(nfft, overlap, lenth_here):
if overlap == '_nooverlap_':
length = np.arange(0, lenth_here, nfft)
else:
length = list(map(int, np.arange(0, lenth_here, nfft / 2)))
return length
def gaussKernel(sigma, dt):
""" Creates a Gaussian kernel with a given standard deviation and an integral of 1.
Parameters
----------
sigma : float
The standard deviation of the kernel in seconds
dt : float
The temporal resolution of the kernel, given in seconds.
Returns:
np.array
The kernel in the range -4 to +4 sigma
"""
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
def firing_rate(spikes, duration, sigma=0.005, dt=1. / 20000.):
"""Convert spike times to a firing rate estimated by kernel convolution with a Gaussian kernel.
Args:
spikes (np.array): the spike times
duration (float): the trial duration
sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005.
dt (float, optional): desired temporal resolution of the firing rate. Defaults to 1./20000..
Returns:
np.array: the firing rate
"""
binary = np.zeros(int(np.round(duration / dt)))
indices = np.asarray(np.round(spikes / dt), dtype=int)
binary[indices[indices < len(binary)]] = 1
kernel = gaussKernel(sigma, dt)
rate = np.convolve(kernel, binary, mode="same")
return rate
def get_rates(rr, time, dt, sigma):
valid_stim_count = sum([1 for s in rr if s.duration == time[-1]])
rates = np.zeros((valid_stim_count, len(time)))
index = 0
for i in range(rr.stimulus_count):
if rr[i].duration != time[-1]:
continue
spikes = rr.spikes(i)
rates[index, :] = firing_rate(spikes, time[-1], sigma, dt=dt)
index += 1
return rates
def coherence(rates, stim, nperseg, noverlap, dt):
assert (rates.shape[1] == len(stim))
all_rate_spectra, all_stim_spectra, f = get_rates_stacked(dt, noverlap, nperseg, rates, stim)
csd = cross_spectrum(all_stim_spectra, all_rate_spectra)
stimasd = auto_spectrum(all_stim_spectra)
respasd = auto_spectrum(all_rate_spectra)
coh = csd / (stimasd * respasd)
return f[f >= 0], coh[f >= 0], all_rate_spectra, all_stim_spectra
def get_rates_stacked(dt, noverlap, nperseg, rates, stim):
stim_segments = get_segments(stim, nperseg, noverlap)
f, stim_spectra = spectra(stim_segments, dt)
for i in range(rates.shape[0]):
rate_segments = get_segments(rates[i, :], nperseg, noverlap)
_, rate_spectra = spectra(rate_segments, dt)
if i == 0:
all_rate_spectra = rate_spectra
all_stim_spectra = stim_spectra
else:
all_rate_spectra = np.vstack((all_rate_spectra, rate_spectra))
all_stim_spectra = np.vstack((all_stim_spectra, stim_spectra))
# hier hat er sie alle appended also nicht direct
return all_rate_spectra, all_stim_spectra, f
def exp_coherence(rates, nperseg, noverlap, dt):
mrate = np.mean(rates, axis=0)
mrate_segments = get_segments(mrate, nperseg, noverlap)
f, mrate_spectra = spectra(mrate_segments, dt)
for i in range(rates.shape[0]):
rate_segments = get_segments(rates[i, :], nperseg, noverlap)
_, rate_spectra = spectra(rate_segments, dt)
if i == 0:
all_mrate_spectra = mrate_spectra
all_rate_spectra = rate_spectra
else:
all_mrate_spectra = np.vstack((all_mrate_spectra, mrate_spectra))
all_rate_spectra = np.vstack((all_rate_spectra, rate_spectra))
csd = cross_spectrum(all_mrate_spectra, all_rate_spectra)
mrateasd = auto_spectrum(all_mrate_spectra)
rateasd = auto_spectrum(all_rate_spectra)
c = csd / (rateasd * mrateasd)
return f[f >= 0], c[f >= 0]
def coherences(rates, s, dt, nperseg=2 ** 14):
f, gamma, all_rate_spectra, all_stim_spectra = coherence(rates, s, nperseg, nperseg // 2, dt)
_, exp_gamma = exp_coherence(rates, nperseg, nperseg // 2, dt)
_, rr_gamma = rr_coherence(rates, nperseg, nperseg // 2, dt)
return f, gamma, exp_gamma, rr_gamma, all_rate_spectra, all_stim_spectra
def plt_cohs_ich(ax, coh, coh_resp, coh_resp_mean, coh_resp_directs, coh_resp_restrict, coh_s_directs, cut_off, f_same):
ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp_restrict[f_same < cut_off]),
label='coherence_r_restrict', color='purple')
ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp[f_same < cut_off]), label='coherence_r', color='orange')
ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_resp_directs[f_same < cut_off]),
label='coherence_r_direct', color='orange', linestyle='--')
ax[1].plot(f_same[f_same < cut_off], np.sqrt(coh_s_directs[f_same < cut_off]),
label='coh_s_directs', color='blue', linestyle='--')
ax[1].plot(f_same[f_same < cut_off], coh_resp_mean[f_same < cut_off],
label='coherence_r_expected', color='green')
ax[1].plot(f_same[f_same < cut_off], coh[f_same < cut_off],
label='coherence_s', color='blue')
def plt_cohs_jan(ax, cut_off, exp_gamma, f_jan, gamma, rr_gamma):
ax[0].plot(f_jan[f_jan < cut_off], gamma[f_jan < cut_off], label='gamma')
ax[0].plot(f_jan[f_jan < cut_off], exp_gamma[f_jan < cut_off], label='exp_gamma')
ax[0].plot(f_jan[f_jan < cut_off], rr_gamma[f_jan < cut_off], label='rr_gamma')
ax[0].legend()
def get_mats_same_shape(indices, mats, mt, sampling, stims):
length_val = np.max(mt.extents[:][indices]) * sampling
mats_jan = []
for mm, m in enumerate(mats):
if len(m) == length_val:
mats_jan.append(np.array(m))
stim_jan = stims[mm]
mats_jan = np.array(mats_jan)
return mats_jan, stim_jan
def tranfer_xlabel():
return '$f/'+f_eod_name_core_rm()+'$' # \,[Hz]
def tranfer_xlabel_hz():
return '$f$ [Hz]' # \,[Hz]
def diagonal_xlabel():
return '$f_{1}+f_{2}$\,[Hz]'
def diagonal_xlabel_nothz():
return '$(f_{1}+f_{2})/'+f_eod_name_core_rm()+'$'
def NLI_scorename2():
return 'PNL$(f'+basename()+')$'
def NLI_name2():
return 'PNL'
def NLI_scorename():
return 'NLI$(f'+basename()+')$'
def join_x(axts_all):
axts_all[0].get_shared_x_axes().join(*axts_all)
def join_y(axts_all, mult_val=1.00):
axts_all[0].get_shared_y_axes().join(*axts_all)
if axts_all[0].get_ylim()[-1] != axts_all[1].get_ylim()[-1]:
try:
for a in range(len(axts_all) - 1):
first = axts_all[a].get_ylim() # [-1]
second = axts_all[a + 1].get_ylim() # [-1]
starting_val = np.min([first[0], second[0]])
end_val = np.max([first[1], second[1]])
axts_all[a].set_ylim(starting_val, end_val * mult_val)
axts_all[a + 1].set_ylim(starting_val, end_val * mult_val)
except:
print('joiny something')
def find_peaks_simple(eodf, freq1, freq2, name, color1, color2):
if name == '01':
freqs = [np.abs(freq1), eodf, freq1 + eodf]
colors_peaks = [color1, 'black', color1]
alpha = [1, 1, 0.5]
labels = ['DF1', 'EODf', 'F1']
elif name == '02':
freqs = [np.abs(freq2), eodf, freq2 + eodf]
colors_peaks = [color2, 'black', color2]
alpha = [1, 1, 0.5]
labels = ['DF2', 'EODf', 'F2']
elif name == '012':
freqs = [np.abs(freq1), np.abs(freq2), eodf, freq1 + eodf, freq2 + eodf]
colors_peaks = ['blue', 'red', 'black', 'blue', 'red']
alpha = [1, 1, 1, 0.5, 0.5]
labels = ['DF1', 'DF2', 'EODf', 'F1', 'F2']
elif name == '0':
freqs = [np.abs(freq1), eodf, freq1 + eodf]
colors_peaks = [color1, 'black', color1]
alpha = [1, 1, 0.5]
labels = ['DF1', 'EODf', 'F1']
return labels, alpha, colors_peaks, freqs
def plt_single_contrast(f_counter, axts, axps, f, grid_ll, c_nn, freq1, freq2, eodf, auci_wo, auci_w,
results_diff, a_f2s, fish_jammer, a,
trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length,
model_cells, position_diff, colors_array, reshuffled, sampling, cell_here, c_nr, n,
dev_name, c_nn_nr=1, xpos=1, ypos=1.35, small_peaks=True, second=True, ws=0.6,
start=1, legend=False, v_mem_choice=True):
array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names = calc_roc_amp_core_cocktail_for_plot(
[freq1 + eodf], [freq2 + eodf], auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name,
gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length,
model_cells, position_diff, 0.0005, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n,
reshuffled=reshuffled)
array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_orig, names = calc_roc_amp_core_cocktail_for_plot(
[freq1 + eodf], [freq2 + eodf], auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft, us_name,
gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length,
model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n,
reshuffled=reshuffled)
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
if v_mem_choice:
arrays_here = v_mems[start::]
else:
if start == 1:
arrays_here = [arrays[start::][0][0], arrays[start::][1][0], arrays[start::][2][0]]
else:
arrays_here = [arrays[start::][0][0], arrays[start::][1][0], arrays[start::][2][0], arrays[start::][3][0]]
if start == 1:
arrays_here_psd = [arrays_orig[start::][0][0], arrays_orig[start::][1][0], arrays_orig[start::][2][0]]
else:
arrays_here_psd = [arrays_orig[start::][0][0], arrays_orig[start::][1][0], arrays_orig[start::][2][0],
arrays_orig[start::][3][0]]
names = np.array(['0', '01', '02', '012'])[start::]
spikes_here = arrays_spikes[start::]
colors_array_here = colors_array[start::]
pps = []
for a in range(len(arrays_here)):
grid_pt = gridspec.GridSpecFromSubplotSpec(1, 2,
hspace=0.45,
wspace=ws, width_ratios=[1, 1.2],
subplot_spec=grid_ll[a, c_nn]) # hspace=0.4,wspace=0.2,len(chirps)
axt = plt.subplot(grid_pt[0])
axp = plt.subplot(grid_pt[1])
if v_mem_choice:
axt.eventplot(spikes_here[a], lineoffsets=np.max(arrays_here[a]), color='black') # np.max(v1)*
else:
axt.eventplot(spikes_here[a], lineoffsets=np.max(arrays_here[a]), linelengths=np.max(arrays_here[a]) * 0.1,
color='black') # np.max(v1)*
axts.append(axt)
axps.append(axp)
if f != 0:
remove_yticks(axt)
if a != len(arrays_here) - 1:
remove_xticks(axt)
if f_counter == 0:
axt.set_ylabel(names[a])
set_amplitude_titles(a, a_f2s, arrays_here, axt, c_nr, colors_array_here, start, time)
axt.set_xlim(0.1, 0.22)
try:
pp, ff = ml.psd(arrays_here_psd[a] - np.mean(arrays_here_psd[a]), Fs=sampling, NFFT=nfft,
noverlap=nfft // 2)
except:
print('pp problems')
embed()
pps.append(pp)
axp.plot(ff, pp, color=colors_array_here[a]) # colors_contrasts[c_nn]
maxx = 1000
axp.set_xlim(-10, maxx)
if small_peaks:
labels, alpha, colors_peaks, freqs = find_peaks_simple(eodf, freq1, freq2, names[a], colors_array[1],
colors_array[2])
else:
alpha, labels, colors_peaks, freqs = mult_beat_freqs(eodf, maxx, freq1, color_df_mult=colors_array[1],
color_eodf='black',
color_stim='orange',
color_stim_mult='pink', )
set_titles_freqs(a, axt, c_nn, c_nn_nr, eodf, freq1, freq2, second, start, xpos, ypos)
plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, alphas=alpha)
if a != 2:
remove_xticks(axp)
if c_nn != 0:
remove_yticks(axt)
remove_yticks(axp)
else:
axt.set_ylabel('Hz')
axp.set_ylabel('Hz')
if legend:
axp.legend(loc=(-0.3, 1.2), ncol=3)
axt.set_xlabel('Time [s]')
axp.set_xlabel('Frequency [Hz]')
try:
f_counter += 1
except:
print('counter thing')
embed()
return f_counter
def set_amplitude_titles(a, a_f2s, arrays_here, axt, c_nr, colors_array_here, start, time):
if start == 1:
if a == 0:
axt.set_title(' Amplitude 1 = ' + str(c_nr) + ', Amplitude 2 = 0')
elif a == 1:
axt.set_title(' Amplitude 1 = 0,' + ' Amplitude 2 = ' + str(a_f2s[0]))
else:
axt.set_title(' Amplitude 1 = ' + str(c_nr) + ', Amplitude 2 = ' + str(a_f2s[0]))
try:
axt.plot(time, arrays_here[a], color=colors_array_here[a]) # colors_contrasts[c_nn]
except:
print('time something')
embed()
def set_titles_freqs(a, axt, c_nn, c_nn_nr, eodf, freq1, freq2, second, start, xpos, ypos):
if c_nn == c_nn_nr:
if start == 1:
if a == 0: #
if second:
second_part = 'F1=' + str(np.round(int(freq1 + eodf))) + 'Hz' + ' DF1=F1-EODf=' + str(
int(np.round(freq1))) + 'Hz'
else:
second_part = ''
axt.text(xpos, ypos, 'Only Frequency 1 (F1): \n' + second_part, fontweight='bold', ha='center',
fontsize=10, transform=axt.transAxes, )
elif a == 1:
if second:
second_part = 'F2=' + str(np.round(int(freq2 + eodf))) + 'Hz ' + 'F2-EODf=' + str(
int(np.round(freq2))) + ' Hz '
else:
second_part = ''
axt.text(xpos, ypos, 'Only Frequency 2 (F2): \n' + second_part, fontweight='bold', ha='center',
fontsize=10, transform=axt.transAxes, )
else:
if second:
second_part = 'F1=' + str(int(np.round(freq1 + eodf))) + 'Hz' + ' F1-EODf=' + str(
int(np.round(freq1))) + 'Hz' + ' F2=' + str(int(freq2 + eodf)) + 'Hz ' + 'DF2=F2-EODf=' + str(
int(np.round(freq2))) + ' Hz '
else:
second_part = ''
axt.text(xpos, ypos,
'Frequency 1 (F1) + Frequency 2 (F2): \n' + second_part,
fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, )
def find_diffs(c, frame_cell, diffs, add=''):
if c == 'c1': # 'B1_diff'
try:
frame_cell['diff'] = np.sqrt(
(frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs
except:
# irgnedwann habe ich das Format geändert deswegen
try:
add = '_original'
frame_cell['diff'] = np.sqrt(
(frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs
except:
add = ''
frame_cell['diff'] = np.sqrt(
(frame_cell['amp_B1_012_mean' + add]) ** 2 - frame_cell['amp_B1_01_mean' + add] ** 2) * diffs
else: # 'B2_diff'
try:
frame_cell['diff'] = np.sqrt(
frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs
except:
try:
add = '_original'
frame_cell['diff'] = np.sqrt(
frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs
except:
add = ''
frame_cell['diff'] = np.sqrt(
frame_cell['amp_B2_012_mean' + add] ** 2 - frame_cell['amp_B2_02_mean' + add] ** 2) * diffs
return frame_cell
def find_deltas(frame_cell, c):
diffs = list(np.diff(frame_cell[c]))
diffs.extend([np.diff(frame_cell[c])[-1]])
diffs = np.array(diffs)
return diffs
def find_dfs(frame_cell):
f1s = np.unique(frame_cell.f1)
f2s = np.unique(frame_cell.f2)
df1s = f1s - frame_cell.f0.unique()
df2s = f2s - frame_cell.f0.unique()
# für den Fall dass df falsch ausgerechnet wurde
frame_cell['df1'] = frame_cell.f1 - frame_cell.f0
frame_cell['df2'] = frame_cell.f2 - frame_cell.f0
return frame_cell, df1s, df2s, f1s, f2s
def plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1, freq2, add='', B_replace='DF', sum=True,
add_label='',
alpha=[1, 1, 1, 1], linewidths=[], xscale='', c_dist_recalc=True, lw=1.5,
linestyles=['-', '-', '-', '-', '-'], nr=4, labels=[],
scores=['amp_B1_01_mean_original', 'amp_B1_012_mean_original', 'amp_B2_02_mean_original',
'amp_B2_012_mean_original'],
colors=['green', 'blue', 'orange', 'red', 'grey'], lim_recalc = (0, 70), default_colors=True, delta=True):
if default_colors:
if 'amp_B1_01_mean_original' in frame_cell_orig.keys():
add = '_mean_original'
else:
add = '_mean'
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add=add, nr=nr)
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(labels) < 1:
labels = scores # .replace('amp_','')
c1 = c_dist_recalc_here(c_dist_recalc, frame_cell)
for sss, score in enumerate(scores):
if len(linewidths) > 1:
lw = linewidths[sss]
try:
ax_u1.plot(c1, frame_cell[score], zorder=100, color=colors[sss],
alpha=alpha[sss],
label=labels[sss].replace('_mean', '').replace('amp_', '').replace('B', B_replace).replace(
'original', '').replace('original', 'or').replace('distance', 'dist').replace('power',
'') + add_label,
linestyle=linestyles[sss], linewidth=lw)
except: # vals = frame_cell.keys()#frame_cell.filter(like='var').keys()
print('linestyle problem')
embed()
ax_upper.append(ax_u1)
if sum:
ax_u1.plot(c1, np.sqrt(frame_cell['amp_B2_012_mean' + add] ** 2 + frame_cell['amp_B1_012_mean' + add] ** 2),
zorder=100, color='grey',
label='B1+B2_012', linestyle='--')
ax_u1.plot(c1, np.sqrt(frame_cell['amp_B2_02_mean' + add] ** 2 + frame_cell['amp_B1_01_mean' + add] ** 2),
zorder=100,
color='black',
label='B1_01+B2_02', linestyle='-')
if c_dist_recalc:
if lim_recalc:
ax_u1.set_xlim(lim_recalc)
ax_u1.set_xlabel('C1 Distance [cm]')
else:
ax_u1.set_xlabel('Contrast$_{1}$ [$\%$]')
if xscale == 'log':
ax_u1.set_xscale('log')
ax_u1.set_ylabel(representation_ylabel(delta=delta))
return ax_upper
def c_dist_recalc_here(c_dist_recalc, frame_cell):
c1 = c_dist_recalc_func(frame_cell, cell=frame_cell.cell.unique()[0], c_dist_recalc=c_dist_recalc)
if not c_dist_recalc:
c1 = np.array(c1) * 100
return c1
def representation_ylabel(delta=True):
if delta:
val = 'Amplitude $A(\Delta f)$ [Hz]'
else:
val = 'Amplitude $A(f)$ [Hz]'
return val
def c_dist_recalc_func(frame_cell=[], mult_eod=0.5, c_nrs=[], cell=[], eod_size_change=True, c_dist_recalc=True,
recalc_contrast_in_perc=1):
if len(c_nrs) < 1:
c_nrs = frame_cell.c1
else:
c_nrs = np.array(c_nrs)
if c_dist_recalc: #
# ich weiß noch nicht ob das jetzt für mv oder kontraste stimmen sollte
if eod_size_change:
try:
baseline, b, eod_size, = load_eod_size(cell, max='perc')
except:
try:
update_ssh_file()
baseline, b, eod_size, = load_eod_size(cell, max='perc')
except:
print('EODF SIZE ESTIMATION BIASED')
eod_size = 1
else:
eod_size = 1
# bassiert auf henninger 2020
print('eodfsize' + str(eod_size))
# also eod size mal 0.5 um den maximalen wert zu haben und mal den kotnrast
c1 = c_to_dist(eod_size * mult_eod * c_nrs)
else:
c1 = np.array(c_nrs) * recalc_contrast_in_perc # frame_cell.c1
return c1
def calc_cv_three_wave(results_diff, position_diff, arrays=[], adds=[]):
for a, add in enumerate(adds):
if len(arrays[0]) > 1: # das ist für mehrere Trials
isi = np.diff(
arrays[a][0]) # auch hier nehmen wir erstmal nur den ersten Trial sonst wird das komplex und zu viel
else: # für einen Trial
isi = np.diff(arrays[a][0])
try:
results_diff.loc[position_diff, 'cv' + add] = np.std(isi) / np.mean(isi)
except:
print('ROC problem')
embed()
results_diff.loc[position_diff, 'std_isi' + add] = np.std(isi)
results_diff.loc[position_diff, 'mean_isi' + add] = np.mean(isi)
results_diff.loc[position_diff, 'mean_isi' + add] = np.median(isi)
burst_1, burst_2 = calc_burst_perc(results_diff.loc[position_diff, 'f0'], isi)
results_diff.loc[position_diff, 'burst_1' + add] = burst_1
results_diff.loc[position_diff, 'burst_2' + add] = burst_2
return results_diff
def deltat_sampling_factor(sampling_factor, deltat, eod_fr):
if sampling_factor == 'EODmult':
deltat = 1 / (eod_fr * 2)
elif sampling_factor != '':
deltat = sampling_factor
return deltat
def eod_fish_e_generation(time_array, a_fe=0.2, eod_fe=[750], e=0, phaseshift_fr=0, sampling=20000, stimulus_length=1,
nfft_for_morph=2 ** 14, cell_recording='', fish_morph_harmonics_var='analyze', zeros='zeros',
mimick='no', fish_emitter='Alepto', fish_jammer='Alepto', thistype='emitter'):
# WICHTIG: die ersten vier
time_fish_e = time_array * 2 * np.pi * eod_fe[e]
if (a_fe == 0) and (zeros != 'zeros'):
eod_fish_e = np.ones(len(time_array))
else:
# in case you want to mimick the second here you just can do it based on the dictonary of the thunderfish,
# since here we were not interested in more but one could adapt this based on the eod_fish_r_genration function
if ('Emitter' in mimick) and (thistype == 'emitter'):
if 'Wavemorph' in mimick:
input, eod_fr_data, data_array_eod, time_data_eod, eod_fish_e, pp, ff, p_array_new, f_new, amp, phase, b, t = thunder_morph_func(
phaseshift_fr, cell_recording, eod_fe[e], sampling, stimulus_length, a_fe, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var)
else:
eod_fish_e = fakefish.wavefish_eods(fish_emitter, frequency=eod_fe[e], samplerate=sampling,
duration=stimulus_length, phase0=0.0, noise_std=0.00)
if ('Zenter' in mimick) and ('NotZentered' not in mimick):
eod_fish_e = zenter_and_normalize(eod_fish_e, a_fe)
elif ('Jammer' in mimick) and (thistype == 'jammer'):
if 'Wavemorph' in mimick:
input, eod_fr_data, data_array_eod, time_data_eod, eod_fish_e, pp, ff, p_array_new, f_new, amp, phase, b, t = thunder_morph_func(
phaseshift_fr, cell_recording, eod_fe[e], sampling, stimulus_length, a_fe, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var)
else:
eod_fish_e = fakefish.wavefish_eods(fish_jammer, frequency=eod_fe[e], samplerate=sampling,
duration=stimulus_length, phase0=0.0, noise_std=0.00)
if ('Zenter' in mimick) and ('NotZentered' not in mimick):
eod_fish_e = zenter_and_normalize(eod_fish_e, a_fe)
else:
eod_fish_e = a_fe * np.sin(time_fish_e)
# this is since some of the zeros can be negative, so just made them all positiv
if (a_fe == 0) and (zeros == 'zeros'):
eod_fish_e = np.abs(eod_fish_e)
return eod_fish_e, time_fish_e
def deltat_choice(model_params, sampling_factor, eod_fr, stimulus_length, deltat=None):
if not deltat:
deltat = model_params.pop("deltat")
deltat = deltat_sampling_factor(sampling_factor, deltat, eod_fr)
sampling = 1 / deltat
time_array = np.arange(0, stimulus_length, deltat)
return time_array, sampling, deltat
def get_arrays_for_three(cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length,
offset, model_params, n, variant, adapt_offset, deltat, f2, trials_nr, time_array, f1,
freq1, eod_fr, reshuffle='reshuffled', length_adapt=True, dev=0.0005, zeros='', a_fr=1,
params_dict={'burst_corr': ''}, redo_stim=True, nfft='', cell_recording='', phaseshift_fr=0,
fish_emitter='', fish_receiver='', beat='', fish_jammer='', fish_morph_harmonics_var='',
mimick='', nfft_for_morph='', phase_right='', sampling=''):
#######################################
# do the 01
params_dict['eod_fish1'] = eod_fish1 # }
stimulus_01, meansmoothed05_01, spikes_01, smoothed01, mat01, offset_new, v_mem_output_01 = do_array_for_three(nfft,
a_fr,
fish_receiver,
beat,
zeros,
cell_recording,
fish_emitter,
fish_jammer,
fish_morph_harmonics_var,
mimick,
nfft_for_morph,
phase_right,
sampling,
eod_fish2,
SAM,
eod_stimulus,
eod_fish_r,
freq2,
a_f1,
a_f2,
cell,
stimulus_length,
offset,
model_params,
n,
adapt_offset,
deltat,
f2,
trials_nr,
time_array,
f1,
freq1,
[],
phaseshift_fr,
variant,
eod_fr,
length_adapt=length_adapt,
dict_here=params_dict,
redo_stim=redo_stim,
stim_type='01',
reshuffle=reshuffle,
dev=dev)
# do the 02
stim_02, meansmoth05_02, spikes_02, smoothed02, mat02, offset_new, v_mem_02 = do_array_for_three(nfft,
a_fr,
fish_receiver,
beat,
zeros,
cell_recording,
fish_emitter,
fish_jammer,
fish_morph_harmonics_var,
mimick,
nfft_for_morph,
phase_right,
sampling,
eod_fish2,
SAM,
eod_stimulus,
eod_fish_r,
freq2,
a_f1,
a_f2,
cell,
stimulus_length,
offset,
model_params,
n,
adapt_offset,
deltat,
f2,
trials_nr,
time_array,
f1,
freq1,
[],
phaseshift_fr,
variant,
eod_fr,
length_adapt=length_adapt,
dict_here=params_dict,
redo_stim=redo_stim,
stim_type='02',
reshuffle=reshuffle,
dev=dev)
#######################################
# do the 012
stimulus_012, meansmoothed05_012, spikes_012, smoothed012, mat012, offset_new, v_mem_output_012 = do_array_for_three(
nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer, fish_morph_harmonics_var,
mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM, eod_stimulus, eod_fish_r, freq2, a_f1, a_f2,
cell, stimulus_length, offset, model_params, n, adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, [],
phaseshift_fr, variant, eod_fr, length_adapt=length_adapt, dict_here=params_dict, redo_stim=redo_stim,
stim_type='012', reshuffle=reshuffle, dev=dev)
print(offset)
test = False
if test:
from utils_test import test_stimulus
test_stimulus()
return np.array([v_mem_output_01, v_mem_02,
v_mem_output_012]), offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stim_02, stimulus_012, meansmoothed05_01, spikes_01, meansmoth05_02, spikes_02, meansmoothed05_012, spikes_012
def calc_roc_amp_core_cocktail_for_plot(freq1, freq2, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr,
nfft,
us_name, gain, runs, a_fr, nfft_for_morph, beat,
printing, stimulus_length, model_cells, position_diff, dev, cell_here,
a_f1s=[], dev_name='05', n=1, reshuffled='',
test=False, SAM='_SAM_'):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
f1 = 0
f2 = 0
sampling_factor = ''
phaseshift_fr = 0
cell_recording = ''
mimick = 'no'
zeros = 'zeros'
fish_morph_harmonics_var = 'harmonic'
fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus']
fish_receiver = 'Alepto' #
phase_right = '_phaseright_'
damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35
damping_type = ''
exponential = ''
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr, stimulus_length,
phaseshift_fr, cell_recording, zeros, mimick,
sampling, fish_receiver, deltat, nfft,
nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph,
phaseshift_fr, mimick, zeros,
cell_recording, sampling,
stimulus_length, fish_receiver, deltat,
nfft, damping_type, damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
spikes_base = [[]] * trials_nr
for run in range(runs):
print(run)
for t in range(trials_nr):
stimulus_0 = eod_fish_r
adapt_offset = 'adaptoffset_bisecting'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, v_mem0, noise_final = simulate(cell, offset, stimulus_0, deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2, adaptation_yes_e=f1,
adaptation_yes_t=t, power_alpha=alpha,
power_nr=n, reshuffle=reshuffled,
**model_params)
print(' offset orig ' + str(offset))
test = False
if test:
from utils_test import test_cvs2
test_cvs2()
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
print(' Base ' + str(adapt_offset) + ' offset ' + str(offset))
if printing:
print('Baseline time' + str(time.time() - t1))
_, _ = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
length_adapt = False
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stimulus_length,
dev=dev, length_adapt=length_adapt)
fr = np.mean(base_cut)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
for aaa, a_f2 in enumerate(a_f2s): # [0]
for aa, a_f1 in enumerate(a_f1s): # [0]
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick, fish_emitter,
thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick, fish_jammer,
thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
if test:
eod_stimulus_d = eod_fish1 + eod_fish2
fig, ax = plt.subplots(2, 1, sharex=True, sharey=True)
ax[0].plot(time_array, eod_stimulus_d)
ax[1].plot(time_array, eod_stimulus)
plt.show()
if test:
fig, ax = plt.subplots(4, 1, sharex=True, sharey=True)
ax[0].plot(time_array, eod_fish_r)
ax[1].plot(time_array, eod_fish2)
ax[2].plot(time_array, eod_stimulus)
ax[3].plot(time_array, eod_stimulus)
ax[0].set_xlim(0, 0.1)
plt.show()
adapt_offset_later = ''
v_mem, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length,
offset, model_params, n, variant, adapt_offset_later, deltat, f2, trials_nr, time_array, f1,
freq1, eod_fr, reshuffle=reshuffled, length_adapt=length_adapt, dev=dev)
if test:
fig, ax = plt.subplots(4, 1, sharex=True, sharey=True)
ax[0].plot(stimulus_0)
ax[1].plot(stimulus_01)
ax[2].plot(stimulus_02)
ax[3].plot(stimulus_012)
plt.show()
if printing:
print('Generation process' + str(time.time() - t1))
v_mems = np.concatenate([[v_mem0], v_mem])
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
for dev_n in dev_name:
results_diff.loc[position_diff, 'fr'] = fr
results_diff.loc[position_diff, 'f1'] = freq1[0]
results_diff.loc[position_diff, 'f2'] = freq2[0]
results_diff.loc[position_diff, 'f0'] = eod_fr
results_diff.loc[position_diff, 'df1'] = np.abs(eod_fr - freq1)
results_diff.loc[position_diff, 'df2'] = np.abs(eod_fr - freq2)
results_diff.loc[position_diff, 'cell'] = cell
results_diff.loc[position_diff, 'c1'] = a_f1
results_diff.loc[position_diff, 'c2'] = a_f2
results_diff.loc[position_diff, 'trial_nr'] = trials_nr
####################################
# calc cvs
results_diff = calc_cv_three_wave(results_diff, position_diff,
arrays=[spikes_base, spikes_01, spikes_02, spikes_012],
adds=['_0', '_01', '_02', '_012'])
if dev_n == '05':
dev = 0.0005
# tp_02_all, tp_012_all
# das mit den Means ist jetzt einfach nur ein
# test wie ich die std und var und psd eigentlich gruppieren müsste
if dev_n == 'original':
array0 = [np.mean(mat0, axis=0)]
array01 = [np.mean(mat01, axis=0)]
array02 = [np.mean(mat02, axis=0)]
array012 = [np.mean(mat012, axis=0)]
elif dev_n == '05':
array0 = [np.mean(smoothed0, axis=0)]
array01 = [np.mean(smoothed01, axis=0)]
array02 = [np.mean(smoothed02, axis=0)]
array012 = [np.mean(smoothed012, axis=0)]
####################################################################
arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012]
arrays = [array0, array01, array02, array012]
arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012]
array_mat = [mat0, mat01, mat02, mat012]
######################################
# upper and lower bound berechnen
names = ['0', '01', '02', '012']
position_diff += 1
return array_mat, v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names
def huxley():
np.random.seed(1000)
# Start and end time (in milliseconds)
tmin = 0.0
tmax = 50.0
# Average potassium channel conductance per unit area (mS/cm^2)
gK = 36.0
# Average sodoum channel conductance per unit area (mS/cm^2)
gNa = 120.0
# Average leak channel conductance per unit area (mS/cm^2)
gL = 0.3
# Membrane capacitance per unit area (uF/cm^2)
Cm = 1.0
# Potassium potential (mV)
VK = -12.0
# Sodium potential (mV)
VNa = 115.0
# Leak potential (mV)
Vl = 10.613
# Time values
# Potassium ion-channel rate functions
def alpha_n(Vm):
return (0.01 * (10.0 - Vm)) / (np.exp(1.0 - (0.1 * Vm)) - 1.0)
def beta_n(Vm):
return 0.125 * np.exp(-Vm / 80.0)
# Sodium ion-channel rate functions
def alpha_m(Vm):
return (0.1 * (25.0 - Vm)) / (np.exp(2.5 - (0.1 * Vm)) - 1.0)
def beta_m(Vm):
return 4.0 * np.exp(-Vm / 18.0)
def alpha_h(Vm):
return 0.07 * np.exp(-Vm / 20.0)
def beta_h(Vm):
return 1.0 / (np.exp(3.0 - (0.1 * Vm)) + 1.0)
# n, m, and h steady-state values
def n_inf(Vm=0.0):
return alpha_n(Vm) / (alpha_n(Vm) + beta_n(Vm))
def m_inf(Vm=0.0):
return alpha_m(Vm) / (alpha_m(Vm) + beta_m(Vm))
def h_inf(Vm=0.0):
return alpha_h(Vm) / (alpha_h(Vm) + beta_h(Vm))
# Input stimulus
def Id(t):
if 0.0 < t < 1.0:
return 150.0
elif 10.0 < t < 11.0:
return 50.0
return 0.0
# Compute derivatives
def damping_kashimori(stimulus, time, GbCa_=15, GbKCa_=500):
@jit() # (nopython=True)
def func1(x, t, R, F, stimulus, dt, GbCa, GbKCa, T, CiCa, convert_micro):
volt_to_milli = 1
volt_to_milli2 = 1
convert_milli = 0.001
# valance
zCl = -1
zNa = 1
zCa = 2
zK = 1
# area
Sb = 1
Sa = 20
St = Sb
# capacitance F/m^2
C = 0.01
# Concentration milli Mol/l
ClNa = 15 # * convert_milli
ClK = 0.000000001 # * convert_milli
ClCa = 0.000000001 # * convert_milli
CcNa = 5 # * convert_milli
CcK = 150 # * convert_milli
CcCa = 0.01 # * convert_milli
CiNa = 150 # * convert_milli
CiK = 5 # * convert_milli
# Permeability m/s
PaNa = 1 * (10 ** -11)
PaK = 9.8 * (10 ** -11)
PaCl = 0.5 * (10 ** -11)
PbNa = 2 * (10 ** -9)
PbK = 5 * (10 ** -9)
PbCl = 1 * (10 ** -9)
PtNa = 5 * (10 ** -11)
PtK = 5 * (10 ** -11)
PtCl = 5 * (10 ** -11)
# Initialize
Fa = x[0]
Fb = x[1]
mc = x[2]
CcCaI = x[3]
# CcCa = CcCaI
C0 = x[4]
C1 = x[5]
C2 = x[6]
O2 = x[7]
O3 = x[8]
V = x[9]
m = x[10]
h = x[11]
n = x[12]
# concentration S/m^2
ClCl = ClNa + ClK + 2 * ClCa
CcCl = CcNa + CcK + 2 * CcCa
CiCl = CiNa + CiK + 2 * CiCa
# area
ra = Sa / Sb
rt = St / Sb
Rat = ra + rt + rt * ra
Ft = Fa - Fb
# The equillimbium potential: V
FaNa = ((R * T) / (zNa * F)) * np.log(ClNa / CcNa) * volt_to_milli
FaK = ((R * T) / (zK * F)) * np.log(ClK / CcK) * volt_to_milli
FaCl = ((R * T) / (zCl * F)) * np.log(ClCl / CcCl) * volt_to_milli
FbNa = ((R * T) / (zNa * F)) * np.log(CiNa / CcNa) * volt_to_milli
FbK = ((R * T) / (zK * F)) * np.log(CiK / CcK) * volt_to_milli
FbCa = ((R * T) / (zCa * F)) * np.log(CiCa / CcCa) * volt_to_milli
FbCl = ((R * T) / (zCl * F)) * np.log(CiCl / CcCl) * volt_to_milli
FtNa = ((R * T) / (zNa * F)) * np.log(ClNa / CiNa) * volt_to_milli
FtK = ((R * T) / (zK * F)) * np.log(ClK / CiK) * volt_to_milli
FtCl = ((R * T) / (zCl * F)) * np.log(ClCl / CiCl) * volt_to_milli
# equation 14 : Conductivity leaky channels S/m^2
etaa = (F * Fa / volt_to_milli) / (R * T) # no unit, Fa has to be in Volt for cancellation
faNa = ((zNa ** 2) * (F ** 2) * Fa * PaNa * (ClNa - CcNa * np.exp(zNa * etaa))) / (
(R * T * (Fa - FaNa)) * (1 - np.exp(zNa * etaa)))
faK = ((zK ** 2) * (F ** 2) * Fa * PaK * (ClK - CcK * np.exp(zK * etaa))) / (
(R * T * (Fa - FaK)) * (1 - np.exp(zK * etaa)))
faCl = ((zCl ** 2) * (F ** 2) * Fa * PaCl * (ClCl - CcCl * np.exp(zCl * etaa))) / (
(R * T * (Fa - FaCl)) * (1 - np.exp(zCl * etaa)))
etat = (F * Ft / volt_to_milli) / (R * T)
ftNa = ((zNa ** 2) * (F ** 2) * Ft * PtNa * (ClNa - CiNa * np.exp(zNa * etat))) / (
(R * T * (Ft - FtNa)) * (1 - np.exp(zNa * etat)))
ftK = ((zK ** 2) * (F ** 2) * Ft * PtK * (ClK - CiK * np.exp(zK * etat))) / (
(R * T * (Ft - FtK)) * (1 - np.exp(zK * etat)))
ftCl = ((zCl ** 2) * (F ** 2) * Ft * PtCl * (ClCl - CiCl * np.exp(zCl * etat))) / (
(R * T * (Ft - FtCl)) * (1 - np.exp(zCl * etat)))
etab = (F * Fb / volt_to_milli) / (R * T)
fbNa = ((zNa ** 2) * (F ** 2) * Fb * PbNa * (CiNa - CcNa * np.exp(zNa * etab))) / (
(R * T * (Fb - FbNa)) * (1 - np.exp(zNa * etab)))
fbK = ((zK ** 2) * (F ** 2) * Fb * PbK * (CiK - CcK * np.exp(zK * etab))) / (
(R * T * (Fb - FbK)) * (1 - np.exp(zK * etab)))
fbCl = ((zCl ** 2) * (F ** 2) * Fb * PbCl * (CiCl - CcCl * np.exp(zCl * etab))) / (
(R * T * (Fb - FbCl)) * (1 - np.exp(zCl * etab)))
# A5,6,7,8: Conductivity K and Ca
k_1 = 6 * (10 ** 3)
k_2 = 100 * (10 ** 3)
k_3 = 30 * (10 ** 3)
betac = 1000
delta1 = 0.2
delta2 = 0
delta3 = 0.2
K10 = 6 * (10 ** -3) # mmol/l
K20 = 45 * (10 ** -3) # mmol/l
K30 = 20 * (10 ** -3) # mmol/l
Valpha = 33 * (10 ** -3) * volt_to_milli # V
alphac0 = 450 # -s
Ks = 28000 # -s
U = 0.2
k1 = (k_1 / K10) * np.exp(
(-2 * delta1 * F * Fb / volt_to_milli) / (R * T)) # Fb has to be in V for cancellation
k2 = (k_2 / K20) * np.exp((-2 * delta2 * F * Fb / volt_to_milli) / (R * T))
k3 = (k_3 / K30) * np.exp((-2 * delta3 * F * Fb / volt_to_milli) / (R * T))
alphac = alphac0 * np.exp(-Fb / Valpha)
GbCa_var = GbCa * (mc ** 3)
dC0 = k_1 * C1 - k1 * CcCaI * C0
dC1 = k1 * CcCaI * C0 + k_2 * C2 - (k_1 + k2 * CcCaI) * C1
dC2 = k2 * CcCaI * C1 + alphac * O2 - (k_2 + betac) * C2
dO2 = betac * C2 + k_3 * O3 - (alphac + k3 * CcCaI) * O2
dO3 = k3 * CcCaI * O2 - k_3 * O3
GbKCa_var = GbKCa * (O2 + O3)
# GbKCa_var = GbKCa_ * ((np.abs(O2) + np.abs(O3))/(np.abs(C0)+np.abs(C1)+np.abs(C2)+np.abs(O2)+np.abs(O3)))
Ga = faNa * (Fa - FaNa) + faK * (Fa - FaK) + faCl * (
Fa - FaCl)
Gb = fbNa * (Fb - FbNa) + fbCl * (
Fb - FbCl) + fbK * (Fb - FbK) + GbCa_var * (Fb - FbCa) + GbKCa_var * (
Fb + FbK)
# print('Na:')
Gt = ftNa * (Fa - Fb - FtNa) + ftK * (Fa - Fb - FtK) + ftCl * (
Fa - Fb - FtCl)
try:
stimulus1 = stimulus[int(np.round(t / dt))] * convert_micro * 10000 * volt_to_milli
except:
stimulus1 = stimulus[0] * convert_micro * 10000 * volt_to_milli
print('error in indexing')
print(int(np.round(t / dt)))
dFa = (1 / (C * Rat)) * (-(ra + rt) * stimulus1 - ra * (
1 + rt) * Ga - rt * Gb - rt * Gt)
dFb = (1 / (C * Rat)) * (ra * (ra + rt) * stimulus1 - ra * rt * Ga - (
ra + rt) * Gb + ra * rt * Gt)
dFb_rest = (1 / (C * Rat)) * (- ra * rt * Ga - (
ra + rt) * Gb + ra * rt * Gt)
V0 = 70 * convert_milli * volt_to_milli
Vb = 6.17 * convert_milli * volt_to_milli
beta0 = 0.97
Kb = 940
alpha0 = 22800
Va = 8.01 * convert_milli * volt_to_milli
Ka = 510
beta = beta0 * np.exp((Fb + V0) / Vb) + Kb # s^-1
alpha = alpha0 * np.exp(-(Fb + V0) / Va) + Ka # s^-1
dmc = beta * (1 - mc) - alpha * mc # s^-1
IbCa = GbCa * (mc ** 3) * (Fb - FbCa)
if IbCa == -0:
IbCa = 0
l = 25 * (10 ** -6) # m
Xi = 3.4 * (10 ** -6) # fraction of the volume that aborbs Ca the bigger the fraction the smaller Ca should be
const = U / (2 * Xi * l * F) # * 1000
dCcCa = const * IbCa - Ks * CcCaI
Cn = 0.01 # S/m^2
gNa_ = 1200 # S/m^2
gK_ = 400 # S/m^2
gl = 2.4 # S/m^2
VNa = 0.056 * volt_to_milli2 # V
VK = -0.093 * volt_to_milli2 # V
Vl = -0.03 * volt_to_milli2 # V
w = 4.7
eta = 0.150 * volt_to_milli2 # A/m^2
e = 0.050 * volt_to_milli2 # A/m^2
Ips = (w / (1 + np.exp(-(np.abs(IbCa * volt_to_milli2) - eta) / e))) * volt_to_milli2
dV = (-gNa_ * (m ** 3) * h * (V - VNa) - gK_ * (n ** 4) * (V - VK) - gl * (V - Vl) + Ips) / Cn # Volt
betahinf = 1.8 # *(10**3)#ms
mV = V * 1000 / volt_to_milli2 # mv
beta_ending = 1000 # tranfer ms and mv back in v and s
alpham = (-0.1 * (mV + 40) / (np.exp(-(mV + 40) / 10)) - 1) * beta_ending # s
betam = (4 * np.exp(-(mV + 65) / 18)) * beta_ending # s
alphah = (0.07 * np.exp(-(mV + 65) / 20)) * beta_ending # s
betah = (betahinf / (np.exp(-(mV + 35) / 10) + 1)) * beta_ending # s
alphan = (-0.01 * (mV + 55) / (np.exp(-(mV + 55) / 10) - 1)) * beta_ending # s
betan = (0.125 * np.exp(-(mV + 65) / 80))
dm = alpham - (alpham + betam) * m # s^-1
dh = alphah - (alphah + betah) * h # s^-1
dn = alphan - (alphan + betan) * n # s^-1
vars = (dFa, dFb, dmc, dCcCa, dC0, dC1, dC2, dO2, dO3, dV, dm, dh, dn, dFb_rest)
return vars
convert_mili = 0.001
convert_micro = 0.000001
R = 8.31446261815324
F = 96485.3329
x = [[]] * 14
x[0] = -0.03 # *volt_to_milli # Fa =
x[1] = -0.050 # *volt_to_milli # = Fb
x[2] = 0 # = mc
x[3] = 0.01 * (10 ** -3) # = CcCa
x[4] = 0.05 # = C0
x[5] = 0.05 # = C1
x[6] = 0.1 # = C2
x[7] = 0.4 # = O2
x[8] = 0.4 # = O3
x[9] = -0.08 # *volt_to_milli # = V
x[10] = 1 # = m
x[11] = 0 # = h
x[12] = 0.5 # = n
x[13] = -0.08
CiCa = 1 # *(10**-3) #milli mol/l
T = 298.5 # K in Koshimori
us = odeint(func1, x, time, args=(
R, F, stimulus, np.abs(time[0] - time[1]), GbCa_, GbKCa_, T, CiCa, convert_mili,
convert_micro)) # ,hmin = np.abs(time[0]-time[1]),hmax = np.abs(time[0]-time[1])
mc = us[:, 2]
Fb = us[:, 1]
CcCa = us[:, 3]
zCa = 2
FbCa = ((R * T) / (zCa * F)) * np.log(CiCa / CcCa)
IbCa = GbCa_ * (mc ** 3) * (Fb - FbCa)
return np.std(IbCa), np.abs(np.min(IbCa)) - np.abs(np.max(IbCa)), IbCa, us
def damping_hundspet(mechanical, stimulus, time, damping_type):
convert_SI = 1
convert_SI_minus = 1 / convert_SI
convert_mili = 0.001 * convert_SI
convert_micro = 0.000001 * convert_SI
convert_nano = 0.000000001 * convert_SI
convert_pico = 0.000000000001 * convert_SI
thousand = 1
x = [[]] * 8
x[0] = -0.05 # 30 * convert_mili # = Fb
x[1] = 0.27 # = mc
x[2] = 0.1 # * convert_mili * thousand # = CcCa
x[3] = 0.05 # = C0
x[4] = 0.05 # = C1
x[5] = 0.1 # = C2
x[6] = 0.4 # = O2
x[7] = 0.4 # = O3
@jit() # (nopython=True)
def func1(x, t, R, damping_type, F, convert_SI_minus, stimulus, dt, T, convert_milli, convert_micro, convert_nano,
convert_pico, mechanical):
# Initialize
Fb = x[0]
mc = x[1]
CcCaI = x[2]
C0 = x[3]
C1 = x[4]
C2 = x[5]
O2 = x[6]
O3 = x[7]
GbKCa_ = 16.8 * convert_nano # S
FbK = -80 * convert_milli # V
FbCa = 100 * convert_milli # V
GbCa_ = 4.14 * convert_nano # S
# same parameter for kinetic sheme
k_1 = 300 * convert_SI_minus # s^-1
k_2 = 5000 * convert_SI_minus # s^-1
k_3 = 1500 * convert_SI_minus # s^-1
betac = 1000 * convert_SI_minus # s^-1
alphac0 = 450 * convert_SI_minus # s^-1
delta1 = 0.2
delta2 = 0
delta3 = 0.2
K10 = 6 * convert_micro # mol/l
K20 = 45 * convert_micro # mol/l
K30 = 20 * convert_micro # mol/l
Valpha = 33 * convert_milli # V
alphac = alphac0 * np.exp(-Fb / Valpha) # s^-1
GbCa = GbCa_ * (mc ** 3) # from kashimori: S
if 'nieman' in damping_type:
K1 = (K10 * np.exp((-1 * delta1 * F * Fb) / (R * T)))
k1 = (k_1 / K1)
K2 = (K20 * np.exp((-1 * delta2 * F * Fb) / (R * T)))
k2 = (k_2 / K2)
K3 = (K30 * np.exp((-1 * delta3 * F * Fb) / (R * T)))
k3 = (k_3 / K3)
GbKCa = GbKCa_ * (O2 + O3)
dC0 = 0
C0 = 1 - (C1 + C2 + O2 + O3)
else:
K1 = (K10 * np.exp((-2 * delta1 * F * Fb) / (R * T))) # mol/l
k1 = (k_1 / K1) # l/(s*mol)
K2 = (K20 * np.exp((-2 * delta2 * F * Fb) / (R * T))) # mol/l
k2 = (k_2 / K2) # l/(s*mol)
K3 = (K30 * np.exp((-2 * delta3 * F * Fb) / (R * T))) # mol/l
k3 = (k_3 / K3) # l/(s*mol)
dC0 = k_1 * C1 - k1 * CcCaI * C0 # CcCaI here is mol/l
GbKCa = GbKCa_ * (O2 + O3) # S
dC1 = k1 * CcCaI * C0 + k_2 * C2 - (k_1 + k2 * CcCaI) * C1
dC2 = k2 * CcCaI * C1 + alphac * O2 - (k_2 + betac) * C2
dO2 = betac * C2 + k_3 * O3 - (alphac + k3 * CcCaI) * O2
dO3 = k3 * CcCaI * O2 - k_3 * O3
try:
stimulus1 = stimulus[int(np.round(t / dt))]
except:
stimulus1 = stimulus[0]
print('error in indexing')
FbL = -30 * convert_milli # V
fbL = 1 * convert_nano # S
C = 15 * convert_pico # F
x = 20 * convert_nano # m
G1 = 0.75 * 1000 # kcal/mol = 1000 cal/mol
G2 = 0.25 * 1000 # kcal/mol = 1000 cal/mol
Z1 = 10 * 1000 * 10 ** 6 # kcal/mol microm = 1000 10**6 cal/mol* m
Z2 = 2 * 1000 * 10 ** 6 # kcal/mol microm = 1000 10**6 cal/mol* m
A = (G1 - Z1 * x) / (R * T)
B = (G2 - Z2 * x) / (R * T)
Fbt = 0 * convert_milli # V
gt_ = 3 * convert_nano # S
gt = gt_ / (1 + np.exp(B) * (1 + np.exp(A))) # S
if mechanical == 'mechanical':
mechanical = gt * (Fb - Fbt)
else:
mechanical = 0
dFb = -(fbL * (Fb - FbL) + GbCa * (Fb - FbCa) + GbKCa * (
Fb - FbK) + mechanical - stimulus1) / C # +fbCa*(Pb -PbCa)
alpha0 = 22800 * convert_SI_minus # s^-1
V0 = 70 * convert_milli # V
VA = 8.01 * convert_milli # V
Ka = 510 * convert_SI_minus # s^-1
beta0 = 0.97 * convert_SI_minus # s^-1
VB = 6.17 * convert_milli # V
Kb = 940 * convert_SI_minus # s^-1
beta = beta0 * np.exp((Fb + V0) / VB) + Kb
alpha = alpha0 * np.exp(-(Fb + V0) / VA) + Ka
dm = beta * (1 - mc) - alpha * mc
IbCa = GbCa_ * (mc ** 3) * (Fb - FbCa) # A
if IbCa == -0:
IbCa = 0
U = 0.02
Xi = 3.4 * (
10 ** -5) # fraction of volume that binds Ca, the bigger this is the smaller should be Ca increase therefore I guess this factor is necessary
Cvol = 1.25 * convert_pico # l
Ks = 2800 * convert_SI_minus ##s^-1
if 'nieman' in damping_type:
const = -0.00061
else:
const = (U / (2 * F * Cvol * Xi))
dCcCa = (const * IbCa) - Ks * CcCaI # concentration is mol/l
vars = (dFb, dm, dCcCa, dC0, dC1, dC2, dO2, dO3)
return vars
R = 8.31446261815324
F = 96485.336
CiCa = 1 * convert_mili * thousand
T = 298.5
stimulus = stimulus * 20 * convert_pico
us = odeint(func1, x, time, args=(
R, damping_type, F, convert_SI_minus, convert_SI, thousand, stimulus, np.abs(time[0] - time[1]), T, CiCa,
convert_mili,
convert_micro, convert_nano, convert_pico, mechanical),
full_output=True) # ,hmin = np.abs(time[0]-time[1]),hmax = np.abs(time[0]-time[1]) # full_output = True #,hmin = np.abs(time[0]-time[1])/10,hmax = np.abs(time[0]-time[1])*10 #
return us
def all_damping_variants(stimulus, time_array, damping_type='', eod_fr=750, damping_gain=1, damping='',
damping_element='',
damping_output=[], plot=False, std_dump=0, max_dump=0, range_dump=0):
# function: here you can choose the way of dumping but in realitiy only the damping_type == 'damping' is properly done,
# the rest would need further modifications
if damping_type == 'damping':
damping_output, stimulus, std_dump, max_dump, range_dump = damping_func(stimulus, time_array, eod_fr, damping,
damping_element, damping_gain)
elif (damping_type == 'damping_hundspeth') or (damping_type == 'damping_nieman'):
std_dump, extent, IbCa, damping_output = damping_hundspet(stimulus, time_array, damping_type)
elif damping_type == 'damping_kashimori':
std_dump, extent, IbCa, damping_output = damping_kashimori(stimulus, time_array)
if plot == True:
from utils_test import plot_kashimori
plot_kashimori(stimulus, time_array, damping_output, IbCa)
elif damping_type == 'damping_huxley':
huxley()
return std_dump, max_dump, range_dump, stimulus, damping_output,
def a2():
pass
def do_array_for_three(nfft, a_fr, fish_receiver, beat, zeros, cell_recording, fish_emitter, fish_jammer,
fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, sampling, eod_fish2, SAM,
eod_stimulus, eod_fish_r, freq2, a_f1, a_f2, cell, stimulus_length, offset, model_params, n,
adapt_offset, deltat, f2, trials_nr, time_array, f1, freq1, stimulus, phaseshift_fr=0,
variant='sinz', eod_fr=750, length_adapt=True, dict_here=[], redo_stim=False,
stim_type='01', reshuffle='reshuffled', dev=0.0005, damping=''):
spikes = [[]] * trials_nr
spikes_bef = [[]] * trials_nr
for t in range(trials_nr):
if (t == 0) | (type(phaseshift_fr) == str):
if redo_stim:
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick, sampling,
fish_receiver, deltat, nfft,
nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
eod_fish1, eod_fish2, eod_stimulus, t1 = stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2,
fish_emitter, fish_jammer, freq1, freq2,
fish_morph_harmonics_var, mimick,
nfft_for_morph, phase_right, phaseshift_fr,
sampling, stimulus_length, time_array,
zeros)
# if we need new stimulus each time we generate it here each time
if stim_type == '02':
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_fish2, eod_fish_r, freq2, f2,
eod_fr,
time_array, a_f2, eod_fj=freq2, j=f2,
a_fj=a_f2, ) # three='Three'
elif stim_type == '012':
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1, eod_fj=freq2, j=f2,
a_fj=a_f2,
three='Three') # SAM, eod_stimulus, eod_fish_r,freq2,a_f1, a_f2
elif stim_type == '01':
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, dict_here['eod_fish1'], eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1, eod_fj=freq1, j=f2,
a_fj=a_f2, ) # three='Three'
# damping variants
if damping != '':
embed()
std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(stimulus, time_array,
damping_type, eod_fr,
damping_gain,
damping, damping_variant,
plot=False,
std_dump=0, max_dump=0,
range_dump=0)
cvs, adapt_output, baseline_after, _, rate_adapted, rate_baseline_before, rate_baseline_after, spikes_bef[t], \
stimulus_altered, \
v_dent_output, offset_new, v_mem_output, noise_final = simulate(cell, offset, stimulus, deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2, adaptation_yes_e=f1,
adaptation_yes_t=t, power_variant=variant,
power_nr=n,
reshuffle=reshuffle, **model_params)
isi = calc_isi(spikes_bef[t], eod_fr)
spikes[t] = spikes_after_burst_corr(spikes_bef[t], isi, dict_here['burst_corr'], cell, eod_fr,
model_params=model_params)
if length_adapt == False:
spikes_mat = [[]] * len(spikes)
for s in range(len(spikes)):
spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat))
else:
spikes_mat = spikes_mat_depending_on_length(spikes, deltat, stimulus_length)
sampling_rate = 1 / deltat
if dev != 'original':
smoothed = gaussian_filter(spikes_mat, sigma=dev * sampling_rate)
else:
smoothed = spikes_mat
mean_smoothed = np.mean(smoothed, axis=0)
return stimulus, mean_smoothed, spikes, smoothed, spikes_mat, offset_new, v_mem_output
def stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2, fish_emitter, fish_jammer, freq1, freq2,
fish_morph_harmonics_var, mimick, nfft_for_morph, phase_right, phaseshift_fr, sampling,
stimulus_length, time_array, zeros):
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
if phaseshift_fr == 'randALL':
phaseshift_f1 = np.random.rand() * 2 * np.pi
phaseshift_f2 = np.random.rand() * 2 * np.pi
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick, fish_emitter,
thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick, fish_jammer,
thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
return eod_fish1, eod_fish2, eod_stimulus, t1
def check_peak_overlap_only_stim_big_final(stimulus_length_data=0.5, dev=0.001,
reshuffled='reshuffled',
printing=False, show=False, beat='', nfft_for_morph=4096 * 4,
gain=1,
fish_jammer='Alepto', us_name=''):
runs = 1
n = 1
default_settings() # ts=13, ls=13, fs=13, lw = 0.7
# extra combination with female small
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = '05'
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
a_fr = 1
a = 0
trials_nrs = [1]
datapoints = 1000
results_diff = pd.DataFrame()
position_diff = 0
default_settings(column=2, length=8.5)
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768 # 2**16##6#32768#2**12#32768
cells_here = ['2012-06-27-an-invivo-1']
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_250-1325-25_F2_500-525-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_mult__StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_not_log_temporal']
c_grouped = ['c1'] # , 'c2']
# adds = [-150, -50, -10, 10, 50, 150]
# fig, ax = plt.subplots(4, len(adds), constrained_layout=True, figsize=(12, 5.5))
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
frame_cell_orig_orignal = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
# (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0)
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
new_f2_tuple = new_f2_tuple[sorted]
frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)]
frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell)
diffs = find_deltas(frame_cell, c_grouped[0])
frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original')
new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score']
freq1s = np.unique(new_frame.df1)
freq2s = np.unique(new_frame.df2)
freq_example = 30 # 65
freq1s = [freq1s[np.argmin(np.abs(freq1s - freq_example))]]
freq2s = [freq2s[0]]
else:
freq_example = 30 # 65
freq1s = [freq_example]
freq2s = [10]
for freq1 in freq1s:
for freq2 in freq2s:
c_nrs = [0.0002, 0.2, 0.8] # 0.0002, , 0.50.01,0.075,0.1,
grid0 = gridspec.GridSpec(1, 1, bottom=0.1, top=0.85, left=0.09,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.15, hspace=0.27, subplot_spec=grid0[0],
) #
grid_u = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.7,
wspace=0.3,
subplot_spec=grid00[
0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_l = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.7,
wspace=0.1,
subplot_spec=grid00[
1]) # hspace=0.4,wspace=0.2,len(chirps)
if len(frame_cell_orig) > 0:
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
c_here = 'c1'
if c_here == 'c1': # 'B1_diff'
pass
if c_here == 'c2': # 'B1_diff'
pass
f_counter = 0
ax_upper = []
frame_cell_orig_orignal, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig_orignal)
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
axis_all = []
f += 1
# plot the baseline Peak above
linestyles = [['-', '-', '-', '-'],
['-', '-', '-', '-']
, ['--', '-.', '--'],
['-', '-', '-', '-'],
]
scores = [['amp_B1_01_mean_original', 'amp_f1_01_mean_original', 'amp_f0_01_mean_original'],
] # ['cv_01']
colors = [['green', 'pink', 'black', 'red'],
['grey'], ]
alpha = [[1, 0.5, 1, 1],
[0.5, 0.5, 0.5, 0.5],
[1, 1, 1, 1],
[1, 1, 1, 1], ]
axs = []
for s in range(len(scores)):
ax_u1 = plt.subplot(grid_u[s])
ax_upper = plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1, freq2,
alpha=alpha[s], sum=False, linestyles=linestyles[s],
scores=scores[s], colors=colors[s])
ax_u1.legend(loc=(0, 1), ncol=3)
if 'cv' not in scores[s][0]:
axs.append(ax_u1)
join_x(axs)
join_x(axs)
join_y(axs)
# frame_cell_orig_orignal
for ax in ax_upper:
ax.scatter(c_nrs, np.zeros(len(c_nrs)), marker='^', color='black')
grid_ll = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs),
hspace=0.55,
wspace=0.2,
subplot_spec=grid_l[
f]) # hspace=0.4,wspace=0.2,len(chirps)
colors_array = ['grey', 'green', 'orange', 'purple']
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
cvs = pd.read_csv(load_folder_name('calc_base') + '/csv_model_data.csv')
cv = np.round(cvs[cvs['cell'] == cell_here].cv_model.iloc[0], 3)
fr = np.round(cvs[cvs['cell'] == cell_here].fr_model.iloc[0])
plt.suptitle(cell_here + ' EODf=' + str(np.round(eodf)) + ' Hz' + ' cv=' + str(cv) + ' fr=' + str(
np.round(fr))
+ 'Hz F1=' + str(np.round(freq1 + eodf)) + ' Hz' + ' F1-EODf=' + str(
np.round(freq1)) + ' Hz' + ' F2=' + str(np.round(freq2 + eodf)) + ' Hz ' + 'F2-EODf=' + str(
np.round(freq2)) + ' Hz ')
axts = []
axps = []
axis = []
for c_nn, c_nr in enumerate(c_nrs):
f_counter = plt_single_contrast_ps_isi_01(axis, f_counter, axts, axps, f, grid_ll, c_nn,
freq_example, freq2, eodf, datapoints, auci_wo,
auci_w, results_diff, a_f2s, fish_jammer, a,
trials_nr, nfft, us_name, gain, runs, a_fr,
nfft_for_morph, beat, printing,
stimulus_length_data,
model_cells, position_diff, colors_array, reshuffled,
dev, sampling,
cell_here, c_nr, n, [dev_name], min_amps, extend=True,
first=False, ypos=1.65,
c_nn_nr=0, xpos=1, second=False)
axts_all.extend(axts)
axps_all.extend(axps)
axis_all.extend(axis)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
axis_all[0].get_shared_x_axes().join(*axis_all)
save_visualization(cell_here + '_freq1_' + str(freq1) + '_freq2_' + str(freq2), show)
def plt_single_contrast_ps_isi_01(axis, f_counter, axts, axps, f, grid_ll, c_nn, freq1, freq2, eodf, datapoints,
auci_wo, auci_w, results_diff, a_f2s, fish_jammer, a,
trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, colors_array, reshuffled, dev, sampling, cell_here, c_nr,
n, dev_name, min_amps, extend=True, c_nn_nr=1, first=True, xpos=1, second=True,
val=1.5, ypos=1.35):
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_05, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft,
us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n,
reshuffled=reshuffled, min_amps=min_amps)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr, nfft,
us_name, gain, runs, a_fr, nfft_for_morph, beat, printing, stimulus_length,
model_cells, position_diff, 'original', cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n,
reshuffled=reshuffled, min_amps=min_amps)
time = np.arange(0, len(arrays_05[a][0]) / sampling, 1 / sampling)
arrays_here = [v_mems[1]] #::]#arrays_05[1::]#arrays_original[1::]#
arrays_here_original = [arrays_original[1]] #::]
spike_here = [arrays_spikes[1]] #::]
stim_here = [arrays_stim[1]] #::]
names = ['0', '01', '02', '012']
names_here = [names[1]] # extend=True
for a in range(len(arrays_here)):
grid_pt = gridspec.GridSpecFromSubplotSpec(4, 1,
hspace=0.65,
wspace=0.2,
subplot_spec=grid_ll[c_nn]) # hspace=0.4,wspace=0.2,len(chirps)
axs = plt.subplot(grid_pt[0])
axt = plt.subplot(grid_pt[1])
axp = plt.subplot(grid_pt[2])
axi = plt.subplot(grid_pt[3])
axts.append(axt)
axps.append(axp)
axis.append(axi)
if f != 0:
remove_yticks(axt)
remove_yticks(axs)
if a != len(arrays_here) - 1:
remove_xticks(axt)
remove_xticks(axs)
if f_counter == 0:
axt.set_ylabel(names[a])
axs.set_ylabel(names[a])
if a == 0:
axs.set_title(' a1=' + str(c_nr) + ', a2=0')
elif a == 1:
axs.set_title(' a1=0,' + ' a2=' + str(a_f2s[0]))
else:
axs.set_title(' a1=' + str(c_nr) + ', a2=' + str(a_f2s[0]))
xlim = [0.1, 0.1 + val / freq1]
axs.plot(time, stim_here[a], color='grey') # color=colors_array_here[a],colors_contrasts[c_nn]
axs.eventplot(spike_here[a][0], lineoffsets=np.mean(stim_here[a]), color='black') # np.max(v1)*
axs.set_xlim(xlim)
axt.plot(time, arrays_here[a], color='grey') # colors_array_here[a]colors_contrasts[c_nn]
axt.eventplot(spike_here[a][0], lineoffsets=np.max(arrays_here[a]), color='black') # np.max(v1)*
axt.set_xlim(xlim) # 1.5
axi.hist(np.diff(spike_here[a][0]) / (1 / eodf),
bins=np.arange(0, np.max(np.diff(spike_here[a][0]) / (1 / eodf)), 0.1),
color='grey') # colors_array_here[a]
axi.axvline(x=1, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey',
axi.axvline(x=2, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey',
axi.axvline(x=3, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey',
axi.axvline(x=4, color='black', linestyle='--', linewidth=0.5, zorder=100) # color = 'grey',
try:
axi.set_xticks_delta(1)
except:
print('problem something')
axi.set_xlim(0, 8)
pp, ff = ml.psd(arrays_here_original[a][0] - np.mean(arrays_here_original[a][0]), Fs=sampling, NFFT=nfft,
noverlap=nfft // 2)
axp.plot(ff, pp, color='grey') # colors_contrasts[c_nn]#colors_array_here[a]
maxx = 900
axp.set_xlim(0, maxx)
if c_nn == c_nn_nr:
if a == 0: #
if second:
second_part = 'F1=' + str(np.round(freq1 + eodf)) + 'Hz' + ' F1-EODf=' + str(
freq1) + 'Hz'
else:
second_part = ''
if first:
first_part = 'only Frequency 1: '
else:
first_part = ''
axt.text(xpos, ypos, first_part + second_part, fontweight='bold', ha='center', fontsize=10,
transform=axt.transAxes, )
elif a == 1:
if second:
second_part = 'F2=' + str(np.round(freq2 + eodf)) + 'Hz ' + 'F2-EODf=' + str(
freq2) + ' Hz '
else:
second_part = ''
if first:
first_part = 'only Frequency 2: '
else:
first_part = ''
axt.text(xpos, ypos, first_part + second_part, fontweight='bold', ha='center', fontsize=10,
transform=axt.transAxes, )
else:
if second:
second_part = 'F1=' + str(np.round(freq1 + eodf)) + 'Hz' + ' F1-EODf=' + str(
freq1) + 'Hz' + ' F2=' + str(freq2 + eodf) + 'Hz ' + 'F2-EODf=' + str(freq2) + ' Hz '
else:
second_part = ''
if first:
first_part = 'Frequency 1 + Frequency 2: '
else:
first_part = ''
axt.text(xpos, ypos,
first_part + second_part,
fontweight='bold', ha='center', fontsize=10, transform=axt.transAxes, )
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf,
color_eodf='black', name=names_here[0],
color_stim='pink', color_stim_mult='pink')
plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, alphas=alphas, extend=extend, ms=18,
clip_on=True)
if c_nn != 0:
remove_yticks(axt)
remove_yticks(axs)
remove_yticks(axp)
remove_yticks(axi)
else:
axt.set_ylabel('mV')
axp.set_ylabel('Hz')
axi.set_ylabel('Nr')
axt.set_xlabel('Time [s]')
axi.set_xlabel('EOD mult')
remove_xticks(axs)
axp.set_xlabel('Frequency [Hz]')
f_counter += 1
return f_counter
def chose_all_freq_combos(freq2, colors_array, freq1, maxx, eodf, color_eodf='blue', stim_thing=True,
color_stim='orange', name='01',
color_stim_mult='orange'):
if name == '01':
alphas, labels, colors_peaks, freqs = mult_beat_freqs(eodf, maxx, np.abs(freq1), color_df_mult=colors_array[1],
color_eodf=color_eodf,
color_stim=color_stim, stim_thing=stim_thing,
color_stim_mult=color_stim_mult, )
elif name == '02':
freqs = [np.abs(freq2), np.abs(freq2) * 2, np.abs(freq2) * 2, np.abs(freq2) * 3, np.abs(freq2) * 4,
np.abs(freq2) * 5, np.abs(freq2) * 6, np.abs(freq2) * 7, np.abs(freq2) * 8, np.abs(freq2) * 9,
np.abs(freq2) * 10,
eodf]
colors_peaks = [colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2],
colors_array[2], colors_array[2], colors_array[2], colors_array[2], colors_array[2],
'black']
labels = ['DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'DF2',
'DF2', 'DF2', 'DF2', 'DF2', 'DF2', 'EODF']
alphas = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
1, 0.5]
elif name == 'eodf':
freqs = [eodf]
colors_peaks = ['black']
labels = ['EODF']
alphas = [1]
else:
freqs = [freq1, np.abs(freq2), eodf]
colors_peaks = ['blue', 'red', 'black']
labels = ['DF1', 'DF2', 'EODF']
alphas = [1, 0.2, 0.5]
return freqs, colors_peaks, labels, alphas
def find_double_spikes(eod_fr, arrays_spikes, names, results_diff, position_diff, add=''):
for a, sp_array in enumerate(arrays_spikes):
hist, bin_edges = np.histogram(sp_array[0], bins=np.arange(0, np.max(sp_array[0]), 1 / eod_fr))
hist_big = hist[hist > 0]
results_diff.loc[position_diff, 'dsp_perc95_' + names[a] + add] = np.percentile(hist_big, 95)
results_diff.loc[position_diff, 'dsp_max_' + names[a] + add] = np.max(hist_big)
results_diff.loc[position_diff, 'dsp_mean_' + names[a] + add] = np.mean(hist_big)
return results_diff
def upper_and_lower_fr(array_smoothed, results_diff, position_diff, eod_fr, names, add=''):
lim = 10
for a, array in enumerate(array_smoothed):
results_diff.loc[position_diff, 'lb_' + names[a] + add] = len(array[0][array[0] < lim]) / len(array[0])
results_diff.loc[position_diff, 'ub_' + names[a] + add] = len(
array[0][(array[0] < eod_fr + lim) & (array[0] > eod_fr - lim)]) / len(array[0])
results_diff.loc[position_diff, 'ub_above_' + names[a] + add] = len(
array[0][(array[0] >= eod_fr + lim)]) / len(array[0])
return results_diff
def calc_vs_amps(results_diff, stim, eod_fr, arrays_spikes, position_diff, names, add=''):
freq_comp = ['_f0', '_f1']
freqs = [eod_fr, stim]
for f, ff in enumerate(freqs):
for a, array in enumerate(arrays_spikes):
vs = calc_vectorstrength(array[0], 1 / freqs[f])
results_diff.loc[position_diff, 'vs_' + freq_comp[f] + names[a] + add] = vs[0]
vs = calc_vectorstrength(array[0], 1 / freqs[f] * 2)
results_diff.loc[position_diff, 'vs_harm_' + freq_comp[f] + names[a] + add] = vs[0]
return results_diff
def plt_subpart_cocktail(results_diff, fs, p01, p02, p012, p0):
fig, ax = plt.subplots(4, 1, sharex=True, sharey=True) #
arrays = [p0[0], p01[0], p02[0], p012[0]]
for a, array in enumerate(arrays):
ax[a].plot(fs, array)
B1 = results_diff.df1.iloc[-1]
B2 = results_diff.df2.iloc[-1]
fr = results_diff.fr.iloc[-1]
f0 = results_diff.f0.iloc[-1]
freqs = [np.abs(B1), np.abs(B2),
np.abs(np.abs(B1) - np.abs(B2)),
np.abs(B1) + np.abs(B2), np.mean(fr), f0]
colors = ['blue', 'green', 'purple', 'orange', 'red', 'black']
labels = ['DF1', 'DF2', '|DF1-DF2|', '|DF1+DF2|', 'Baseline', 'eod_fr']
plt_peaks_several(freqs, arrays, ax[a], array, fs, labels, 0, colors)
ax[-1].legend()
plt.show()
def calc_roc_amp_core_cocktail(freq1, freq2, datapoints, auci_wo, auci_w, results_diff, a_f2s, fish_jammer, trials_nr,
nfft,
us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length, model_cells, position_diff, dev, cell_here,
params_dict={'burst_corr': ''}, stimulus_length_first=0, p_xlim=0, a_f1s=[],
dev_name=['05'], phaseshift_fr=0, min_amps='', n=1, reshuffled='', way_all='',
test=False, AUCI='AUCI', phase_right='_phaseright_', SAM='', points=5,
means_different=''): # '_means_'
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
f1 = 0
f2 = 0
sampling_factor = ''
cell_recording = ''
mimick = 'no'
zeros = 'zeros'
fish_morph_harmonics_var = 'harmonic'
fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus']
fish_receiver = 'Alepto' #
adapt_offset = 'adaptoffset_bisecting'
lower_tol = 0.995
upper_tol = 1.005
damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35
damping_type = ''
exponential = ''
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
_, _ = prepare_baseline_array(time_array, eod_fr, nfft_for_morph,
phaseshift_fr, mimick, zeros,
cell_recording, sampling,
stimulus_length, fish_receiver, deltat,
nfft, damping_type, damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
spikes_base = [[]] * trials_nr
spikes_bef = [[]] * trials_nr
for run in range(runs):
print(run)
stim_lengths = []
for t in range(trials_nr):
if (stimulus_length_first != 0) & (t == 0):
stimulus_length_here = stimulus_length_first
else:
stimulus_length_here = stimulus_length
stim_lengths.append(stimulus_length_here)
if (t == 0) | (type(phaseshift_fr) == str):
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length_here, deltat=deltat)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length_here, phaseshift_fr,
cell_recording, zeros, mimick, sampling,
fish_receiver, deltat, nfft,
nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
if (stimulus_length_first != 0) & (t == 1):
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length_here, deltat=deltat)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length_here, phaseshift_fr,
cell_recording, zeros, mimick, sampling,
fish_receiver, deltat, nfft,
nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
# baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
stimulus_0 = eod_fish_r
power_here = 'sinz'
adapt_offset = 'adaptoffset_bisecting'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_bef[t], _, _, offset_new, v_mem0, noise_final = simulate(cell, offset, stimulus_0, deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2, adaptation_yes_e=f1,
adaptation_yes_t=t,
adaptation_upper_tol=upper_tol,
adaptation_lower_tol=lower_tol,
power_variant=power_here, power_alpha=alpha,
power_nr=n, reshuffle=reshuffled,
**model_params)
isi = calc_isi(spikes_bef[t], eod_fr)
try:
spikes_base[t] = spikes_after_burst_corr(spikes_bef[t], isi, params_dict['burst_corr'], cell, eod_fr,
model_params=model_params)
except:
print('assing spikes problem')
print(' offset orig ' + str(offset))
test = False
if test:
from utils_test import test_cvs3
test_cvs3()
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
print(' Base ' + str(adapt_offset) + ' offset ' + str(offset))
if printing:
print('Baseline time' + str(time.time() - t1))
if test:
fig, ax = plt.subplots(2, 1, sharex=True)
ax[0].eventplot(spikes_bef[t], color='red')
ax[1].eventplot(spikes_base[t])
ax[1].eventplot(spikes_bef[t], color='red')
ax[1].set_xlim(0, 0.1)
plt.show()
sampling_rate = 1 / deltat
stim_length_max = np.max([stimulus_length_first, stimulus_length])
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat, stim_length_max,
dev=dev)
fr = np.mean(base_cut)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
for aaa, a_f2 in enumerate(a_f2s): # [0]
for aa, a_f1 in enumerate(a_f1s): # [0]
eod_fish1, eod_fish2, eod_stimulus, t1 = stimulus_threefish(a_f1, a_f2, cell_recording, f1, f2,
fish_emitter, fish_jammer, freq1, freq2,
fish_morph_harmonics_var, mimick,
nfft_for_morph, phase_right, phaseshift_fr,
sampling, stimulus_length, time_array,
zeros)
if test:
from utils_test import test_timearray
test_timearray()
adapt_offset_later = ''
v_mem, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2, stimulus_length,
offset, model_params, n, variant, adapt_offset_later, deltat, f2, trials_nr, time_array, f1,
freq1, eod_fr, reshuffle=reshuffled, dev=dev, zeros=zeros, a_fr=a_fr, params_dict=params_dict,
redo_stim=True, nfft=nfft, cell_recording=cell_recording, phaseshift_fr=phaseshift_fr,
fish_emitter=fish_emitter, fish_receiver=fish_receiver, beat=beat, fish_jammer=fish_jammer,
fish_morph_harmonics_var=fish_morph_harmonics_var, mimick=mimick, nfft_for_morph=nfft_for_morph,
phase_right=phase_right, sampling=sampling)
if test:
fig, ax = plt.subplots(4, 1, sharex=True, sharey=True)
ax[0].plot(stimulus_0)
ax[1].plot(stimulus_01)
ax[2].plot(stimulus_02)
ax[3].plot(stimulus_012)
plt.show()
if printing:
print('Generation process' + str(time.time() - t1))
v_mems = np.concatenate([[v_mem0], v_mem])
for dev_n in dev_name:
results_diff.loc[position_diff, 'fr'] = fr
results_diff.loc[position_diff, 'f1'] = freq1[0]
results_diff.loc[position_diff, 'f2'] = freq2[0]
results_diff.loc[position_diff, 'f0'] = eod_fr
results_diff.loc[position_diff, 'df1'] = np.abs(eod_fr - freq1)
results_diff.loc[position_diff, 'df2'] = np.abs(eod_fr - freq2)
results_diff.loc[position_diff, 'cell'] = cell
results_diff.loc[position_diff, 'c1'] = a_f1
results_diff.loc[position_diff, 'c2'] = a_f2
results_diff.loc[position_diff, 'trial_nr'] = trials_nr
results_diff = calc_cv_three_wave(results_diff, position_diff,
arrays=[spikes_base, spikes_01, spikes_02, spikes_012],
adds=['_0', '_01', '_02', '_012'])
if dev_n == '05':
dev = 0.0005
# test wie ich die std und var und psd eigentlich gruppieren müsste
if dev_n == 'original':
array0 = [np.mean(mat0, axis=0)]
array01 = [np.mean(mat01, axis=0)]
array02 = [np.mean(mat02, axis=0)]
array012 = [np.mean(mat012, axis=0)]
elif dev_n == '05':
array0 = [np.mean(smoothed0, axis=0)]
array01 = [np.mean(smoothed01, axis=0)]
array02 = [np.mean(smoothed02, axis=0)]
array012 = [np.mean(smoothed012, axis=0)]
p0, p02, p01, p012, ff = calc_ps(nfft, [array012[0][time_array > p_xlim]],
[array01[0][time_array > p_xlim]],
[array02[0][time_array > p_xlim]],
[array0[0][time_array > p_xlim]],
sampling_rate=sampling_rate)
if test:
plt_subpart_cocktail(results_diff, ff, p01, p02, p012, p0)
########################################################
# also hier sind eben diese ganzen Amplituden
# 'amp_max_012_mean' ,'amp_max_02_mean', 'amp_max_01_mean', 'amp_max_0_mean' ist zwangsweise das was wir suchen
# 'amp_max_012_mean', 'amp_max_02_mean', 'amp_max_01_mean'
# 'amp_B2_012_mean','amp_B1_012_mean' 'amp_B2_02_mean','amp_B1_01_mean'
# das ganze ist zwangsweise auf dem gemittelten Arrays
# in dieser einen Version mache ich das ohne sqrt hier und dann
if np.isnan(results_diff.loc[position_diff, 'f0']):
print('isnan thing4')
embed()
results_diff = calc_amps(ff, p0, p02, p01, p012, position_diff,
[dev], 0, results_diff, results_diff,
add='_mean' + '_' + dev_n, timesstamp=False, min_amps=min_amps,
points=points)
printing = False
if printing:
print(' a_f1 ' + str(aa) + ' ' + str(adapt_offset) + ' offset ' + str(offset) + ' time ' + str(
time.time() - t1))
#######################################
# here calculate the fft
# die arrays hier sind immer eindimmensional deswegen muss man hier nicht auf trials achten!
# embed()
# das ist das was wir vergleichen
ffts_right1, freq = calc_fft(array0, array01, array012, array02, deltat, sampling)
results_diff.loc[position_diff, 'diff_fft' + '_' + str(dev_n)] = np.sum(ffts_right1) * \
freq[1]
####################################################################
arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012]
arrays = [array0, array01, array02, array012]
arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012]
names = ['0', '01', '02', '012']
for a, array in enumerate(arrays):
results_diff.loc[position_diff, 'std_' + names[a] + '_' + dev_n] = np.std(array)
results_diff.loc[position_diff, 'var_' + names[a] + '_' + dev_n] = np.var(array)
names_saved = ['var', 'std']
for name_saved in names_saved:
results_diff = calculate_the_difference(position_diff, results_diff, name_saved, dev_n,
results_diff.loc[
position_diff, name_saved + '_012' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_01' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_02' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_0' + '_' + dev_n])
test = False
if test:
fig, ax = plt.subplots(4, 1)
ax[0].plot(time,
results_diff.loc[position_diff, name_saved + '_0' + '_' + dev_n])
ax[1].plot(time,
results_diff.loc[position_diff, name_saved + '_01' + '_' + dev_n])
ax[2].plot(time,
results_diff.loc[position_diff, name_saved + '_02' + '_' + dev_n])
ax[3].plot(time, results_diff.loc[position_diff, name_saved + '_012' + '_' + dev_n])
plt.show()
##########################################
# hier ist eine extra kondition falls wir das ohne vorheriges Mitteln vergleichen wollen würden!
# die brauchen wir im default nicht aber zum abgleichen was was ist ist
# d asmanchmal ganz nett
# if (trials != 1 ) & ('_means_' & means_differnet): # für die verschiedenen Trials wollen wir verschiedene Konditions einführen
# embed()
if (trials_nr != 1) & (
'_means_' in means_different): # calc_model_amp_freqs_param. # für die verschiedenen Trials wollen wir verschiedene Konditions einführen
if dev_n == 'original':
array0 = mat0
array01 = mat01
array02 = mat02
array012 = mat012 # , axis = 0
elif dev_n == '05':
array0 = smoothed0
array01 = smoothed01
array02 = smoothed02
array012 = smoothed012
p0, p02, p01, p012, ff = calc_ps(nfft, array012,
array01,
array02,
array0,
sampling_rate=sampling_rate)
results_diff = calc_amps(ff, p0, p02, p01, p012, position_diff,
[dev], 0, results_diff, results_diff,
add='' + '_' + dev_n, timesstamp=False, min_amps=min_amps,
points=points)
#######################################
# here calculate the fft
# die arrays hier sind immer eindimmensional deswegen muss man hier nicht auf trials achten!
# embed()
ffts_right1, freq = calc_fft(array0, array01, array012, array02, deltat, sampling)
results_diff.loc[position_diff, 'diff_mean(fft)' + '_' + str(dev_n)] = np.sum(ffts_right1) * \
freq[1]
arrays_stim = [stimulus_0, stimulus_01, stimulus_02, stimulus_012]
arrays = [array0, array01, array02, array012]
arrays_spikes = [spikes_base, spikes_01, spikes_02, spikes_012]
names = [cl_3names.c0, cl_3names.c01, cl_3names.c02, cl_3names.c012, ]
for a, array in enumerate(arrays):
results_diff.loc[position_diff, 'mean(std)_' + names[a] + '_' + str(dev_n)] = np.mean(
np.std(array, axis=1))
results_diff.loc[position_diff, 'mean(var)_' + names[a] + '_' + str(dev_n)] = np.mean(
np.var(array, axis=1))
names_saved = ['mean(var)', 'mean(std)']
for name_saved in names_saved:
results_diff = calculate_the_difference(position_diff, results_diff, name_saved, dev_n,
results_diff.loc[
position_diff, name_saved + '_012' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_01' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_02' + '_' + dev_n],
results_diff.loc[
position_diff, name_saved + '_0' + '_' + dev_n])
##################################################
# für den Fall dass ich das Testen will und das zurück transferieren will
# zum testen von dem phase sorting algorithmus für die Daten!
if params_dict['phase_undo'] == True:
embed()
mean_type = 'MeanTrialsIndexPhaseSort_Min0.25sExcluded'
_, _, _, _ = phase_sort_and_cut(mean_type, frame,
synaptic_flt_analysis,
t, sampling_time,
sorted_on=sorted_on) # not there
# todo: der Teil ist halt noch nicht fürs Mitteln ausgelegt, aber vielleich tbrauche ich das auch nicht
else:
p_array = []
##########################################
# diese Mehrfachen berechnen für eine
# das hier ist so eine Analyse um, die Zenter zu berechnen und zwar ohne die Beats erstmal
# ich denke 4 Harmonische reichen da schon nicht whar?
# hier muss man noch einbau dass das davon abhängt ob c1 oder c2 variert!
if dev_n == 'original':
array0 = [np.mean(mat0, axis=0)]
array01 = [np.mean(mat01, axis=0)]
array02 = [np.mean(mat02, axis=0)]
array012 = [np.mean(mat012, axis=0)]
elif dev_n == '05':
array0 = [np.mean(smoothed0, axis=0)]
array01 = [np.mean(smoothed01, axis=0)]
array02 = [np.mean(smoothed02, axis=0)]
array012 = [np.mean(smoothed012, axis=0)]
B1 = results_diff.loc[position_diff, 'B1']
if B1 != 0:
try:
beats_range = np.arange(np.abs(B1), 1000,
np.abs(B1))
except:
print('B1 thing')
embed()
idx = []
for beat in beats_range:
idx.append(np.argmin(np.abs(ff - beat)))
names = ['0', '01', '02', '012']
p_array = [p0, p01, p02, p012]
for n_nr, name in enumerate(names):
vals = np.sqrt(np.mean(p_array[n_nr][0][idx] * ff[1]))
results_diff.loc[position_diff, 'B1_harms_all_mean_' + name + '_' + dev_n] = vals
vals = np.sqrt(np.sum(p_array[n_nr][0][idx] * ff[1]))
results_diff.loc[position_diff, 'B1_harms_all_sum_' + name + '_' + dev_n] = vals
nrs_excluded = [0, 1, 2, 3, 4]
for nr_excluded in nrs_excluded:
if len(idx) > nr_excluded:
vals_bef = p_array[n_nr][0][idx[nr_excluded::]] * ff[1]
vals = np.sqrt(np.mean(vals_bef))
results_diff.loc[position_diff, 'B1_harms_' + str(
nr_excluded) + '_all_mean_' + name + '_' + dev_n] = vals
vals = np.sqrt(np.sum(vals_bef))
results_diff.loc[position_diff, 'B1_harms_' + str(
nr_excluded) + '_all_mean_' + name + '_' + dev_n] = vals
results_diff.loc[position_diff, 'B1_harms_' + str(
nr_excluded) + '_all_center_' + name + '_' + dev_n] = ff[
idx[nr_excluded::][np.argmax(vals_bef)]]
# nochmal die Vector Strength berechnen
results_diff = calc_vs_amps(results_diff, freq1[0], eod_fr, arrays_spikes, position_diff, names,
add='_' + dev_n)
######################################
# upper and lower bound berechnen
array_smoothed = [smoothed0, smoothed01, smoothed02, smoothed012]
names = ['0', '01', '02', '012']
results_diff = upper_and_lower_fr(array_smoothed, results_diff, position_diff, eod_fr, names,
add='') # ''
######################################
# hist für doppelte spikes vom Phase Locking
try:
results_diff = find_double_spikes(eod_fr, arrays_spikes, names, results_diff, position_diff,
add='_' + dev_n)
except:
print('double spikes problem')
embed()
######################################
# phasen zu dem EOD
if 'AUCI' in AUCI:
add = '_' + way_all + str(datapoints) + '_' + dev_n
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, base_here, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, add=add, t_off=5,
way=way_all, datapoints=datapoints, f0='f0')
position_diff += 1
try:
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_array, ff
except:
print('missing')
embed()
return v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_array, ff
def calc_fft(array0, array01, array012, array02, deltat, sampling):
arrays = [array012, array01, array02, array0]
names = [cl_3names.c012, cl_3names.c01, cl_3names.c02, cl_3names.c0]
fft = {}
ffts_all = calc_FFT3(arrays, deltat, fft, names)
ffts_right1 = equal_to_temporal_mean(ffts_all)
freq = np.fft.fftfreq(len(ffts_right1), d=1 / sampling)
return ffts_right1, freq
def phase_sort_and_cut(mean_type, frame, synaptic_flt_orig, t, sampling_time, sorted_on='local_reconst_big_norm'):
if 'DetectionAnalysis' not in mean_type:
delays_length = define_delays_trials(mean_type, frame, sorted_on=sorted_on)
array012_all, array01_all, array02_all, array0_all = cut_uneven_trials(
frame, synaptic_flt_orig[t], mean_type, delays_length, sampling=sampling_time)
if 'extended' in mean_type:
# Also das hier generiert mehr trials, indem das alles in neue gruppen regeneriert,
# also die gleichen Daten sind mehrmals drin nur anders gruppiert (2*15, 3*10 und am ende auch 1*30, letzte Zeile!)
# brauchen wir vor allem für das ROC Ding und im Falle vom Mitteln
# also das finde icch braut
array0_gr = find_group_variants(array0_all, [])
array01_gr = find_group_variants(array01_all, [])
array02_gr = find_group_variants(array02_all, [])
array012_gr = find_group_variants(array012_all, [])
print('extended ' + str(len(array0_gr)))
if len(array0_gr) > 0:
print('extended thing')
embed()
# und hier appenden wir nochmal alles als variante
array0_gr.append([array0_all])
array01_gr.append([array01_all])
array02_gr.append([array02_all])
array012_gr.append([array012_all])
else: # DEFAULT
array0_gr = [[array0_all]]
array01_gr = [[array01_all]]
array02_gr = [[array02_all]]
array012_gr = [[array012_all]]
else:
array012_gr = [[[]]]
array01_gr = [[[]]]
array02_gr = [[[]]]
array0_gr = [[[]]]
return array012_gr, array02_gr, array01_gr, array0_gr
def calc_amp_value(names, freq_step, ff, ps, position_diff, fs, devname, t, results_diff, name, add, fish, points=5):
length = len(ps[ff])
results_diff = results_diff.copy()
vals = []
printing = False
if str(names[name]) != 'no': # wir machen das wegen der Baseline, weil die nicht immer da ist!
for trial_nr in range(len(ps[ff])):
if not (np.isnan(ps[ff][trial_nr])).any():
if trial_nr == 0:
if printing:
print('started')
results_diff.loc[position_diff, 'amp_' + name + fish + add] = 0
if 'max' in name:
results_diff.loc[position_diff, 'f_' + name + fish + add] = 0
if name == '':
if printing:
print('name == ')
results_diff.loc[position_diff, 'amp_' + name + fish + add] += np.sum(ps[ff][trial_nr]) * fs[1]
else:
if 'max' in name:
if printing:
print('max')
if (devname[t] == 'original') or (devname[t] == '_original') or (devname[t] == '_eod'):
try:
arg = np.argmax(
ps[ff][trial_nr][fs < 0.5 * results_diff.EODf.loc[position_diff] + fs[1]])
except:
try:
arg = np.argmax(
ps[ff][trial_nr][fs < 0.5 * results_diff.f0.loc[position_diff] + fs[1]])
except:
print('arg stuff')
embed()
else:
if 'harm' in name:
arg = np.argmax(ps[ff][trial_nr]) * 2
else:
arg = np.argmax(ps[ff][trial_nr])
if arg < len(
fs): # also wenn einer der trials drüber ist dann setzten wir alles NAN einfach weil
# irgendwo später im Code könnte es komische Effetke geben
try:
results_diff.loc[position_diff, 'f_' + name + fish + add] += fs[arg]
except:
print('results diff problems')
embed()
else:
if printing:
print('else')
results_diff.loc[position_diff, 'f_' + name + fish + add] += float('nan')
else:
try:
arg = np.argmin(np.abs(fs - names[name]))
except:
print('arg something')
embed()
try:
if printing:
print('val')
# also bei den Phaselocking Sachen da nehme ich immer nur einen Peak, weil ich kriege so viele Peaks ich möchte
# einen Überlapp vermeiden! Aber wenn die Auflösung fein genug ist sollte das schon passen!
if arg < len(ps[ff][trial_nr]):
if points == 1:
val = np.sum((ps[ff][trial_nr][arg]) * freq_step)
elif points == 3:
val = np.sum((ps[ff][trial_nr][arg - 1:arg + 2]) * freq_step)
elif points == 5:
val = np.sum((ps[ff][trial_nr][arg - 2:arg + 3]) * freq_step)
results_diff.loc[position_diff, 'amp_' + name + fish + add] += val
vals.append(val)
except:
print('calc_ amp')
embed()
else:
length -= 1
# das muss ganz am Ende stehen!
# davor wurde das aufsummiert jetzt wird das geteilt!
if length != 0:
# ok das hier ist so ein Ding wenn ich mitten in der Funtion anhalten will das ich trotzdem bei der Hälfte rauskomme?
if printing:
print('div')
try:
results_diff.loc[position_diff, 'amp_' + name + fish + add] = results_diff.loc[
position_diff, 'amp_' + name + fish + add] / length
except:
print('amp problem')
embed()
if 'max' in name:
results_diff.loc[position_diff, 'f_' + name + fish + add] = results_diff.loc[
position_diff, 'f_' + name + fish + add] / length
return results_diff
def peaks_of_interest(df1, df2, beat1, beat2, fr, f1, f2, eod_fr, min_amps=''):
# ok das sind alle potentiell interessanten Peaks aber meistens wollen wir ja nur bestimmte,
# das sollten wir stark reduzieren, hier kann man sagen nur an den und den Peaks interesisert
# eignetlich interessiren uns nur 'B1_': np.abs(beat1),'B2_': np.abs(beat2),
if 'min' in min_amps:
names = {
'B1_': np.abs(beat1),
'B2_': np.abs(beat2),
'B1-B2_': np.abs(beat1 - beat2),
'B2-B1_': np.abs(beat2 - beat1),
'B2+B1_': np.abs(beat2 + beat1),
'B1+B2_': np.abs(beat2 + beat1),
'f0_': np.abs(eod_fr),
'f0_harm_': np.abs(eod_fr) * 2,
'f1_': f1,
'f2_': f2,
'env_beat_': np.abs(np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs(
create_beat_corr(eod_fr - f2, np.array([eod_fr])))),
'fr_': fr}
else:
names = {
'DeltaF1_': np.abs(df1),
'DeltaF2_': np.abs(df2),
'DeltaF1_harm_': np.abs(df1) * 2,
'DeltaF2_harm_': np.abs(df2) * 2,
'B1_harm_': np.abs(beat1) * 2,
'B2_harm_': np.abs(beat2) * 2,
'B1_2harm_': np.abs(beat1) * 3,
'B2_2harm_': np.abs(beat2) * 3,
'B1_3harm_': np.abs(beat1) * 4,
'B2_3harm_': np.abs(beat2) * 4,
'F2+F1_': np.abs(f2 + f1),
'F1-F2_': np.abs(f1 - f2),
'F2-F1_': np.abs(f2 - f1),
'B1_': np.abs(beat1),
'B2_': np.abs(beat2),
'B1-B2_': np.abs(beat1 - beat2),
'B2-B1_': np.abs(beat2 - beat1),
'B2+B1_': np.abs(beat2 + beat1),
'B1+B2_': np.abs(beat2 + beat1),
'fr-B2_': np.abs(fr - beat2),
'fr-B1_': np.abs(fr - beat1),
'fr-(B2+B1)_': np.abs(fr - (beat2 + beat1)),
'fr-(B1-B2)_': np.abs(fr - np.abs(beat1 - beat2)),
'fr-(B2-B1)_': np.abs(fr - np.abs(beat2 - beat1)),
'fr+B2_': np.abs(fr + beat2),
'fr+B1_': np.abs(fr + beat1),
'fr+(B2+B1)_': np.abs(fr + (beat2 + beat1)),
'fr+(B1-B2)_': np.abs(fr + np.abs(beat1 - beat2)),
'fr+(B2-B1)_': np.abs(fr + np.abs(beat2 - beat1)),
'f0-B2_': np.abs(eod_fr - beat2),
'f0-B1_': np.abs(eod_fr - beat1),
'f0-(B2+B1)_': np.abs(eod_fr - (beat2 + beat1)),
'f0-(B1-B2)_': np.abs(eod_fr - np.abs(beat1 - beat2)),
'f0-(B2-B1)_': np.abs(eod_fr - np.abs(beat2 - beat1)),
'f0+B2_': np.abs(eod_fr + beat2),
'f0+B1_': np.abs(eod_fr + beat1),
'f0_': np.abs(eod_fr),
'f0+(B2+B1)_': np.abs(eod_fr + (beat2 + beat1)),
'f0+(B1-B2)_': np.abs(eod_fr + np.abs(beat1 - beat2)),
'f0+(B2-B1)_': np.abs(eod_fr + np.abs(beat2 - beat1)),
'f1_': f1,
'f2_': f2,
'f1_harm_': f1 * 2,
'f2_harm_': f2 * 2,
'env_': np.abs(np.abs(df1) - np.abs(df2)),
'env_beat_': np.abs(np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs(
create_beat_corr(eod_fr - f2, np.array([eod_fr])))),
'env_beat_beatf0_': create_beat_corr(np.abs(
np.abs(create_beat_corr(eod_fr - f1, np.array([eod_fr]))) - np.abs(
create_beat_corr(eod_fr - f2, np.array([eod_fr])))), np.array([eod_fr])),
'fr_': fr,
'fr_harm_': fr * 2}
return names
def plt_calc_amps(results, p0, p01, p02, p012, frame, fs):
fig, ax = plt.subplots(2, 4, sharex=True, sharey=True) #
ax = ax.flatten()
arrays = [p0[0], p01[0], p02[0], p012[0], p012[0] - p02[0], p012[0] - p01[0], p012[0] - p02[0] - p01[0],
p012[0] - p02[0] - p01[0] + p0[0]]
titles = ['0', '01', '02', '012', '012-02', '012-01', '012-01-02', '012-01-02+0', ]
plt.suptitle(frame.cell.iloc[0])
for a, array in enumerate(arrays):
ax[a].plot(fs, array, color='black')
ax[a].set_title(titles[a])
ax[a].set_xlim(0, 1000)
B1 = results.df1.iloc[0]
B2 = results.df2.iloc[0]
fr = results.fr.iloc[0]
freqs = [np.abs(B2),
np.abs(np.abs(np.abs(B1) - np.abs(B2))),
np.abs(np.abs(B1) + np.abs(B2)), np.mean(fr), np.abs(B1), np.abs(B1) * 2, np.abs(B1) * 3,
np.abs(B1) * 4, ]
colors = ['green', 'purple', 'orange', 'red', 'blue', 'blue', 'blue', 'blue', ]
labels = ['DF2', '|DF1-DF2|', '|DF1+DF2|', 'Baseline', 'DF1', 'DF1', 'DF1', 'DF1']
plt_peaks_several(freqs, arrays, ax[a], array, fs, labels, 0, colors, alpha=0.5)
ax[a].set_xlim([0, 400])
ax[-1].legend(loc=(1, 0), ncol=1)
plt.subplots_adjust(right=0.8)
plt.show()
def calc_pure_amps_diffs(frame, pos, names, fishes, freq_step, ps, fs, devname, t, add, points=5):
for nn, name in enumerate(names):
# Hier werden die Einzelfrequenzen gemacht
# also hier berechne ich die als std dann ist das Hz
for ff, fish in enumerate(fishes):
frame = calc_amp_value(names, freq_step, ff, ps, pos, fs, devname, t, frame, name, add, fish, points=points)
# Hier werden die Diff Scores gemacht
# in dem default fall sqrt == '' schauen wir uns gleich die sqrt peaks an
# Summieren können wir aber nur varianzen!
# if sqrt = '_sqrt_': # das mache ich um das von früheren Versionen zu differnezieren!
if 'amp_' + name + '01' + add in frame.keys():
frame.loc[pos, 'amp_' + name + '012-01' + add] = frame.loc[pos, 'amp_' + name + '012' + add] - \
frame.loc[pos, 'amp_' + name + '01' + add]
if 'amp_' + name + '02' + add in frame.keys():
frame.loc[pos, 'amp_' + name + '012-02' + add] = (frame.loc[
pos, 'amp_' + name + '012' + add]) - \
(frame.loc[
pos, 'amp_' + name + '02' + add])
if 'amp_' + name + '01' + add in frame.keys():
frame.loc[pos, 'amp_' + name + 'diff' + add] = (frame.loc[
pos, 'amp_' + name + '012' + add]) - \
(frame.loc[
pos, 'amp_' + name + '01' + add]) - \
(frame.loc[
pos, 'amp_' + name + '02' + add]) + \
(frame.loc[
pos, 'amp_' + name + '0' + add])
frame.loc[pos, 'amp_' + name + '012-02-01' + add] = (frame.loc[
pos, 'amp_' + name + '012' + add]) - \
frame.loc[
pos, 'amp_' + name + '02' + add] - \
frame.loc[
pos, 'amp_' + name + '01' + add]
return frame
def find_B1B2_norm_amp_diffs(frame, norms_name, norms, pos, add):
for n, norm in enumerate(norms):
##################
# B1 & B2
divs = [2, 1]
divs_name = ['/2', '', ]
for d, div in enumerate(divs):
# IMPORTANT
B1_B2 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add] - frame.loc[pos, 'amp_' + 'B1_' + '01' + add]) +
(frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B2_' + '02' + add])) / norm
# embed()
B1_0102 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add]
- frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[
pos, 'amp_' + 'B1_' + '02' + add])) / norm # frame.loc[pos, 'amp_' + 'B1_' + '0' + add]
B2_0102 = (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[pos, 'amp_' + 'B2_' + '01' + add]) / norm
B1_01020 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add]
- frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[
pos, 'amp_' + 'B1_' + '02' + add]) + frame.loc[
pos, 'amp_' + 'B1_' + '0' + add]) / norm # frame.loc[pos, 'amp_' + 'B1_' + '0' + add]
B2_01020 = (frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[pos, 'amp_' + 'B2_' + '01' + add] +
frame.loc[pos, 'amp_' + 'B2_' + '0' + add]) / norm
B1_B2_0102 = ((frame.loc[pos, 'amp_' + 'B1_' + '012' + add]
- frame.loc[pos, 'amp_' + 'B1_' + '01' + add] - frame.loc[
pos, 'amp_' + 'B1_' + '02' + add]) +
(frame.loc[pos, 'amp_' + 'B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B2_' + '02' + add] - frame.loc[
pos, 'amp_' + 'B2_' + '01' + add])) / norm
frame.loc[pos, 'amp_' + 'B1&B2' + divs_name[d] + '_012-01_012-02' + norms_name[n] + add] = B1_B2 / div
frame.loc[pos, 'amp_' + 'B1&B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_B2 / div
frame.loc[pos, 'amp_' + 'B1' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_0102 / div
frame.loc[pos, 'amp_' + 'B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B2_0102 / div
frame.loc[pos, 'amp_' + 'B1' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_01020 / div
frame.loc[pos, 'amp_' + 'B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B2_01020 / div
# B1 & B2 Harmonische
divs = [4, 1]
divs_name = ['/4', '', ]
Bs = ['B1_', 'B2_']
ends = ['01', '02']
B_harms = []
# here takes only two combis of Bs and ends
for bb, B in enumerate(Bs):
# IMPORTANT
B_harm = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + 'harm_' + ends[bb] + add])
+
(frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + '2harm_' + ends[bb] + add])
+
(frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + '3harm_' + ends[bb] + add])) / norm
B_harm_0102 = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + 'harm_' + '01' + add]
- frame.loc[pos, 'amp_' + B + 'harm_' + '02' + add])
+
(frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + '2harm_' + '01' + add]
- frame.loc[pos, 'amp_' + B + '2harm_' + '02' + add])
+
(frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \
-frame.loc[pos, 'amp_' + B + '3harm_' + '01' + add] -
frame.loc[pos, 'amp_' + B + '3harm_' + '02' + add])) / norm
B_harm_01020 = ((frame.loc[pos, 'amp_' + B + 'harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + 'harm_' + '01' + add]
- frame.loc[pos, 'amp_' + B + 'harm_' + '02' + add]
+ frame.loc[pos, 'amp_' + B + 'harm_' + '0' + add])
+
(frame.loc[pos, 'amp_' + B + '2harm_' + '012' + add] - \
frame.loc[pos, 'amp_' + B + '2harm_' + '01' + add]
- frame.loc[pos, 'amp_' + B + '2harm_' + '02' + add]
+ frame.loc[pos, 'amp_' + B + '2harm_' + '0' + add])
+
(frame.loc[pos, 'amp_' + B + '3harm_' + '012' + add] - \
-frame.loc[pos, 'amp_' + B + '3harm_' + '01' + add] -
frame.loc[pos, 'amp_' + B + '3harm_' + '02' + add]
+ frame.loc[pos, 'amp_' + B + '3harm_' + '0' + add])) / norm
if 'B1_' in B:
B1_harm = B_harm
B1_harm_0102 = B_harm_0102
B1_harm_01020 = B_harm_01020
else:
B2_harm = B_harm
B2_harm_0102 = B_harm_0102
B2_harm_01020 = B_harm_01020
B_harms.append(B_harm)
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-' + ends[bb] + norms_name[n] + add] = B_harm / div
frame.loc[
pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B_harm_0102 / div
frame.loc[pos, 'amp_' + B + 'harms_' + divs_name[d] + '_012-01-02+0' + norms_name[
n] + add] = B_harm_01020 / div
# für das aufsummierte
# here I have different controls, I guess the t the mean control ist the best
diff_parts = [('02', '02'), ('01', '01'), ('02', '01'), ('01', '02')]
for diff_part in diff_parts:
# 3) B1-B2 & B1+B2, für die verschiedenen Kontrollen
prev = ((frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + diff_part[0] + add]) +
(frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + diff_part[1] + add])) / norm
divs = [2, 1]
divs_name = ['/2', '']
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1-B2&B1+B2' + divs_name[d] + '_012-' + diff_part[0] + '_012-' + diff_part[1] + '' +
norms_name[n] + add] = prev / div
# 4) B1 & B2 & B1 - B2 & B1 + B2
# B1 - B2 & B1 + B2 AND EXTRA B1 & B2
divs = [4, 1]
divs_name = ['/4', '']
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1&B2&B1-B2&B1+B2' + divs_name[d] + '_012-' + diff_part[0] + '_012-' + diff_part[
1] + '' + norms_name[n] + add] = (prev + B1_B2) / div
# B1-B2
B1_minus_B2_0102 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add]) / norm
B1_minus_B2_01020 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add] + frame.loc[
pos, 'amp_' + 'B1-B2_' + '0' + add]) / norm
B1_minus_B2 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] + frame.loc[
pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add]) / (2 * norm)
divs = [2, 1]
divs_name = ['/2', '', ]
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1-B2' + divs_name[d] + '_mean(012-0102)' + norms_name[n] + add] = B1_minus_B2 / div
frame.loc[
pos, 'amp_' + 'B1-B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_minus_B2_0102 / div
frame.loc[
pos, 'amp_' + 'B1-B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_minus_B2_01020 / div
# B1+B2
B1_plus_B2_0102 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / norm
B1_plus_B2_01020 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add] +
frame.loc[pos, 'amp_' + 'B1+B2_' + '0' + add]) / norm
B1_plus_B2 = (frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] +
frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / (2 * norm)
divs = [2, 1]
divs_name = ['/2', '']
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1+B2' + divs_name[d] + '_mean(012-0102)' + norms_name[n] + add] = B1_plus_B2 / div
frame.loc[pos, 'amp_' + 'B1+B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = B1_plus_B2_0102 / div
frame.loc[
pos, 'amp_' + 'B1+B2' + divs_name[d] + '_012-01-02+0' + norms_name[n] + add] = B1_plus_B2_01020 / div
# IMPORTANT
# und hier kommt die Fortsetzung das ist das gleiche nur mit Mean
B1_minus_B2_B1_plus_B2 = (frame.loc[pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '02' + add] + frame.loc[
pos, 'amp_' + 'B1-B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1-B2_' + '01' + add] +
frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '01' + add] +
frame.loc[pos, 'amp_' + 'B1+B2_' + '012' + add] - \
frame.loc[pos, 'amp_' + 'B1+B2_' + '02' + add]) / (2 * norm)
# IMPORTANT
divs = [2, 1]
divs_name = ['/2', '']
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[
n] + add] = B1_minus_B2_B1_plus_B2 / div
divs = [4, 1]
divs_name = ['/4', '']
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1&B2&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[n] + add] = (
B1_minus_B2_B1_plus_B2 + B1_B2) / div
# VERY IMPORTANT
# B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2
divs = [8, 1]
divs_name = ['/8', '']
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' + norms_name[
n] + add] = (B1_minus_B2_B1_plus_B2 + B1_harm + B2_harm) / div
# B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2
divs = [10, 1]
divs_name = ['/10', '']
frame = frame.copy()
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_mean(012-0102_012-0102)' +
norms_name[n] + add] = (B1_minus_B2_B1_plus_B2 + B1_harm + B2_harm + B1_B2) / div
# VERY IMPORTANT (all with difference to two
# B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2
divs = [8, 1]
divs_name = ['/8', '']
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02' + norms_name[n] + add] = (
B1_minus_B2_0102 + B1_plus_B2_0102 + B1_harm_0102 + B2_harm_0102) / div
# B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2
divs = [10, 1]
divs_name = ['/10', '']
frame = frame.copy()
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02' +
norms_name[n] + add] = (
B1_minus_B2_0102 + B1_plus_B2_0102 + B1_harm_0102 + B2_harm_0102 + B1_B2_0102) / div
# VERY IMPORTANT (all with difference to two
# B1&B2&B1-B2&B1+B2&Harm OHNE B1 & B2
divs = [8, 1]
divs_name = ['/8', '']
for d, div in enumerate(divs):
frame.loc[
pos, 'amp_' + 'B1Harm&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02+0' + norms_name[
n] + add] = (B1_minus_B2_01020 + B1_plus_B2_01020 + B1_harm_01020 + B2_harm_01020) / div
# B1&B2&B1-B2&B1+B2&Harm MIT B1 & B2
divs = [10, 1]
divs_name = ['/10', '']
frame = frame.copy()
for d, div in enumerate(divs):
frame.loc[pos, 'amp_' + 'B1&B1Harm&B2&B2Harm&B1-B2&B1+B2' + divs_name[d] + '_012-01-02+0' +
norms_name[n] + add] = (
B1_minus_B2_01020 + B1_plus_B2_01020 + B1_harm_01020 + B2_harm_01020 + B1_01020 + B2_01020) / div
test = False
if test:
from utils_test import test_calc_amps2
test_calc_amps2(frame, B1_B2)
return frame
def find_norm_amp_diff(norms, pos, frame, names, norms_name, add):
frame = frame.copy()
for nn, name in enumerate(names):
frame = frame.copy()
for n, norm in enumerate(norms):
frame.loc[pos, 'amp_' + name + '012-01' + norms_name[n] + add] = (frame.loc[
pos, 'amp_' + name + '012' + add] - \
frame.loc[
pos, 'amp_' + name + '01' + add]) / norm
frame.loc[pos, 'amp_' + name + '012-02' + norms_name[n] + add] = (frame.loc[
pos, 'amp_' + name + '012' + add] - \
frame.loc[
pos, 'amp_' + name + '02' + add]) / norm
return frame
def find_norms(min_amps, pos, frame, add):
if 'min' in min_amps:
if 'norm' in min_amps:
norms_name = ['_norm_01B1',
'_norm_02B2',
'_norm_01B1+02B2',
'_norm_eodf']
norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add],
frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
np.mean(
frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add])]
else:
norms_name = []
norms = []
else:
norms_name = ['_norm_01B1',
'_norm_02B2',
'_norm_01B1+02B2',
'_norm_eodf']
norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add],
frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add])]
return norms, norms_name
def calc_var_psd_same_score(fs, p012, p01, p02, p0, pos, frame, add):
# embed()power_distance_int_sqr
# also den Wert könnten wir doch auch bei den Daten nehmen!
# das ist der wert der uns vor allem interessiert
p_val = np.sum(np.mean(np.array(p012) - np.array(p02) - np.array(p01) + np.array(p0), axis=0)) * fs[1]
frame.loc[pos, 'power_distance_int' + add] = p_val
if p_val > 0:
frame.loc[pos, 'power_distance_int_sqrt' + add] = np.sqrt(p_val)
else:
frame.loc[pos, 'power_distance_int_sqrt' + add] = -np.sqrt(-p_val)
##################################################
# das ist die eine ROC condition
p_val = np.sum(np.mean(np.array(p012) - np.array(p02), axis=0)) * fs[1]
frame.loc[pos, '012-02_power_distance_int' + add] = p_val
if p_val > 0:
frame.loc[pos, '012-02_power_distance_int_sqrt' + add] = np.sqrt(p_val)
else:
frame.loc[pos, '012-02_power_distance_int_sqrt' + add] = -np.sqrt(-p_val)
###################################################
# das ist die andere ROC condition
p_val = np.sum(np.mean(np.array(p01) - np.array(p0), axis=0)) * fs[1]
frame.loc[pos, '01-0_power_distance_int' + add] = p_val
if p_val > 0:
frame.loc[pos, '01-0_power_distance_int_sqrt' + add] = np.sqrt(p_val)
else:
frame.loc[pos, '01-0_power_distance_int_sqrt' + add] = -np.sqrt(-p_val)
return frame
def find_norms_euc(min_amps, frame, pos, p012, p01, p02, p0, add):
if 'min' in min_amps:
if 'norm' in min_amps:
norms_name = ['_norm_01B1',
'_norm_02B2',
'_norm_01B1+02B2',
'_norm_eodf',
'_norm_p012', '_norm_p01', '_norm_p02', '_norm_p0', ]
norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add],
frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add]),
p012, p01, p02, p0, ]
else:
norms_name = []
norms = []
else:
norms_name = ['_norm_01B1',
'_norm_02B2',
'_norm_01B1+02B2',
'_norm_eodf',
'_norm_p012', '_norm_p01', '_norm_p02', '_norm_p0', ]
norms = [frame.loc[pos, 'amp_' + 'B1_' + '01' + add],
frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
frame.loc[pos, 'amp_' + 'B1_' + '01' + add] + frame.loc[pos, 'amp_' + 'B2_' + '02' + add],
np.mean(frame.loc[pos, 'amp_' + 'f0_' + '01' + add] + frame.loc[pos, 'amp_' + 'f0_' + '02' + add]),
p012, p01, p02, p0, ]
return norms_name, norms
def calc_euc_amp_norm(fs, diff_parts_names, add, norms_name, frame, pos, norms, diff_parts):
for n, norm in enumerate(norms):
for dd in range(len(diff_parts)):
diffs = []
diffs_norm = []
for i in range(len(diff_parts[dd][0])):
for j in range(len(diff_parts[dd][0])):
diffs.append(diff_parts[dd][0][i] - diff_parts[dd][1][j])
# hier kommen die zusätzlichen norm sachen
if ('B' in norms_name[n]) | ('eod' in norms_name[n]):
diffs_norm.append(diff_parts[dd][0][i] / norms[n] - diff_parts[dd][1][j] / norms[n])
else:
diffs_norm.append(
diff_parts[dd][0][i] / np.sum(norms[n][i] * fs[1]) - diff_parts[dd][1][j] / np.sum(
norms[n][i] * fs[1]))
prev = np.mean(np.linalg.norm(diffs_norm, axis=1))
frame.loc[pos, 'euclidean_all_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[
n] + add] = prev
if ('B' in norms_name[n]) | ('eod' in norms_name[n]):
frame.loc[pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[
n] + add] = np.linalg.norm(
np.mean(np.array(diff_parts[dd][0]) / norms[n] - np.array(diff_parts[dd][1]) / norms[n], axis=0))
else:
frame.loc[pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + norms_name[
n] + add] = np.linalg.norm(np.mean(
np.transpose(np.array(diff_parts[dd][0])) / np.sum(np.array(norms[n]) * fs[1],
axis=1) - np.transpose(
np.array(diff_parts[dd][1])) / np.sum(np.array(norms[n]) * fs[1], axis=1), axis=1))
frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + norms_name[n] + add] = np.mean(
[frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + norms_name[n] + add],
frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + norms_name[n] + add]])
try:
frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + norms_name[n] + add] = np.mean(
[frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + norms_name[n] + add],
frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + norms_name[n] + add]])
except:
print('problem euclidean')
embed()
frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + '_norm_p01p02' + add] = np.mean(
[frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + '_norm_p01' + add],
frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + '_norm_p02' + add]])
frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + '_norm_p01p02' + add] = np.mean(
[frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + '_norm_p01' + add],
frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + '_norm_p02' + add]])
return frame
def calc_euc_amp(add, frame, diff_parts, pos, diff_parts_names):
frame = frame.copy()
for dd in range(len(diff_parts)):
diffs = []
for i in range(len(diff_parts[dd][0])):
for j in range(len(diff_parts[dd][0])):
diffs.append(diff_parts[dd][0][i] - diff_parts[dd][1][j])
prev = np.mean(np.linalg.norm(diffs, axis=1))
frame.loc[pos, 'euclidean_all_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + add] = prev
frame.loc[
pos, 'euclidean_' + diff_parts_names[dd][0] + '-' + diff_parts_names[dd][1] + add] = np.linalg.norm(
np.mean(np.array(diff_parts[dd][0]) - np.array(diff_parts[dd][1]), axis=0))
frame.loc[pos, 'euclidean_all_' + 'mean(012-01_012-02)' + add] = np.mean(
[frame.loc[pos, 'euclidean_all_' + '012' + '-' + '01' + add],
frame.loc[pos, 'euclidean_all_' + '012' + '-' + '02' + add]])
frame.loc[pos, 'euclidean_' + 'mean(012-01_012-02)' + add] = np.mean(
[frame.loc[pos, 'euclidean_' + '012' + '-' + '01' + add],
frame.loc[pos, 'euclidean_' + '012' + '-' + '02' + add]])
return frame
def calc_amps(fs, p0, p02, p01, p012, pos, devname, t, frame, results, timesstamp=False, add='', min_amps='', points=5,
printing=False):
fishes = ['012', '01', '02', '0']
ps = [p012, p01, p02, p0]
test = False
if test:
plt_calc_amps(results, p0, p01, p02, p012, frame, fs)
freq_step = np.abs(fs[1] - fs[0])
try:
fr = results.fr.loc[pos]
except:
print('fr prob')
embed()
f2 = results.f2.loc[pos]
f1 = results.f1.loc[pos]
try:
df1 = results.DeltaF1.loc[pos]
df2 = results.DeltaF2.loc[pos]
eod_fr = results.EODf.loc[pos]
except:
eod_fr = results.f0.loc[pos]
df1 = results.df1.loc[pos]
df2 = results.df2.loc[pos]
try:
beat1 = create_beat_corr(np.array([np.abs(df1)]), np.array([eod_fr]))
except:
print('beat 1 problem')
embed()
beat2 = create_beat_corr(np.array([np.abs(df2)]), np.array([eod_fr]))
names = peaks_of_interest(df1, df2, beat1, beat2, fr, f1, f2, eod_fr, min_amps=min_amps)
for name in names:
frame.loc[pos, name[0:-1]] = names[name]
names[''] = ''
names['max_'] = ''
names['max_harm_'] = ''
# drei sachen
# 1) erstmal nur die Veränderungen von B1 und B2
# 2) dann die Veränderung von B1,B2,
# 3) dann veränderung von B1+B2 und B1-B2 und dann B1-B2, B1+B2
# 4) B1, B2, B1-B2, B1+B2
# 5) Euclidische Distanz
# 6) Und noch Normierung (mit B1, B2, B1+B2)
##
# VARIABLEN: 4: NORM, MEAN, VON WAS SIE DIFFERENZ, 1-6
t1 = time.time()
# Nur einzelfrequenzen und deren richtigen Diffs
if np.isnan(frame.loc[pos, 'f0']):
print('isnan thing2')
embed()
frame = calc_pure_amps_diffs(frame, pos, names, fishes, freq_step, ps, fs, devname, t, add, points=points)
time_first = time.time() - t1
# embed()
# 1) erstmal nur die Veränderungen von einzelnen Frequenzen B1 und B2, B1+B2 und B1-B2
# hier habe ich die drei Normierungen, wir normieren immer auf B1, B2 oder beides
# weil diese von der Beat Frequenz abhängen können normieren wir auch auf das EODf
# das charachterisiert das Antwortverhalten der P-unit
# auch diese normierungen die brauchen wir denke ich nicht
# wenn alle haben will schriebe ich nix
# wenn ich das absolute minimum haben will sollte min drin sein
# wenn ich ein bisschen mehr haben will dann aber auch norm
norms, norms_name = find_norms(min_amps, pos, frame, add)
# für alle Werte auch nochmal die normierten Peaks
# bei der reduzierten Version lassen wir das mit den norms, braucht je kein Mensch
if norms:
t1 = time.time()
frame = find_norm_amp_diff(norms, pos, frame, names, norms_name, add)
time_second = time.time() - t1
# 2) dann die Veränderung von B1 & B2, B1-B2 & B1+B2,
if norms:
t1 = time.time()
frame = find_B1B2_norm_amp_diffs(frame, norms_name, norms, pos, add)
time_third = time.time() - t1
###############
# DAS IST EIN WICHIGER SCORE
# das gleiche wie die varianz
# zweiter Score Dezember 2022
if (len(p02) > 0) & (len(p01) > 0): # todo für verschiedene Trials
frame = calc_var_psd_same_score(fs, p012, p01, p02, p0, pos, frame, add)
# embed()
##############
# die restlichen (für talk in lissbo)
# 5) Euclidische Distanz
# np.linalg.norm(np.array(p012)-np.array(p01))
# die zwei sind das gleiche, also ob ich die direkt subtrahiere
# alle gegen alle vergleich
# np.linalg.norm(np.array(p012) - np.array(p01), axis=0))
# Trial für Trial Vergleich
t1 = time.time()
# 1) verschiedene Normierungen, 2) all vs not 3) was gegen was
# ich glaube diese normierung über das spectrum machen wir nur damit das über die Zellen vergleich bar bleibt?
norms_name, norms = find_norms_euc(min_amps, frame, pos, p012, p01, p02, p0, add)
# fishes = ['norm_B2']
# norms = [frame.loc[pos, 'amp_' + 'B2_' + '02' + add]]
diff_parts_names = [('012', '02'), ('012', '01')]
diff_parts = [(p012, p02), (p012, p01)]
if len(p02) > 0:
frame = calc_euc_amp(add, frame, diff_parts, pos, diff_parts_names)
if norms:
frame = calc_euc_amp_norm(fs, diff_parts_names, add, norms_name, frame, pos, norms, diff_parts)
time_forht = time.time() - t1
if printing:
print(time_first)
print(time_second)
print(time_third)
print(time_forht)
# embed()
# hier nehmen wir die Wurzel damit die Werte am Ende eben keine varianzen sondern std sind also in Hz!
frame = sqrt_values(pos, frame, add) # .replace('_mean','')
test = False
if test:
plt.plot(fs, p02[0])
plt.scatter(frame['f0'], frame['amp_f0_02_original'])
plt.scatter(frame['f1'], frame['amp_f1_02_original'])
plt.scatter(frame['f2'], frame['amp_f2_02_original'])
plt.show()
# embed()
# names = 'amp_fr_012-02-01_mean'
# hier nehme ich also die Sachen mit den AMPS und den Euclidischen Distanzen
# und auch die Fläche und nehme nochmal die Wurzel draus!
# wenn amp and euc drin ist dann nehme ich hier nochmal die Wurzel!
# embed()
# start_pos = np.where(np.array(keys) == 'amp_'+keys_names[0]+fishes[0]+add)
if test:
embed()
return frame
def sqrt_values(pos, frame, add=''):
keys = [k for k in frame]
for k in keys:
if (('amp' in k) | ('euclidean' in k)) & (add in k):
if frame.loc[pos, k] < 0:
# für den Fall das das negativ ist machen wir das erst wieder positiv und dann wieder negativ
# das gilt vor allem für die Differenz Werte
frame.loc[pos, k] = -np.sqrt(-frame.loc[pos, k])
else:
frame.loc[pos, k] = np.sqrt(frame.loc[pos, k])
return frame
def calc_ps(nfft, array012, array01, array02, array0, sampling_rate=40000, log = '', xlim = []):
p012, f = calc_ps_single(array012, nfft, sampling_rate, log = log, xlim = xlim)
p01, f = calc_ps_single(array01, nfft, sampling_rate, log = log, xlim = xlim)
p02, f = calc_ps_single(array02, nfft, sampling_rate, log = log, xlim = xlim)
p0, f = calc_ps_single(array0, nfft, sampling_rate, log = log, xlim = xlim)
return p0, p02, p01, p012, f
def calc_ps_single(array012, nfft, sampling_rate, log = '', xlim = []):
p012 = [[]] * len(array012)
for i in range(len(array012)):
p012[i], f = ml.psd(array012[i] - np.mean(array012[i]), Fs=sampling_rate,
NFFT=nfft, noverlap=nfft // 2)
if log == 'log':
if len(xlim)>0:
p012[i] = p012[i][f<xlim[-1]]
f = f[f<xlim[-1]]
p012[i] = calc_log(p012[i])
return p012, f
def calculate_the_difference(position_diff, results_diff, name_saved, title, contdition12, control_01, control_02,
base_0, base_1=[], base_2=[]):
results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title] = np.mean(
contdition12 - control_01 - control_02 + base_0)
# das ist die detektion gegenüber der 1er Welle
try:
results_diff.loc[position_diff, '012-01' + '_' + name_saved + '_' + title] = np.mean(
contdition12 - control_01)
except:
print('results diff in uilts_func')
embed()
# das ist diedetektion gegenüber der 2er Welle
results_diff.loc[position_diff, '012-02' + '_' + name_saved + '_' + title] = np.mean(
contdition12 - control_02)
# das ist das was experimentell nicht möglich ist
if (len(base_1) > 0) and (len(base_2) > 0):
results_diff.loc[position_diff, '012-0-1-2' + '_' + name_saved + '_' + title] = np.mean(
contdition12 - base_0 - base_1 - base_2)
if name_saved == 'var':
# das ist der SCORE 1 FÜR DIE DIFFERENZEN
# das ist nochmal var squared
var_val = results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title] # titles_all[names[0]][t]
if var_val > 0:
# wenns positiv ist behalten wir das
results_diff.loc[position_diff, 'diff' + '_' + 'var_sqrt' + '_' + title] = np.sqrt(
results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title])
else:
results_diff.loc[position_diff, 'diff' + '_' + 'var_sqrt' + '_' + title] = -np.sqrt(
-results_diff.loc[position_diff, 'diff' + '_' + name_saved + '_' + title])
return results_diff
def equal_to_temporal_mean(ffts_all):
if np.shape(ffts_all) == 3:
fft_val = np.abs(np.mean(ffts_all, axis=0)[3]) ** 2 - np.abs(np.mean(ffts_all, axis=0)[2]) ** 2 - np.abs(
np.mean(ffts_all, axis=0)[1]) ** 2 + np.abs(np.mean(ffts_all, axis=0)[0]) ** 2
else:
fft_val = np.abs(np.mean(ffts_all[cl_3names.c012], axis=0)) ** 2 - np.abs(
np.mean(ffts_all[cl_3names.c01], axis=0)) ** 2 - np.abs(
np.mean(ffts_all[cl_3names.c02], axis=0)) ** 2 + np.abs(np.mean(ffts_all[cl_3names.c0], axis=0)) ** 2
return fft_val
class cl_3names:
"""A simple example class"""
c012 = '012'
c02 = '02'
c01 = '01'
c0 = '0'
def calc_FFT3(arrays, deltat, fft, names):
for a, array in enumerate(arrays):
try:
fft[names[a]] = np.fft.fft(array - np.mean(array), norm='forward') # /nfft # nas sollte forward sein
except:
fft[names[a]] = np.fft.fft(array - np.mean(array)) * deltat
return fft
def data_tuning(show=True):
cells = ['2021-08-03-ac-invivo-1']
_, _ = find_all_threewave_versions()
save_name_alls = [
'calc_auc_three_AllTrialsIndexEodLocSynch_Min0.25sExcluded__multsorted2__psdEOD__minindices___nfft_32768three_AUCI_sqrt__points1.pkl']
plot_style()
default_figsize(column=2, length=2) # ts=12, ls=12, fs=12
for save_name_all0 in save_name_alls:
for c, cell in enumerate(cells):
save_name_all = load_folder_name('threefish') + '/' + save_name_all0
name0 = save_name_all.split('_nfft')[0] + cell + '_nfft' + save_name_all.split('_nfft')[1]
if '_dev' in save_name_all:
name1 = save_name_all.split('_dev')[0] + cell + '_dev' + save_name_all.split('_dev')[1]
else:
name1 = 'xyo'
if os.path.exists(name0):
print(name0 + 'exists')
name = name0
elif os.path.exists(name1):
print(name1 + 'exists')
name = name0
else:
print('PROBLEM ' + str(save_name_all))
name = name0
if os.path.exists(name):
frame_orig = pd.read_pickle(name)
contrasts = [10] # frame_orig.c2.unique()
for c, contrast2 in enumerate(contrasts):
contrasts1 = [10] # frame_orig.c1.unique()
for contrast1 in contrasts1:
if len(frame_orig) > 0:
frame = frame_orig[(frame_orig['cell'] == cell) & (
frame_orig['c2'] == contrast2) & (
frame_orig['c1'] == contrast1) & (frame_orig['dev'] == '05')] #
print(np.mean(np.mean(frame.EODf.unique())))
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
nr=2)
# also hier die Rows bestimmen
gridspacing = [] # 0.02
grid0 = gridspec.GridSpec(1, 1, bottom=0.25, top=0.75, left=0.09,
right=0.98,
wspace=0.04) #
axs = []
p_nrs = [7, 4, 2] # np.arange(0, len(pivot), 1)#
grid1 = gridspec.GridSpecFromSubplotSpec(1, len(p_nrs), wspace=0.35, hspace=0.35,
subplot_spec=grid0[0])
dfs = ['DF1', 'DF2']
pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, scores[0],
orientation=[],
gridspacing=gridspacing,
dfs=dfs,
matrix_sorted='grid_sorted')
if '2' in pivot.columns.name:
scores = [scores[2], scores[3], scores[0], scores[1]]
for s, score in enumerate(scores):
pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, score,
orientation=[],
gridspacing=gridspacing,
dfs=dfs,
matrix_sorted='grid_sorted')
print('min f0 ' + str(np.min(frame.f0)))
print('min f1 ' + str(np.min(frame.f1)))
print('max f1 ' + str(np.max(frame.f1)))
print('min f2 ' + str(np.min(frame.f2)))
print('max f2 ' + str(np.max(frame.f2)))
if len(pivot) > 0:
if s == 0:
_, _ = find_row_col(pivot)
for pp, p in enumerate(p_nrs): # range(len(pivot)):
ax = plt.subplot(grid1[pp])
if 'm' in dfs[0]:
try:
ax.set_title(pivot.index.name + ' ' + str(pivot.index[p]))
except:
print('ax something')
embed()
else:
if s == 0:
ax.text(1, 1.05, '$\Delta f_{' + stable_val() + '}=%s$' % (
int(pivot.index[p])) + '\,Hz', ha='right', transform=ax.transAxes)
ax.plot(pivot.columns, pivot.iloc[p], color=colors[s], label=labels[s],
linestyle=linestyles[s], linewidth=linewidths[s])
ax.set_xlabel(xlabel_vary()) # pivot.columns.name
if pp != 0:
remove_yticks(ax)
else:
ax.set_ylabel(representation_ylabel())
axs.append(ax)
join_y(axs)
fig = plt.gcf()
fig.tag(axs[0:3], xoffs=-3, yoffs=1)
if len(pivot) > 0:
axs[0].legend(loc=(0, 1.2), ncol=2)
individual_tag = save_name_all0 + '_' + cell + '_c1_' + str(contrast1) + '_c2_' + str(
contrast2) + '_gridpsacing_' + str(gridspacing)
save_visualization(individual_tag, show=show)
def xlabel_vary():
return '$\Delta f_{' + vary_val() + '}$\,[Hz]'
def tuning_f(freqs=[(39.5, -135.5)],
cells_here='2011-10-25-ad-invivo-1'):
plot_style()
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
trials_nrs = [1]
plot_style()
default_figsize(column=2, length=5.35) # 5.5)#7.5 5.75
default_figsize(column=2, length=7.5)
default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(17.2))
default_ticks_talks()
for _ in trials_nrs: # +[trials_nrs[-1]]
scatter_extra = False
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal']
# 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
# (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0)
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.13, top=0.85, left=0.1,
right=0.975,
wspace=0.15) # top=0.895
###################################################
squares = False
if squares:
full_names_square = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ]
frame_square = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv')
frame_cell_square = frame_square[(frame_square.cell == cell_here)]
axes = []
axes.append(plt.subplot(grid_s[0]))
axes.append(plt.subplot(grid_s[1]))
axes.append(plt.subplot(grid_s[2]))
frame_cell_square = single_frame_processing(c_grouped, frame_cell_square)
lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05')
plt_cross(matrix, axes[-1])
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
show = True
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
ax_upper = []
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
f = -1
####################################################
# hier kommt die amplituden tuning curve
#
frame_cell = single_frame_processing(c_grouped, frame_cell_orig)
c_heres = [0.1, 0.25] # 0.03,
c_colors = ['dimgrey', 'darkgrey'] # ,'black', ],'silver'
freq1s = np.unique(frame_cell_orig.df1)
freq2s = np.unique(frame_cell_orig.df2)
# np.argmin(frame_cell['amp_B1_01_mean_original'] - frame['amp_B1_012_mean_original'])].c1
f_counter = 0
ax_uss = []
letters_all = [['$\mathrm{A_{ii}}$', '$\mathrm{A_{iii}}$'], ['$\mathrm{B_{ii}}$', '$\mathrm{B_{iii}}$']]
letters_all2 = ['$\mathrm{A_{i}}$', '$\mathrm{B_{i}}$']
for freq1, freq2 in freqs:
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.15, hspace=0.73, subplot_spec=grid0[f_counter],
height_ratios=[1.4, 2.35]) # hspace=0.35 1, 2.55
grid_u = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.7,
wspace=0.25,
subplot_spec=grid00[
0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_r = gridspec.GridSpecFromSubplotSpec(1, 2,
hspace=0.3,
wspace=0.1,
subplot_spec=grid00[1])
################################################
grid_s = gridspec.GridSpecFromSubplotSpec(1, 3,
hspace=0.7,
wspace=0.45,
subplot_spec=grid00[-1])
freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))]
freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))]
f += 1
print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here))
ax_u1_upper = plt.subplot(grid_u[0])
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here,
sum=False, nr=2, c_dist_recalc=c_dist_recalc,
linestyles=['-', '--', '-', '--', '-'])
ax_u1_upper.set_yticks_delta(100) # set_xticks_delta
ax_u1_upper.set_xlim(0, 35)
c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here,
c_dist_recalc=c_dist_recalc)
height = 355 # 0
letter_plus = 30
if not c_dist_recalc:
c_nrs_here_cm = np.array(c_nrs_here_cm) * 100
try:
ax_u1_upper.scatter(c_nrs_here_cm, height * np.ones(len(c_nrs_here_cm)), color=c_colors,
marker='v',
clip_on=False, s=7)
except:
print('embed something')
embed()
for cn, cnr in enumerate(c_nrs_here_cm):
ax_u1_upper.text(cnr, height + letter_plus, letters_all[f_counter][cn], ha='center',
va='center', color=c_colors[cn])
ax_u1_upper.plot([cnr, cnr], [0, height], color=c_colors[cn], linewidth=lw_tuning(), zorder=100)
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean_original', nr=2)
color_first = 'black' # 'red' # color01
color_second = 'black' # color02
color_third = 'black'
ax_u1_upper.text(0, 1.2, ' $\Delta f_{' + vary_val() + '}=%s$' % freq1_here + ' Hz' +
'\n $\Delta f_{' + stable_val() + '}=%s$' % (
freq2_here) + '\,Hz, ' + '$c_{' + stable_val() + '}=10\,\%$',
transform=ax_u1_upper.transAxes) # transform
if f_counter != 0:
ax_u1_upper.set_ylabel('')
remove_yticks(ax_u1_upper)
frame_cell_chosen = frame_cell_orig[
(frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)]
print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str(
frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here))
##################################################
# hier kommt das mit der tuning kurve
freq2_here_abs = str(int(frame_cell_chosen.f2.unique()))
length = '2'
nfft = '4096'
full_names_tunings = [
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
] # 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
# 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_
ax_us = []
for ft_nr, full_names_tuning in enumerate(full_names_tunings):
if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'):
frame_tuning = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv')
print(full_names_tuning)
frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)]
try:
frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning)
except:
print('something')
embed()
f_fixes = ['df2'] # 'df1',
f_variables = ['df1'] # 'df2',
freqs_fixed = [freq2_here] # freq1_here,
for f_nr, f_fixed in enumerate(freqs_fixed):
indexes = [[0, 1, 2, 3]]#[0, 2],
for i, idx in enumerate(indexes):
grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.15,
wspace=0.15,
subplot_spec=grid_r[ft_nr])
plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first,
color_second, color_third, f_counter, f_fixed, f_fixes, f_nr,
f_variables, frame_cell_orig_tuning, freq1_here, freq2_here,
ft_nr, grid_rr, height, i, letter_plus, letters_all2,
scatter_extra, scores, xlabel_pos = 0)
f_counter += 1
if len(ax_us) > 0:
join_x(ax_us)
join_y(ax_us)
ax_uss.append(ax_us)
#########################################################
if squares:
set_clim_same(ims, clims='all', same='same')
join_y(ax_upper)
join_x(ax_upper)
join_y(ax_upper)
yoffs = np.array([4, 3.5, 3.5, 3.5])
x = -3.5 # ax_uss[0][1],
#embed()
#tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][1]]], xoffs=np.array([x, x, x, x]),
# yoffs=yoffs) # ax_uss[0][1]
#tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][1]]], xoffs=np.array([x, x, x, x]),
# yoffs=yoffs) # , ,ax_uss[1][1]]
save_visualization(cell_here, show)
def vary_contrasts_big_with_tuning3_several0(freqs=[(39.5, -135.5)],
cells_here='2011-10-25-ad-invivo-1'):
plot_style()
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
trials_nrs = [1]
plot_style()
default_figsize(column=2, length=5.35) # 5.5)#7.5 5.75
for _ in trials_nrs: # +[trials_nrs[-1]]
scatter_extra = False
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal']
# 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
# (135.5, 625.0), (110.5, 650.0), (85.5, 675.0),(60.5, 700.0), (35.5, 725.0), (10.5, 750.0),(151.07000000000005, 675.0)
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.095, top=0.95, left=0.1,
right=0.975,
wspace=0.15) # top=0.895
###################################################
squares = False
if squares:
full_names_square = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ]
frame_square = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv')
frame_cell_square = frame_square[(frame_square.cell == cell_here)]
axes = []
axes.append(plt.subplot(grid_s[0]))
axes.append(plt.subplot(grid_s[1]))
axes.append(plt.subplot(grid_s[2]))
frame_cell_square = single_frame_processing(c_grouped, frame_cell_square)
lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05')
plt_cross(matrix, axes[-1])
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
show = True
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
ax_upper = []
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
f = -1
####################################################
# hier kommt die amplituden tuning curve
#
frame_cell = single_frame_processing(c_grouped, frame_cell_orig)
c_heres = [0.1, 0.25] # 0.03,
c_colors = ['dimgrey', 'darkgrey'] # ,'black', ],'silver'
freq1s = np.unique(frame_cell_orig.df1)
freq2s = np.unique(frame_cell_orig.df2)
# np.argmin(frame_cell['amp_B1_01_mean_original'] - frame['amp_B1_012_mean_original'])].c1
f_counter = 0
ax_uss = []
letters_all = [['$\mathrm{A_{ii}}$', '$\mathrm{A_{iii}}$'], ['$\mathrm{B_{ii}}$', '$\mathrm{B_{iii}}$']]
letters_all2 = ['$\mathrm{A_{i}}$', '$\mathrm{B_{i}}$']
for freq1, freq2 in freqs:
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.15, hspace=0.53, subplot_spec=grid0[f_counter],
height_ratios=[1, 2.35]) # hspace=0.35 1, 2.55
grid_u = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.7,
wspace=0.25,
subplot_spec=grid00[
0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_r = gridspec.GridSpecFromSubplotSpec(1, 2,
hspace=0.3,
wspace=0.1,
subplot_spec=grid00[1])
################################################
grid_s = gridspec.GridSpecFromSubplotSpec(1, 3,
hspace=0.7,
wspace=0.45,
subplot_spec=grid00[-1])
freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))]
freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))]
f += 1
print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here))
ax_u1_upper = plt.subplot(grid_u[0])
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here,
sum=False, nr=2, c_dist_recalc=c_dist_recalc,
linestyles=['-', '--', '-', '--', '-'])
ax_u1_upper.set_yticks_delta(100) # set_xticks_delta
ax_u1_upper.set_xlim(0, 35)
c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here,
c_dist_recalc=c_dist_recalc)
height = 355 # 0
letter_plus = 30
if not c_dist_recalc:
c_nrs_here_cm = np.array(c_nrs_here_cm) * 100
try:
ax_u1_upper.scatter(c_nrs_here_cm, height * np.ones(len(c_nrs_here_cm)), color=c_colors,
marker='v',
clip_on=False, s=7)
except:
print('embed something')
embed()
for cn, cnr in enumerate(c_nrs_here_cm):
ax_u1_upper.text(cnr, height + letter_plus, letters_all[f_counter][cn], ha='center',
va='center', color=c_colors[cn])
ax_u1_upper.plot([cnr, cnr], [0, height], color=c_colors[cn], linewidth=lw_tuning(), zorder=100)
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean_original', nr=2)
color_first = 'black' # 'red' # color01
color_second = 'black' # color02
color_third = 'black'
ax_u1_upper.text(0, 1, ' $\Delta f_{' + vary_val() + '}=%s$' % freq1_here + ' Hz' +
'\n $\Delta f_{' + stable_val() + '}=%s$' % (
freq2_here) + '\,Hz, ' + '$c_{' + stable_val() + '}=10\,\%$',
transform=ax_u1_upper.transAxes) # transform
if f_counter != 0:
ax_u1_upper.set_ylabel('')
remove_yticks(ax_u1_upper)
frame_cell_chosen = frame_cell_orig[
(frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)]
print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str(
frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here))
##################################################
# hier kommt das mit der tuning kurve
freq2_here_abs = str(int(frame_cell_chosen.f2.unique()))
length = '2'
nfft = '4096'
full_names_tunings = [
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
] # 'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
# 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_
ax_us = []
for ft_nr, full_names_tuning in enumerate(full_names_tunings):
if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'):
frame_tuning = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv')
print(full_names_tuning)
frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)]
try:
frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning)
except:
print('something')
embed()
f_fixes = ['df2'] # 'df1',
f_variables = ['df1'] # 'df2',
freqs_fixed = [freq2_here] # freq1_here,
for f_nr, f_fixed in enumerate(freqs_fixed):
indexes = [[0, 2], [0, 1, 2, 3]]
for i, idx in enumerate(indexes):
grid_rr = gridspec.GridSpecFromSubplotSpec(2, 1,
hspace=0.15,
wspace=0.15,
subplot_spec=grid_r[ft_nr])
plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first,
color_second, color_third, f_counter, f_fixed, f_fixes, f_nr,
f_variables, frame_cell_orig_tuning, freq1_here, freq2_here,
ft_nr, grid_rr, height, i, letter_plus, letters_all2,
scatter_extra, scores)
f_counter += 1
if len(ax_us) > 0:
join_x(ax_us)
join_y(ax_us)
ax_uss.append(ax_us)
#########################################################
if squares:
set_clim_same(ims, clims='all', same='same')
join_y(ax_upper)
join_x(ax_upper)
join_y(ax_upper)
yoffs = np.array([3, 2.5, 2.5, 2.5])
x = -3.5 # ax_uss[0][1],
tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][2]]], xoffs=np.array([x, x, x, x]),
yoffs=yoffs) # ax_uss[0][1]
tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][2]]], xoffs=np.array([x, x, x, x]),
yoffs=yoffs) # , ,ax_uss[1][1]]
save_visualization(cell_here, show)
def plt_tuning_twobeat(idx, ax_u1_upper, ax_us, c_colors, c_heres, cell_here, color_first, color_second, color_third,
f_counter, f_fixed, f_fixes, f_nr, f_variables, frame_cell_orig_tuning, freq1_here, freq2_here,
ft_nr, grid_rr, height, i, letter_plus, letters_all2, scatter_extra, scores, xlabel_pos = 1):
ax_u1 = plt.subplot(grid_rr[i])
frame_f = plt_tuning_curve(c_heres[ft_nr], ax_u1, frame_cell_orig_tuning, cell_here,
f_fixed,
f_fixed,
f_fixed=f_fixes[f_nr], index=idx,
f_variable=f_variables[f_nr])
if (i == 1) & (ft_nr == 0) & (f_counter == 0):
ax_u1.legend(loc=(0, 2.325), ncol=2) # .legend()
ax_u1.set_title('')
ax_u1.set_yticks_delta(100)
# if i == 0:
# ax_u1.text(0, 1.2, 'one-beat conditions')
# else:
if (i == 0): # (ft_nr == 0) &(f_counter == 0) &
ax_u1.text(0, 1.1, '$c_{1}=%s$' % (str(int(c_heres[ft_nr] * 100))) + '$\%$',
color=c_colors[ft_nr], ha='left',
va='top',
transform=ax_u1.transAxes)
df_extra = False
if df_extra:
if ft_nr == 0:
ax_u1.text(0, 1.15, ' $\Delta f_{' + stable_val() + '}=%s$' % (
freq2_here) + '\,Hz ' + '$ ' + c_stable_name() + '=10 \%$',
color=color_third, ha='left',
va='top',
transform=ax_u1.transAxes) # c_colors[ft_nr]%
# +
ax_u1.set_xlabel(f_variables[f_nr])
ax_u1.set_xlim(-265, 265)
ax_u1.set_ylim(0, 420)
# embed()
frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr],
frame_cell_orig_tuning, f_fixed)
s_big = 25
s_small = 20
# s_big = 25
# s_small = 20
if scatter_extra:
ax_u1.scatter(freq1_here,
frame_f[(frame_f['df2'] == freq2_here) & (
frame_f['df1'] == freq1_here)][
scores[2]], edgecolor=color_second, facecolor='white',
s=s_big,
alpha=0.5, marker='o', clip_on=False, zorder=100)
ax_u1.scatter(freq1_here, frame_f[frame_f['df1'] == freq1_here][scores[0]],
edgecolor=color_first, marker='o', zorder=120, facecolor='white',
s=s_small, alpha=0.5, clip_on=False)
if ft_nr == 0:
ax_u1.scatter(freq2_here, frame_f[
(frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][
scores[2]],
edgecolor=color_third, facecolor='white', alpha=0.5,
marker='o',
clip_on=False, zorder=120, s=s_small)
####################
# scatter to the upper one
frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr],
frame_cell_orig_tuning,
f_fixed)
if scatter_extra:
ax_u1_upper.scatter(c_heres[ft_nr] * 100,
frame_f[(frame_f['df2'] == freq2_here) & (
frame_f['df1'] == freq1_here)][
scores[2]],
edgecolor=color_second, facecolor='white', s=s_big,
alpha=0.5,
marker='o',
clip_on=False, zorder=100)
ax_u1_upper.scatter(c_heres[ft_nr] * 100,
frame_f[frame_f['df1'] == freq1_here][scores[0]],
edgecolor=color_first,
marker='o', zorder=120, facecolor='white', s=s_small,
alpha=0.5,
clip_on=False)
#############################
add = -5
# if f_counter == 0:
# if i == 1:
ax_u1.scatter(freq1_here + add, height,
color=c_colors[ft_nr], marker='v', clip_on=False, s=7)
ax_u1.plot([freq1_here + add, freq1_here + add], [0, height], color=c_colors[ft_nr],
linewidth=lw_tuning(), zorder=100)
ax_u1.text(freq1_here + add, height + letter_plus, letters_all2[f_counter],
ha='center', color=c_colors[ft_nr],
va='center')
# ax_u1.scatter(freq1_here, [0], color=c_colors[ft_nr],
# marker='^',
# clip_on=False, s=5)
# ft_nr
if (f_counter == 0) & (ft_nr == 0): # f_counter == 0:f_counter
ax_u1.set_ylabel(representation_ylabel())
else:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
if i in [xlabel_pos]: # f_counter
ax_u1.set_xlabel(xlabel_vary()) # ax_upper.set_xlabel(xlabel_vary())
else:
ax_u1.set_xlabel('')
remove_xticks(ax_u1)
ax_us.append(ax_u1)
def lw_tuning():
return 0.55
def vary_contrasts_big_with_tuning3_several(freqs=[(39.5, -135.5)],
cells_here='2011-10-25-ad-invivo-1'):
default_settings() # ts=13, ls=13, fs=13, lw = 0.7
plot_style()
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
trials_nrs = [1]
plot_style()
default_settings(column=2, length=7.5)
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal']
# 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
grid0 = gridspec.GridSpec(1, len(freqs), bottom=0.1, top=0.87, left=0.09,
right=0.95,
wspace=0.3) #
###################################################
# squares
squares = False
if squares:
full_names_square = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ]
frame_square = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv')
frame_cell_square = frame_square[(frame_square.cell == cell_here)]
axes = []
axes.append(plt.subplot(grid_s[0]))
axes.append(plt.subplot(grid_s[1]))
axes.append(plt.subplot(grid_s[2]))
frame_cell_square = single_frame_processing(c_grouped, frame_cell_square)
lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05')
plt_cross(matrix, axes[-1])
#################################################################
show = True
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
ax_upper = []
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
f = -1
####################################################
# hier kommt die amplituden tuning curve
#
frame_cell = single_frame_processing(c_grouped, frame_cell_orig)
c_heres = [0.1, 0.25, 0.5] # 0.03,
c_colors = ['black', 'darkgrey', 'silver']
freq1s = np.unique(frame_cell_orig.df1)
freq2s = np.unique(frame_cell_orig.df2)
f_counter = 0
ax_uss = []
for freq1, freq2 in freqs:
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.35, hspace=0.35, subplot_spec=grid0[f_counter],
height_ratios=[1, 3.55]) #
grid_u = gridspec.GridSpecFromSubplotSpec(1, 1,
hspace=0.7,
wspace=0.25,
subplot_spec=grid00[
0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_r = gridspec.GridSpecFromSubplotSpec(3, 1,
hspace=0.3,
wspace=0.25,
subplot_spec=grid00[1])
################################################
grid_s = gridspec.GridSpecFromSubplotSpec(1, 3,
hspace=0.7,
wspace=0.45,
subplot_spec=grid00[-1])
freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))]
freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))]
f += 1
print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here))
ax_u1_upper = plt.subplot(grid_u[0])
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_upper = plt_single_trace(ax_upper, ax_u1_upper, frame_cell_orig, freq1_here, freq2_here,
sum=False, c_dist_recalc=c_dist_recalc,
linestyles=['-', '--', '-', '--', '-'])
c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here,
c_dist_recalc=c_dist_recalc)
lw = 0.75
if not c_dist_recalc:
c_nrs_here_cm = np.array(c_nrs_here_cm) * 100
try:
ax_u1_upper.scatter(c_nrs_here_cm, np.zeros(len(c_nrs_here_cm)), color=c_colors, marker='^',
clip_on=False, s=5)
except:
print('embed something')
embed()
for m in range(len(c_nrs_here_cm)):
ax_u1_upper.axvline(c_nrs_here_cm[m], color=c_colors[m], linewidth=lw)
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean_original')
color_first = 'red' # color01
color_second = 'purple' # color02
color_third = 'darkblue'
rainbow_title(plt.gcf(), ax_u1_upper, [' $\Delta f_{s}=%s$' % freq2_here + ' Hz',
' $\Delta f_{p}=%s$' % freq1_here + ' Hz',
'$c_{2}=10\%$'],
[[color_second, color_first, 'black']], start_xpos=0, ha='left', y_pos=1.02)
if f_counter != 0:
ax_u1_upper.set_ylabel('')
remove_yticks(ax_u1_upper)
if f_counter == 0:
ax_u1_upper.legend(loc=(0, 1.25), ncol=2)
frame_cell_chosen = frame_cell_orig[
(frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)]
print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str(
frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here))
# hier kommt das mit der tuning kurve
freq2_here_abs = str(int(frame_cell_chosen.f2.unique()))
length = '2'
nfft = '4096'
full_names_tunings = [
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
]
# 'calc_model_amp_freqs-F1_500-1495-5_F2_725_C2_0.1_C1_0.5_StimLen_2_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal' #_burst_added1_
ax_us = []
for ft_nr, full_names_tuning in enumerate(full_names_tunings):
if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'):
frame_tuning = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv')
print(full_names_tuning)
frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)]
try:
frame_cell_orig_tuning = single_frame_processing(c_grouped, frame_cell_orig_tuning)
except:
print('something')
embed()
f_fixes = ['df2'] # 'df1',
f_variables = ['df1'] # 'df2',
freqs_fixed = [freq2_here] # freq1_here,
for f_nr, f_fixed in enumerate(freqs_fixed):
ax_u1 = plt.subplot(grid_r[ft_nr])
ax_u1.set_title('')
ax_u1.text(1, 1.15, '$c_{1}=%s$' % (str(int(c_heres[ft_nr] * 100))) + '$\%$',
color=color01, ha='right',
va='top',
transform=ax_u1.transAxes)
if ft_nr == 0:
ax_u1.text(0, 1.15, ' $\Delta ' + f_stable_name() + '=%s$' % (
freq2_here) + '\,Hz ' + '$ ' + c_stable_name() + '=10 \%$',
color=color_third, ha='left',
va='top',
transform=ax_u1.transAxes) # c_colors[ft_nr]%
ax_u1.set_xlabel(f_variables[f_nr])
ax_u1.set_xlim(-300, 300)
ax_u1.set_ylim(0, 420)
frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr],
frame_cell_orig_tuning, f_fixed)
s_big = 25
s_small = 20
ax_u1.scatter(freq1_here,
frame_f[(frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][
scores[2]], edgecolor=color_second, facecolor='white', s=s_big,
alpha=0.5, marker='o', clip_on=False, zorder=100)
ax_u1.scatter(freq1_here, frame_f[frame_f['df1'] == freq1_here][scores[0]],
edgecolor=color_first, marker='o', zorder=120, facecolor='white',
s=s_small, alpha=0.5, clip_on=False)
if ft_nr == 0:
ax_u1.scatter(freq2_here, frame_f[
(frame_f['df2'] == freq2_here) & (frame_f['df1'] == freq1_here)][scores[2]],
edgecolor=color_third, facecolor='white', alpha=0.5, marker='o',
clip_on=False, zorder=120, s=s_small)
####################
# scatter to the upper one
frame_f = frame_f_reference(c_heres[ft_nr], cell_here, f_fixes[f_nr],
frame_cell_orig_tuning,
f_fixed)
ax_u1_upper.scatter(c_heres[ft_nr] * 100,
frame_f[(frame_f['df2'] == freq2_here) & (
frame_f['df1'] == freq1_here)][
scores[2]],
edgecolor=color_second, facecolor='white', s=s_big, alpha=0.5,
marker='o',
clip_on=False, zorder=100)
ax_u1_upper.scatter(c_heres[ft_nr] * 100,
frame_f[frame_f['df1'] == freq1_here][scores[0]],
edgecolor=color_first,
marker='o', zorder=120, facecolor='white', s=s_small, alpha=0.5,
clip_on=False)
#############################
ax_u1.axvline(freq1_here, color=c_colors[ft_nr], linewidth=lw)
ax_u1.scatter(freq1_here, [0], color=c_colors[ft_nr],
marker='^',
clip_on=False, s=5)
if f_counter == 0:
ax_u1.set_ylabel(representation_ylabel())
else:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
if ft_nr in [2]:
ax_u1.set_xlabel(xlabel_vary()) # ax_upper.set_xlabel(xlabel_vary())
else:
ax_u1.set_xlabel('')
remove_xticks(ax_u1)
ax_us.append(ax_u1)
f_counter += 1
if len(ax_us) > 0:
join_x(ax_us)
join_y(ax_us)
ax_uss.append(ax_us)
#########################################################
if squares:
set_clim_same(ims, clims='all', same='same')
join_y(ax_upper)
join_x(ax_upper)
join_y(ax_upper)
yoffs = np.array([2, 2.5, 2.5, 2.5])
x = -3
tag2(plt.gcf(), [[ax_upper[0], ax_uss[0][0], ax_uss[0][1], ax_uss[0][2]]], xoffs=np.array([x, x, x, x]),
yoffs=yoffs)
tag2(plt.gcf(), [[ax_upper[4], ax_uss[1][0], ax_uss[1][1], ax_uss[1][2]]], xoffs=np.array([x, x, x, x]),
yoffs=yoffs)
save_visualization(cell_here, show)
def dist_recalc_phaselockingchapter():
c_dist_recalc = False
return c_dist_recalc
def vary_contrasts_big_with_tuning3(freqs=[(39.5, -135.5)], cells_here='2011-10-25-ad-invivo-1'):
default_settings() # ts=13, ls=13, fs=13, lw = 0.7
plot_style()
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
trials_nrs = [1]
plot_style()
default_settings(column=2, length=6.5)
for _ in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_0.5-1.45-0.05_F2_0.5-1.45-0.05_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_250-1325-25_F2_720_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_2.0_mult__start_0.0001_end_2_StimLen_5_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
# full_names = ['calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_originalAUCItemporal']
# 'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2Len_25_FirstC2_0.0001_LastC2_1.0_C1_0.1_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
grid0 = gridspec.GridSpec(1, 1, bottom=0.1, top=0.87, left=0.09,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.2, hspace=0.6, subplot_spec=grid0[0],
height_ratios=[1, 2]) #
grid_u = gridspec.GridSpecFromSubplotSpec(1, len(freqs),
hspace=0.7,
wspace=0.25,
subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_r = gridspec.GridSpecFromSubplotSpec(2, 2,
hspace=0.15,
wspace=0.25,
subplot_spec=grid00[1])
grid_s = gridspec.GridSpecFromSubplotSpec(1, 3,
hspace=0.7,
wspace=0.45,
subplot_spec=grid00[-1])
###################################################
# squares
squares = False
if squares:
full_names_square = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_2_nfft_32768_trialsnr_1_absolut_power_1_minamps__dev_05temporal', ]
frame_square = pd.read_csv(
load_folder_name('calc_cocktailparty') + '/' + full_names_square[0] + '.csv')
frame_cell_square = frame_square[(frame_square.cell == cell_here)]
axes = []
axes.append(plt.subplot(grid_s[0]))
axes.append(plt.subplot(grid_s[1]))
axes.append(plt.subplot(grid_s[2]))
frame_cell_square = single_frame_processing(c_grouped, frame_cell_square)
lim, matrix, ss, ims = plt_matrix_saturation_loss(axes, frame_cell_square, add='_05')
plt_cross(matrix, axes[-1])
#################################################################
show = True
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
ax_upper = []
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
f = -1
####################################################
# hier kommt die amplituden tuning curve
ax_upper, nfft, ax_us = amplitude_tuning_curve(ax_upper, c_grouped, cell_here, f, frame_cell_orig,
freqs, grid_r, grid_u)
ax_upper[0].legend(loc=(0, 1.4), ncol=6) # , f_fixed,
if squares:
set_clim_same(ims, clims='all', same='same')
join_y(ax_upper)
join_x(ax_upper)
join_y(ax_upper)
save_visualization(cell_here, show)
def amplitude_tuning_curve(ax_upper, c_grouped, cell_here, f, frame_cell_orig,
freqs, grid_r, grid_u):
##################################################
#
frame_cell = single_frame_processing(c_grouped, frame_cell_orig)
c_heres = [0.03, 0.1, 0.25, 0.5]
c_colors = ['black', 'darkgrey', 'silver', 'lightgrey']
freq1s = np.unique(frame_cell_orig.df1)
freq2s = np.unique(frame_cell_orig.df2)
for freq1, freq2 in freqs:
freq1_here = freq1s[np.argmin(np.abs(freq1s - freq1))]
freq2_here = freq2s[np.argmin(np.abs(freq2s - freq2))]
f += 1
print(cell_here + ' F1' + str(freq1_here) + ' F2 ' + str(freq2_here))
ax_u1 = plt.subplot(grid_u[0, f])
ax_upper = plt_single_trace(ax_upper, ax_u1, frame_cell_orig, freq1_here, freq2_here,
sum=False, linestyles=['-', '--', '-', '--', '-'])
c_nrs_here_cm = c_dist_recalc_func(frame_cell, c_nrs=c_heres, cell=cell_here)
ax_u1.scatter(c_nrs_here_cm, np.zeros(len(c_nrs_here_cm)), color=c_colors, marker='^', clip_on=False)
plt.suptitle(cell_here)
ax_u1.set_title(' $\Delta f_{1}=%s$' % freq1_here + ' Hz $\Delta f_{2}=%s$' % freq2_here + ' Hz')
ax_upper[-1].legend(loc=(0, 0.9), ncol=4)
frame_cell_chosen = frame_cell_orig[(frame_cell_orig.df1 == freq1_here) & (frame_cell_orig.df2 == freq2_here)]
print('Tuning curve needed for F1' + str(frame_cell_chosen.f1.unique()) + ' F2' + str(
frame_cell_chosen.f2.unique()) + ' for cell ' + str(cell_here))
# hier kommt das mit der tuning kurve
freq2_here_abs = str(int(frame_cell_chosen.f2.unique()))
length = '2'
nfft = '4096'
full_names_tunings = [
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.03_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.1_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.25_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
'calc_model_amp_freqs-F1_500-1495-5_F2_' + freq2_here_abs + '_C2_0.1_C1_0.5_StimLen_' + length + '_nfft_' + nfft + '_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal',
]
ax_us = []
for ft_nr, full_names_tuning in enumerate(full_names_tunings):
if os.path.exists(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv'):
frame_tuning = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names_tuning + '.csv')
print(full_names_tuning)
frame_cell_orig_tuning = frame_tuning[(frame_tuning.cell == cell_here)]
try:
pass
except:
print('something')
embed()
f_variables = ['df1'] # 'df2',
freqs_fixed = [freq2_here] # freq1_here,
for f_nr, f_fixed in enumerate(freqs_fixed):
ax_u1 = plt.subplot(grid_r[ft_nr])
plt_tuning_curve(ax_u1, f_fixed)
ax_u1.set_title('')
ax_u1.text(1, 1, '$c=%s$' % (c_heres[ft_nr]), color=c_colors[ft_nr], ha='right', va='top',
transform=ax_u1.transAxes)
ax_u1.set_xlabel(f_variables[f_nr])
ax_u1.set_xlim(-300, 300)
ax_u1.scatter(freq1_here, 1, color='green', marker='^')
ax_u1.scatter(freq1_here, 1, color='red', marker='^')
if ft_nr in [0, 2]:
ax_u1.set_ylabel('Peak Amp. [Hz]')
else:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
if ft_nr in [2, 3]:
ax_u1.set_xlabel('$\Delta f_{1}$ [Hz]')
else:
ax_u1.set_xlabel('')
remove_xticks(ax_u1)
ax_us.append(ax_u1)
if len(ax_us) > 0:
join_x(ax_us)
join_y(ax_us)
return ax_upper, nfft, ax_us
def single_frame_processing(c_grouped, frame_cell):
frame_cell = area_vs_single_peaks_frame(frame_cell)
frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell)
diffs = find_deltas(frame_cell, c_grouped[0])
frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original')
#new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score']
#matrix = new_frame.pivot(index='df2', columns='df1', values='diff')
return frame_cell
def plt_tuning_curve(c_here, ax, frame_cell, cell, freq2, dfs, f_fixed='f2', f_variable='f1', index=[0, 1, 2, 3]):
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean_original', nr=1)
frame_f = frame_f_reference(c_here, cell, f_fixed, frame_cell, freq2)
#
try: # df2s
ax.set_title(' DF2=' + str(int(dfs)) + ' Hz', fontsize=10) # , fontsize=7
except:
print('f1 f2 problem')
for sss in index:
test = False # False
if test:
print('example')
ax.scatter(np.array(frame_f[f_variable]), frame_f[score[sss]], zorder=100,
linestyle=np.array(linestyles)[sss], color=np.array(colors)[sss],
label=np.array(labels)[sss], alpha=np.array(alpha)[sss], s=3,
linewidths=np.array(linewidths)[sss]) # , color = colors[sss],not found
try:
ax.plot(np.array(frame_f[f_variable]), frame_f[scores[sss]], zorder=100,
linestyle=np.array(linestyles)[sss], color=np.array(colors)[sss],
linewidth=np.array(linewidths)[sss],
label=np.array(labels)[sss], alpha=np.array(alpha)[sss]) # , color = colors[sss],
except: # - np.array(frame_f.f0)
print('f1 thing')
embed()
return frame_f
def frame_f_reference(c_here, cell, f_fixed, frame_cell, freq2):
frame_cell = frame_cell[frame_cell['c1'] == c_here]
frame_f = frame_cell[(frame_cell.cell == cell) & (frame_cell[f_fixed] == freq2)]
frame_f = frame_f[frame_f.f1 != frame_f.f2]
frame_f = frame_f[np.abs(frame_f.f1) != np.abs(frame_f.f2)]
frame_f = frame_f[np.abs(frame_f.df1) != np.abs(frame_f.df2)]
df_extra = True
if df_extra:
# das machen wir weil sonst kriegen wir da resonanz und die peaks sind sehr stark
confidence = 10
frame_f = frame_f[np.abs(np.abs(frame_f.df1) - np.abs(frame_f.df2)) > confidence]
return frame_f
def plt_show_nonlin_effect_didactic_final2_only(min=0.2, cells=[], single_waves=['_SingleWave_', '_SeveralWave_', ],
cell_start=13,
a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1],
add_half=0, show=False, nfft=int(2 ** 15), gain=1, us_name=''):
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells) < 1:
cells = model_cells.cell.loc[range(cell_start, len(model_cells))]
plot_style()
for cell in cells:
# sachen die ich variieren will
###########################################
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
trials_nr = 15 # 150
titles_amp = ['base eodf', 'baseline to Zero', ]
for a, a_fr in enumerate(a_frs):
default_figsize(column=2, length=3.5)
grid = gridspec.GridSpec(3, 1, wspace=0.35, left=0.095, hspace=0.3, top=0.95, bottom=0.15,
right=0.98)
ax = {}
vmem = False
for aa, a_f1 in enumerate(a_f1s):
SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output = outputmodel(
a_fr, add_half, cell, model_cells, single_wave, trials_nr)
ax[1] = plt.subplot(grid[0])
ax[1].show_spines('l')
ax[1].set_ylabel('$s(t)$')
ax[2] = plt.subplot(grid[1])
ax[2].show_spines('l')
ax[2].set_ylabel('Repeat Nr.')
ax[3] = plt.subplot(grid[2])
ax[3].show_spines('lb')
power_extra = False
if power_extra:
ax[4] = plt.subplot(grid[:, 1])
ax[4].show_spines('lb')
ax[1].set_xlim(0, xlim_here())
ax[2].set_xlim(0, xlim_here())
ax[3].set_xlim(0, xlim_here())
_, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array)
_, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0)
isi = np.diff(spike_adapted[0])
cv0 = np.std(isi) / np.mean(isi)
for ff, freq1 in enumerate(freqs1):
print('freq1' + str(freq1 - eod_fr))
print('freq2' + str(freqs2[ff] - eod_fr))
print('a_f1' + str(a_f1))
print('a_f2' + str(freqs2[ff]))
freq1 = [freq1]
freq2 = [freqs2[ff]]
beat1 = freq1 - eod_fr
titles = False
if titles:
plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str(
np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str(
beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str(
np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(
a_f2) + ' tr_nr ' + str(trials_nr))
_, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1)
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2)
eod_stimulus = eod_fish1 + eod_fish2
for t in range(trials_nr):
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1)
# damping variants
std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(
stimulus, time_array, damping_type, eod_fr, gain, damping, us_name, plot=False,
std_dump=0, max_dump=0, range_dump=0)
stimuli.append(stimulus)
cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \
rate_baseline_after[t], spikes[t], \
stimulus_altered[t], \
v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset,
stimulus,
adaptation_yes_e=f1,
**model_params)
spikes_mat = [[]] * len(spikes)
pps = [[]] * len(spikes)
for s in range(len(spikes)):
spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat))
pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
pp_mean = np.mean(pps, axis=0)
sampling_rate = 1 / deltat
smoothed05 = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate)
mat05 = np.mean(smoothed05, axis=0)
beat1 = (freq1 - eod_fr)[0]
beat2 = (freq2 - eod_fr)[0]
if 'Several' in single_wave:
freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(beat2 + beat1),
] # np.abs(beat2 - beat1)
colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all()
colors = [color_01, color_02, color_012] # 'blue'
labels = ['intruder', 'female', 'intruder+female'] # , '|B1-B2|'
else:
freqs_beat = [np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3),
np.abs(beat1 * 4)] # np.abs(beat1) / 2,
colors = colors_didactic()
labels = labels_didactic() # colors_didactic, labels_didactic
if 'Several' in single_wave:
color_beat = 'black'
else:
color_beat = 'black'
if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1):
stim_redo = True
if stim_redo:
eod_interp = np.cos(time_array * beat1 * 2 * np.pi) + 1
else:
eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if (titles_amp[a] != 'baseline to Zero') and not (
(a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)):
ax[1].plot((time_array - min) * 1000, eod_interp - 1, color=color_beat,
clip_on=True)
ax[1].set_ylim(np.min(eod_interp - 1) * 1.05, np.max(eod_interp - 1) * 1.05)
for l in range(len(spikes)):
spikes[l] = (spikes[l] - min) * 1000
if vmem:
ax[0].plot((time_array - min) * 1000, v_mem_output[0], color='black')
ax[0].eventplot(np.array(spikes[0]), lineoffsets=np.max(v_mem_output[0]), color='black')
ax[0].set_xlim([0, 350])
ax[2].eventplot(np.array(spikes), color='black')
ax[3].plot((time_array - min) * 1000, mat05, color='black')
power_extra = False
if power_extra:
pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
log = 'log'
if log:
pp_mean = calc_log(pp_mean)
plt_peaks_several(freqs_beat, pp_mean, ax[4], pp_mean, f, labels, 0, colors,
add_log=2.5, exact=False, text_extra=True, perc_peaksize=0.2,
rel='rel', ms=14,
clip_on=True, log=log) # True
ax[4].plot(f, pp_mean, color='black', zorder=0)
ax[4].set_xlim([0, 350])
test = False
if test:
from utils_test import test_spikes_clusters
test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000)
ax[1].set_xticks([])
ax[2].set_xticks([])
ax[3].set_ylabel('Firing Rate [Hz]')
ax[3].set_xlabel('Time [ms]')
ax[1].set_xticks([])
ax[2].set_xticks([])
fig = plt.gcf()
fig.tag(fig.axes, xoffs=-6, yoffs=1.3)
plt.subplots_adjust(top=0.7, left=0.15, right=0.95, hspace=0.75, wspace=0.1)
individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half)
save_visualization(individual_tag, show, counter_contrast=0, savename='')
def calc_log(pp_mean):
pp_mean = 10 * np.log10(pp_mean / np.max(pp_mean))
return pp_mean
def plt_show_nonlin_effect_didactic_final2(min=0.2, cells=[], single_waves=['_SingleWave_', '_SeveralWave_', ],
cell_start=13,
a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1],
add_half=0, xlim=[0, 350], show=False, nfft=int(2 ** 15), gain=1,
us_name=''):
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells) < 1:
cells = model_cells.cell.loc[range(cell_start, len(model_cells))]
plot_style()
for cell in cells:
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
trials_nr = 15 # 150
titles_amp = ['base eodf', 'baseline to Zero', ]
for a, a_fr in enumerate(a_frs):
default_figsize(column=2, length=2.3) # 3
grid = gridspec.GridSpec(3, 2, wspace=0.35, left=0.095, hspace=0.2, top=0.94, bottom=0.25,
right=0.95)
ax = {}
for aa, a_f1 in enumerate(a_f1s):
SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output = outputmodel(
a_fr, add_half, cell, model_cells, single_wave, trials_nr)
ax[1] = plt.subplot(grid[0])
ax[1].show_spines('')
ax[2] = plt.subplot(grid[2])
ax[2].show_spines('')
ax[3] = plt.subplot(grid[4])
ax[3].show_spines('lb')
ax[4] = plt.subplot(grid[:, 1])
ax[4].show_spines('lb')
ax[1].set_xlim(0, xlim_here())
ax[2].set_xlim(0, xlim_here())
ax[3].set_xlim(0, xlim_here())
_, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array)
_, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0)
isi = np.diff(spike_adapted[0])
cv0 = np.std(isi) / np.mean(isi)
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
freq2 = [freqs2[ff]]
beat1 = freq1 - eod_fr
titles = False
if titles:
plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str(
np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str(
beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str(
np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(
a_f2) + ' tr_nr ' + str(trials_nr))
_, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1)
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2)
eod_stimulus = eod_fish1 + eod_fish2
for t in range(trials_nr):
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1)
std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(
stimulus, time_array, damping_type, eod_fr, gain, damping, us_name, plot=False,
std_dump=0, max_dump=0, range_dump=0)
stimuli.append(stimulus)
cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \
rate_baseline_after[t], spikes[t], \
stimulus_altered[t], \
v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset,
stimulus,
adaptation_yes_e=f1,
**model_params)
spikes_mat = [[]] * len(spikes)
pps = [[]] * len(spikes)
for s in range(len(spikes)):
spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat))
pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
pp_mean = np.mean(pps, axis=0)
sampling_rate = 1 / deltat
smoothed05 = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate)
mat05 = np.mean(smoothed05, axis=0)
beat1 = (freq1 - eod_fr)[0]
beat2 = (freq2 - eod_fr)[0]
nr = 2
if 'Several' in single_wave:
freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(np.abs(beat2) + np.abs(beat1))
] # np.abs(beat2 - beat1),np.abs(beat2 + beat1),
colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all()
colors = [color_01, color_02, color_012] # 'blue'
labels = ['$f_{1}=%d$' % beat1 + '\,Hz', '$f_{2}=%d$' % beat2 + '\,Hz',
'$f_{1} + f_{2}=f'+basename()+'=%d$' % (
beat1 + beat2 - 1) + '\,Hz'] # small , '|B1-B2|'
add_texts = [nr, nr + 0.35, nr + 0.2] # [1.1,1.1,1.1]
texts_left = [-7, -7, -7, -7]
else:
freqs_beat = [np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3),
np.abs(beat1 * 4)] # np.abs(beat1) / 2,
colors = colors_didactic()
add_texts = [nr + 0.1, nr + 0.1, nr + 0.1, nr + 0.1] # [1.1,1.1,1.1,1.1]
texts_left = [3, 0, 0, 0]
labels = labels_didactic2() # colors_didactic, labels_didactic
if 'Several' in single_wave:
color_beat = 'black'
else:
color_beat = colors[0]
if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1):
eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if (titles_amp[a] != 'baseline to Zero') and not (
(a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)):
ax[1].plot((time_array - min) * 1000, eod_interp, color=color_beat, clip_on=True)
ax[1].set_ylim(np.min(eod_interp) * 0.98, np.max(eod_interp) * 1.02)
for l in range(len(spikes)):
spikes[l] = (spikes[l] - min) * 1000
ax[2].eventplot(np.array(spikes), color='black')
ax[3].plot((time_array - min) * 1000, mat05, color='black')
pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
log = 'log'
if log:
pp_mean = 10 * np.log10(pp_mean / np.max(pp_mean))
print(freqs_beat)
print(labels)
plt_peaks_several(freqs_beat, pp_mean, ax[4], pp_mean, f, labels, 0, colors, ha='center',
add_texts=add_texts, texts_left=texts_left, add_log=2.5,
rots=[0, 0, 0, 0], exact=False, text_extra=True, perc_peaksize=5,
rel='rel', ms=14,
clip_on=True, several_peaks=True, log=log) # True
ax[4].plot(f, pp_mean, color='black', zorder=0) # 0.45
ax[4].set_xlim(xlim)
test = False
if test:
from utils_test import test_spikes_clusters
test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000)
ax[1].set_xticks([])
ax[2].set_xticks([])
ax[1].set_ylabel('Beat')
ax[2].set_ylabel('Spikes')
ax[3].set_ylabel('Firing Rate [Hz]')
if log == 'log':
ax[4].set_ylabel('dB')
else:
ax[4].set_ylabel('Amplitude [Hz]')
ax[4].set_xlabel('Frequency [Hz]')
ax[3].set_xlabel('Time [ms]')
ax[1].set_xticks([])
ax[2].set_xticks([])
fig = plt.gcf()
tag2(fig=fig, xoffs=[-4.5, -4.5, -4.5, -5.5], yoffs=1.25)
plt.subplots_adjust(top=0.6, left=0.15, right=0.95, hspace=0.5, wspace=0.1)
individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half)
save_visualization(individual_tag, show, counter_contrast=0, savename='')
def outputmodel(a_fr, add_half, cell, model_cells, single_wave, trials_nr, freqs_mult1=None, freqs_mult2=None):
try:
model_params = model_cells[model_cells['cell'] == cell].iloc[0]
except:
print('model extract something')
embed()
eod_fr = model_params['EODf']
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
f1 = 0
f2 = 0
sampling_factor = ''
stimulus_length = 1
phaseshift_fr = 0
phase_right = '_phaseright_'
adapt_offset = 'adaptoffsetallall2'
SAM = '' # ,
damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35
damping_type = ''
exponential = ''
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr)
sampling = 1 / deltat
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
# now we are ready for the final modeling part in this function
rate_adapted = [[]] * trials_nr
rate_baseline_before = [[]] * trials_nr
rate_baseline_after = [[]] * trials_nr
spikes = [[]] * trials_nr
v_dent_output = [[]] * trials_nr
stimulus_altered = [[]] * trials_nr
v_mem_output = [[]] * trials_nr
spike_adapted = [[]] * trials_nr
stimuli = []
offset, spike_adapted = calc_the_model_spikes(a_fr, adapt_offset, cell, deltat, eod_fish_r, f1,
f2, model_params, offset, spike_adapted,
trials_nr)
base_cut, mat_base = find_base_fr(spike_adapted, deltat, stimulus_length, time_array)
fr = np.mean(base_cut)
if freqs_mult1:
freqs1 = [eod_fr + fr * freqs_mult1]
freqs2 = [eod_fr + fr * freqs_mult2]
else:
if 'Several' in single_wave:
if 'Sum' in single_wave:
freqs1 = [eod_fr + fr * 0.3]
freqs2 = [eod_fr + fr * 0.7]
else:
freqs1 = [eod_fr - fr / 2 + add_half]
freqs2 = [0] * len(freqs1)
return SAM, cell, damping, damping_type, deltat, eod_fish_r, eod_fr, f1, f2, freqs1, freqs2, model_params, offset, phase_right, phaseshift_fr, rate_adapted, rate_baseline_after, rate_baseline_before, sampling, spike_adapted, spikes, stimuli, stimulus_altered, stimulus_length, time_array, v_dent_output, v_mem_output
def xlim_here(): # 075
return 0.1 * 1000
def calc_the_model_spikes(a_fr, adapt_offset, cell, deltat, eod_fish_r, f1, f2, model_params, offset, spike_adapted,
trials_nr, add=0, dent_tau_change=1, constant_reduction=1, n=1, exp_tau=1, exponential='',
lower_tol=0.995, plus=1, sig_val=1, slope=1, v_exp=1, zeros='zeros', upper_tol=1.005):
for t in range(trials_nr):
# get the baseline properties here
# baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
if a_fr == 0:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r, deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2, adaptation_yes_e=f1,
adaptation_yes_t=t,
adaptation_upper_tol=upper_tol,
adaptation_lower_tol=lower_tol,
power_variant=power_here, power_alpha=alpha,
power_nr=n, tau_change_choice=constant_reduction,
tau_change_val=dent_tau_change, sigmoidal_mult=1,
sigmoidal_plus=plus, sigmoidal_slope=slope,
sigmoidal_add=add,
sigmoidal_sigmoidal_val=sig_val,
LIF_exponential=exponential,
LIF_exponential_tau=exp_tau,
LIF_expontential__v=v_exp, **model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
return offset, spike_adapted
def plt_show_nonlin_effect_didactic(min=0.2, text='text', cells=[], add_pp=50,
single_waves=['_SingleWave_', '_SeveralWave_', ],
cell_start=13, zeros='zeros', a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ]
, a_frs=[1], add_half=0, show=False, nfft=int(2 ** 15)):
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells) < 1:
cells = model_cells.cell.loc[range(cell_start, len(model_cells))]
for cell in cells:
###########################################
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
trials_nr = 150
titles_amp = ['base eodf', 'baseline to Zero', ]
for a, a_fr in enumerate(a_frs):
grid = gridspec.GridSpec(4, 2, wspace=0.2, left=0.05, top=0.8, bottom=0.15,
right=0.98)
ax = {}
for aa, a_f1 in enumerate(a_f1s):
try:
model_params = model_cells[model_cells['cell'] == cell].iloc[0]
except:
print('model extract something')
embed()
eod_fr = model_params['EODf']
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
f1 = 0
f2 = 0
sampling_factor = ''
stimulus_length = 1
phaseshift_fr = 0
phase_right = '_phaseright_'
adapt_offset = 'adaptoffsetallall2'
n = 1
lower_tol = 0.995
upper_tol = 1.005
SAM = '' # ,
exponential = ''
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr)
# now we are ready for the final modeling part in this function
rate_adapted = [[]] * trials_nr
rate_baseline_before = [[]] * trials_nr
rate_baseline_after = [[]] * trials_nr
spikes = [[]] * trials_nr
v_dent_output = [[]] * trials_nr
stimulus_altered = [[]] * trials_nr
v_mem_output = [[]] * trials_nr
spike_adapted = [[]] * trials_nr
stimuli = []
for t in range(trials_nr):
# get the baseline properties here
# baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
if a_fr == 0:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
adaptation_upper_tol=upper_tol,
adaptation_lower_tol=lower_tol,
power_variant=power_here,
power_alpha=alpha, power_nr=n,
**model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
base_cut, mat_base = find_base_fr(spike_adapted, deltat, stimulus_length, time_array)
fr = np.mean(base_cut)
titles = ['']
if 'Several' in single_wave:
if 'Sum' in single_wave:
freqs1 = [eod_fr + fr * 0.3]
freqs2 = [eod_fr + fr * 0.7]
else:
freqs1 = [eod_fr - fr / 2 + add_half]
freqs2 = [0] * len(freqs1)
ax[0] = plt.subplot(grid[0])
ax[1] = plt.subplot(grid[2])
ax[2] = plt.subplot(grid[4])
ax[3] = plt.subplot(grid[6])
ax[4] = plt.subplot(grid[:, 1])
ax[0].set_xlim(0, 0.125 * 1000) # 0.1 * 1000
ax[1].set_xlim(0, 0.125 * 1000)
ax[2].set_xlim(0, 0.125 * 1000)
ax[3].set_xlim(0, 0.125 * 1000)
_, _ = find_base_fr(spike_adapted, deltat, stimulus_length, time_array)
_, _ = ISI_frequency(time_array, spike_adapted[0], fill=0.0)
isi = np.diff(spike_adapted[0])
cv0 = np.std(isi) / np.mean(isi)
fs = 11
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
freq2 = [freqs2[ff]]
beat1 = freq1 - eod_fr
plt.suptitle('diverging from half fr by ' + str(add_half) + ' f1:' + str(
np.round(freq1)[0]) + ' f2:' + str(np.round(freq2)[0]) + ' Hz \n' + str(
beat1) + ' Hz Beat\n' + titles[ff] + titles_amp[a] + ' ' + cell + ' cv ' + str(
np.round(cv0, 3)) + '_a_f0_' + str(a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(
a_f2) + ' tr_nr ' + str(trials_nr))
_, _ = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1)
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2)
eod_stimulus = eod_fish1 + eod_fish2
for t in range(trials_nr):
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1)
std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(
stimulus, time_array)
stimuli.append(stimulus)
cvs, adapt_output, baseline_after, _, rate_adapted[t], rate_baseline_before[t], \
rate_baseline_after[t], spikes[t], \
stimulus_altered[t], \
v_dent_output[t], offset_new, v_mem_output[t], noise_final = simulate(cell, offset,
stimulus,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
adaptation_upper_tol=upper_tol,
adaptation_lower_tol=lower_tol,
power_variant=variant,
power_alpha=alpha,
power_nr=n,
**model_params)
spikes_mat = [[]] * len(spikes)
pps = [[]] * len(spikes)
for s in range(len(spikes)):
spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat, int(stimulus_length * 1 / deltat))
pps[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
pp_mean = np.mean(pps, axis=0)
sampling_rate = 1 / deltat
smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate)
mat05 = np.mean(smoothed05, axis=0)
ax[0].set_title('a_f1 ' + str(a_f1), fontsize=fs)
ax[0].plot((time_array - min) * 1000, stimulus, color='grey', linewidth=0.5)
if (np.mean(stimulus) != 0) & (np.mean(stimulus) != 1):
eod_interp, eod_norm = extract_am(stimulus, time_array, sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if (titles_amp[a] != 'baseline to Zero') and not (
(a_f2 == 0) & (a_fr == 1) & (a_f1 == 0)):
ax[1].plot((time_array - min) * 1000, eod_interp, color='red', clip_on=True)
ax[0].plot((time_array - min) * 1000, eod_interp, color='red', clip_on=True)
for l in range(len(spikes)):
spikes[l] = spikes[l] * 1000
ax[2].eventplot(spikes, color='black')
ax[3].plot((time_array - min) * 1000, mat05, color='black')
pp, f = ml.psd(mat05 - np.mean(mat05), Fs=1 / deltat, NFFT=nfft,
noverlap=nfft // 2)
beat1 = (freq1 - eod_fr)[0]
beat2 = (freq2 - eod_fr)[0]
if 'Several' in single_wave:
freqs_beat = [np.abs(beat1), np.abs(beat2), np.abs(beat2 + beat1),
np.abs(beat2 - beat1)]
colors = ['red', 'green', 'orange', 'blue']
labels = ['B1', 'B2', 'B1+B2', '|B1-B2|']
else:
freqs_beat = [np.abs(beat1) / 2, np.abs(beat1), np.abs(beat1) * 2, np.abs(beat1 * 3),
np.abs(beat1 * 4)]
colors = ['grey', 'red', 'orange', 'blue', 'purple']
labels = ['', 'S1', 'S2 / B1', 'S3', 'S4 / B2']
for f_nr, freq_beat in enumerate(freqs_beat):
f_pos = f[np.argmin(np.abs(f - np.abs(freq_beat)))]
pp_pos = pp_mean[np.argmin(np.abs(f - np.abs(freq_beat)))]
ax[4].scatter(f_pos, pp_pos, color=colors[f_nr], label=labels[f_nr])
if text == 'text':
ax[4].text(f_pos - 15, pp_pos + add_pp, labels[f_nr], color=colors[f_nr],
fontsize=15, rotation=65)
if text != 'text':
plt.legend()
ax[4].plot(f, pp_mean, color='black')
ax[4].set_xlim([0, 700])
test = False
if test:
from utils_test import test_spikes_clusters
test_spikes_clusters(eod_fish_r, spikes, mat05, sampling, s_name='ms', resamp_fact=1000)
ax[0].set_xticks([])
ax[1].set_xticks([])
ax[2].set_xticks([])
ax[0].set_ylabel('Amplitude')
ax[1].set_ylabel('Beat')
ax[2].set_ylabel('Spikes')
ax[3].set_ylabel('Fr [Hz]')
ax[4].set_ylabel('Amplitude [Hz]')
ax[4].set_xlabel('f [Hz]')
ax[3].set_xlabel('Time [ms]')
ax[0].set_xticks([])
ax[1].set_xticks([])
ax[2].set_xticks([])
plt.subplots_adjust(top=0.7, left=0.15, right=0.95, hspace=0.5, wspace=0.1)
individual_tag = titles_amp[a] + ' ' + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_diverge_from_base_half' + str(add_half)
save_visualization(individual_tag, show, counter_contrast=0, savename='')
def get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr):
if phase_right == '_phaseright_':
if a_f1 == 0:
phaseshift_f1 = 0
phaseshift_f2 = 0
if a_f2 == 0:
phaseshift_f1 = 0
phaseshift_f2 = 0
if (a_f2 != 0) & (a_f1 != 0):
phaseshift_f1 = 2 * np.pi / 4
phaseshift_f2 = 2 * np.pi / 4
else:
phaseshift_f1 = phaseshift_fr
phaseshift_f2 = phaseshift_fr
return phaseshift_f1, phaseshift_f2
def plt_serach_nonlinearity_cell2(color=['red', 'blue', 'orange', 'purple'], log='', show=True, cells=[], add_half=0):
stimulus_lengths = [100] # , 10, 30, ] # [10, 10, 100, 10, 100]#, 10, 100, 10
for _, _ in enumerate(stimulus_lengths):
_, _ = find_row_col(cells) # [
for t, cell in enumerate(cells):
plot_style()
default_figsize(column=2, length=1.95) # .5 , figsize=(5.5, 5,)
grid = gridspec.GridSpec(1, 1, left=0.09, bottom=0.27, hspace=0.3, top=0.97, wspace=0.27,
right=0.97) # width_ratios=[1.7, 1],
ffts = ['fft4']
sampling = '_dt'
for _, _ in enumerate(ffts):
labels = labels_didactic2()
names = ['_all'] # '_mean', '_one',
axes = []
for n, name in enumerate(names):
names_key = ['c_0' + name + sampling, 'c_1' + name + sampling, 'c_2' + name + sampling,
'c_3' + name + sampling]
###########################################################################
first = True
if first:
save_name = 'calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_0.3_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.03'
frame = pd.read_pickle(load_folder_name(
'calc_model') + '/' + save_name + '.pkl') # calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_1_noAdapt__afr_1_zeros_trNr_500_fft5__dev_original_len_30_adaptoffset_bisecting__transient_1s__until_0.03
ax = plt.subplot(grid[0])
axes.append(ax)
ax.axvline(frame.fr.unique() / 2, color='grey', linewidth=0.5) # 'Fr='++' Hz'
for s, score_name in enumerate(names_key):
ax.plot(np.abs(frame.f1 - frame.eod_fr), frame[score_name], color=color[s],
label=labels[s].replace('\n', ','))
ax.set_xlim(0, 250)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Power [Hz]') # Signal amplitude
ax.show_spines('lb')
try:
ax.legend(loc=(0.7, 0.5), prop={'size': 9})
except:
pass
individual_tag = cell + '_AddHalf_' + str(add_half) + '_' + log
fig = plt.gcf()
fig.tag(axes, xoffs=-7.5)
save_visualization(individual_tag, show, counter_contrast=0, savename='')
if show:
plt.show()
def plt_serach_nonlinearity_cell(color=['red', 'blue', 'orange', 'purple'], log='', show=True, cells=[], add_half=0):
trials_nr = [500] # , 500, 500, ] # [1, 100, 10, 150, 15]#, 300, 30, 500
stimulus_lengths = [100] # , 10, 30, ] # [10, 10, 100, 10, 100]#, 10, 100, 10
for _, _ in enumerate(stimulus_lengths):
_, _ = find_row_col(cells) # [
for t, cell in enumerate(cells):
plot_style()
default_settings(column=2, length=3.5) # , figsize=(5.5, 5,)
grid = gridspec.GridSpec(1, 2, width_ratios=[1.7, 1], left=0.11, bottom=0.15, hspace=0.3, wspace=0.27,
right=0.99)
ffts = ['fft4']
sampling = '_dt'
for f, fft in enumerate(ffts):
labels = labels_didactic()
names = ['_all'] # '_mean', '_one',
axes = []
for n, name in enumerate(names):
names_key = ['c_0' + name + sampling, 'c_1' + name + sampling, 'c_2' + name + sampling,
'c_3' + name + sampling]
first = True
if first:
save_name = 'calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_0.3_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.03'
frame = pd.read_pickle(load_folder_name(
'calc_model') + '/' + save_name + '.pkl') # calc_nonlinearity_contrasts-_beat__AddToHalfFr_frange_from_10_to_400_in_1_noAdapt__afr_1_zeros_trNr_500_fft5__dev_original_len_30_adaptoffset_bisecting__transient_1s__until_0.03
ax = plt.subplot(grid[0])
axes.append(ax)
ax.axvline(frame.fr.unique() / 2, color='grey', linewidth=0.5) # 'Fr='++' Hz'
for s, score_name in enumerate(names_key):
ax.plot(np.abs(frame.f1 - frame.eod_fr), frame[score_name], color=color[s], label=labels[s])
ax.set_xlim(0, 250)
ax.set_xlabel('Beat [Hz]')
ax.set_ylabel('Signal amplitude [Hz]')
ax.show_spines('lb')
try:
ax.legend(loc=(0.7, 0.8))
except:
pass
save_name = load_folder_name(
'calc_model') + '/calc_nonlinearity_contrasts-_beat__AddToHalfFr_0_afr_1_zeros_trNr_500_fft5__dev_original_len_100_adaptoffset_bisecting__transient_50s__until_0.5.pkl'
ax = plt.subplot(grid[1])
axes.append(ax)
if os.path.exists(save_name):
frame = pd.read_pickle(save_name) # load_folder_name('calc_model')+'/nonlinearity_amp_var2.pkl'
frame_cell = frame[frame['cell'] == cell]
if fft == 'psd':
plt_nonlin(ax[t, f], frame_cell)
ax[t, f].set_title('trNr ' + str(
trials_nr[0]) + ' len ' + str(stimulus_lengths[0]) + ' ' + fft + ' FinalTr ' + str(
np.round(stimulus_lengths[0] * trials_nr[0] * np.mean(frame_cell.fr.unique()) / 2)))
else:
for n, name in enumerate(names):
title = False
if title:
plt.suptitle('trNr ' + str(
trials_nr[0]) + ' len ' + str(
stimulus_lengths[0]) + ' ' + fft + ' ' + name + ' Sampling ' + str(
sampling) + ' FinalTr ' + str(
np.round(
stimulus_lengths[0] * trials_nr[0] * np.nanmean(frame_cell.fr.unique()) / 2)))
ax.set_title(cell + ' CV ' + str(np.mean(frame_cell.cv.unique())))
ax_axis = frame_cell['a_f1'] * 100
for s, score_name in enumerate(names_key):
ax.plot(ax_axis[1::], frame_cell[score_name][1::],
color=color[s], label=labels[s])
ax.set_ylabel('Signal Amplitude [Hz]')
ax.set_aspect('equal')
ax.show_spines('lb')
ax.set_xlabel('contrast [%]')
if log == 'log':
ax.set_yscale('log')
ax.set_xscale('log')
individual_tag = cell + '_AddHalf_' + str(add_half) + '_' + log
ax = make_simple_tags(axes)
save_visualization(individual_tag, show, counter_contrast=0, savename='')
if show:
plt.show()
def labels_didactic():
labels = ['Beat ', '2 Beat / Baseline Fr ', '3 Beat', '4 Beat / 2 Baseline Fr']
return labels # $\cdot$
def labels_didactic2():
labels = [r' $f_{Stim}$ ', '$2f_{Stim}$, $f'+basename()+'$ ', r' $3f_{Stim}$ ', ' $4f_{Stim}$, $2 f'+basename()+'$']
return labels # $\cdot$
def make_simple_tags(axes, xpos=-0.03, ypos=1.02, letters=['A', 'B'], ):
fig = plt.gcf()
ppi = 72.0 # points per inch:
fs = mpl.rcParams['font.size'] * fig.dpi / ppi
for aa, ax in enumerate(axes):
ax.text(xpos, ypos, letters[aa], transform=ax.transAxes, ha='right', va='bottom', fontsize=fs)
return ax
def make_tags(axes=[], xoffs=-3, yoffs=1.2):
fig = plt.gcf()
if len(axes) < 1:
axes = plt.gca()
fig.tag(axes, xoffs=xoffs, yoffs=yoffs)
return axes
def plt_nonlin(ax, frame_cell, first='c_0_all_dt', second='c_1_all_dt', third='c_2_all_dt',
forth='c_3_all_dt'): # first = 'a_fundamental_original', second = 'a_h1_original', third = 'a_h2_original', forth = 'a_h3_original'
ax.plot(frame_cell['a_f1'] * 100, frame_cell[first], color='blue', label='S1')
ax.plot(frame_cell['a_f1'] * 100, frame_cell[second], color='orange',
label='S2 / B1 ') # Baseline f [B1] / Stimulus [S2]
ax.plot(frame_cell['a_f1'] * 100, frame_cell[third], color='green', label='S3')
ax.plot(frame_cell['a_f1'] * 100, frame_cell[forth], color='red',
label='S4 / B2')
def save_name_nonlinearity(add_half, a_f1_end=0.2, transient_s=0, adapt_offset='', n=1, stimulus_length=2, freq_type='',
adapt='', a_f2s=[0], freqs2=[0], dev='original', fft='fft', a_fr=1, trials_nr=150,
zeros='zeros'):
dev_name = '_dev_' + str(dev)
version_name = '_' + fft + '_'
if a_f1_end == 0.2:
end_name = ''
else:
end_name = '_until_' + str(a_f1_end)
trials_nr_name = '_trNr_' + str(trials_nr)
if transient_s != 0:
transient_s_name = '_transient_' + str(transient_s) + 's_'
else:
transient_s_name = ''
if n != 1:
n_name = '_power' + str(n)
else:
n_name = ''
a_fr_name = '_afr_' + str(a_fr) + '_' + zeros
if 'psd' in fft:
freq_type = ''
add_half_name = '_AddToHalfFr_' + str(add_half)
if adapt_offset != '':
adapt_offset_name = '_' + adapt_offset + '_'
else:
adapt_offset_name = ''
# die funktion dazu ist calc_nonlinearity_contrasts NOT calc_nonlinearity_contrasts_fft
if (len(freqs2) != 0) & (len(a_f2s) != 0):
if (len(freqs2) > 1) & (len(a_f2s) == 1):
freq_afname = 'frange_from_' + str(freqs2[0]) + '_to_' + str(freqs2[-1]) + '_in_' + str(
np.diff(freqs2)[0]) + '_af2_' + str(a_f2s[0]) # a_f2s =a_f2s, freqs2 = freqs2
elif (len(freqs2) > 1) & (len(a_f2s) > 1):
freq_afname = 'frange_from_' + str(freqs2[0]) + '_to_' + str(freqs2[-1]) + '_in_' + str(
np.diff(freqs2)[0]) + 'af2range_from_' + str(a_f2s[0]) + '_to_' + str(a_f2s[-1]) + '_in_' + str(
np.diff(a_f2s)[0]) # a_f2s =a_f2s, freqs2 = freqs2
else:
freq_afname = 'freq2_' + str(freqs2[0]) + '_af2_' + str(a_f2s[0])
save_name = load_folder_name(
'calc_model') + '/' + calc_nonlinearity_contrasts.__name__ + '-' + freq_type + add_half_name + a_fr_name + trials_nr_name + version_name + dev_name + '_len_' + str(
stimulus_length) + adapt + adapt_offset_name + n_name + transient_s_name + end_name + freq_afname + '.pkl'
return save_name
def plt_single_phaselockloss(colors, frame_cell, df, scores, cell, ax, df_name='df'):
frame_df = frame_cell[(frame_cell[df_name] == df) | (np.isnan(frame_cell[df_name]))]
mt_types = frame_cell.mt_type.unique()
if len(mt_types) < 1:
embed()
for s, score in enumerate(scores):
ax.set_title(cell[0:14] + ' DF=' + str(df), fontsize=8)
score_vals = []
score_vals05 = []
score_vals25 = []
score_vals75 = []
score_vals95 = []
contrasts_here = []
for m, mt_type in enumerate(mt_types):
if 'base' in mt_type:
marker = '*'
elif 'chirp' in mt_type:
marker = '^'
elif 'SAM' in mt_type:
marker = '.'
else:
marker = 'o'
frame_type = frame_df[frame_df.mt_type == mt_type]
if mt_type != 'base':
frame_type = frame_type.groupby('contrast').mean().reset_index()
frame_type75 = frame_type.groupby('contrast').quantile(0.75).reset_index()
frame_type25 = frame_type.groupby('contrast').quantile(0.25).reset_index()
frame_type95 = frame_type.groupby('contrast').quantile(1).reset_index()
frame_type05 = frame_type.groupby('contrast').quantile(0).reset_index()
contasts = np.array(list(map(float, frame_type['contrast'])))
frame_type['contrast'] = contasts
sorted = np.argsort(contasts)
score_val = frame_type[score].iloc[sorted]
score_val75 = frame_type75[score].iloc[sorted]
score_val25 = frame_type25[score].iloc[sorted]
score_val95 = frame_type95[score].iloc[sorted]
score_val05 = frame_type05[score].iloc[sorted]
contrast_here = frame_type['contrast'].iloc[sorted]
nr = 1
else:
score_val = [np.mean(frame_type[score])]
score_val75 = [np.percentile(frame_type[score], 75)]
score_val25 = [np.percentile(frame_type[score], 25)]
score_val95 = [np.percentile(frame_type[score], 100)]
score_val05 = [np.percentile(frame_type[score], 0)]
contrast_here = [0] # np.zeros(len(frame_type[score]))
nr = 2
try:
ax.scatter(contrast_here, score_val, marker=marker, color=colors[s], zorder=100 * nr, alpha=0.5,
s=8.5)
except:
print('axis problem')
embed()
score_vals.extend(np.array(score_val))
score_vals05.extend(np.array(score_val05))
score_vals95.extend(np.array(score_val95))
score_vals75.extend(np.array(score_val75))
score_vals25.extend(np.array(score_val25))
contrasts_here.extend(np.array(contrast_here))
ax.fill_between(np.array(contrasts_here)[np.argsort(contrasts_here)],
np.array(score_vals05)[np.argsort(contrasts_here)],
np.array(score_vals95)[np.argsort(contrasts_here)], color=colors[s], alpha=0.2)
ax.fill_between(np.array(contrasts_here)[np.argsort(contrasts_here)],
np.array(score_vals25)[np.argsort(contrasts_here)],
np.array(score_vals75)[np.argsort(contrasts_here)], color=colors[s], alpha=0.6)
ax.plot(np.array(contrasts_here)[np.argsort(contrasts_here)],
np.array(score_vals)[np.argsort(contrasts_here)], color=colors[s], label=score)
def plt_beats_modulation_several_with_overview_nice_from_three_final(only_first=True, limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
# path_new2 = load_folder_name('calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl'
# aber ich will ja die neuen Datensätzte
_, _, _ = find_cells_for_phaselocking()
datasets = ['2023-05-12-ar-invivo-1']
frame_all = pd.read_pickle(load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
plot_style()
default_settings(column=2, length=6)
for i, cell in enumerate(datasets):
path = load_folder_name('data') + 'cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
pass
else:
pass
DF1s_here = [DF1s[0]]
DF2s_here = [DF2s[0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s_here):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[(np.abs(frame_data_cell.m1 - DF1) < 0.01) & (
np.abs(frame_data_cell.m2 - DF2) < 0.01)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure()
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045,
right=0.97) #
frame_df = frame_df0[(frame_df0.c2 == c2)]
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
if len(frame_df) > 0:
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
cs = {}
means = {}
scores_datas = [['amp_f0_01_original', 'amp_f0_012_original',
'amp_B1_01_original', 'amp_B1_012_original'], ['amp_f0_02_original',
'amp_f0_0_original',
'amp_B2_02_original',
'amp_B2_012_original']]
colorss = [['green', 'purple', 'green', 'blue'], ['orange', 'black', 'orange', 'red']]
linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']]
show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means,
scores_datas)
find_mt_all(b)
contrasts = np.unique(frame_df.c1)
if len(contrasts) > 0:
contrasts_1_chosen, indeces_show = choice_specific_indices(contrasts, negativ='positiv',
units=5, cut_val=1)
nr_col = int(len(contrasts_1_chosen))
grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axps = []
for c_nr, c1 in enumerate(contrasts_1_chosen):
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
key = 'control_01'
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + key)
xlim = [200, 250]
nr_example = 0
###########################################
# time spikes
axt = plt.subplot(grid2[1, c_nr])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate,
spike_times, xlim, key=key)
##########################################
axt = plt.subplot(grid2[0, c_nr])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example,
sampling_rate, spike_times, time, xlim, key=key)
##########################################
# time psd
axp = plt.subplot(grid2[3, c_nr])
axp2 = plt.subplot(grid2[4, c_nr])
axps.append(axp)
axps.append(axp2)
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(DF2) + '_c2_' + str(
c2)
save_visualization(individual_tag, show)
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_from_three(only_first=True, limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
# aber ich will ja die neuen Datensätzte
datasets, loss, gain = find_cells_for_phaselocking()
frame_all = pd.read_pickle(load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
# cells_for_phaselocking, loss, gain = find_cells_for_phaselocking()
plot_style()
default_settings()
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
DF1s_here = [DF1s[0]]
else:
DF1s_here = DF1s # [0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[
(np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure(figsize=(15, 9))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4, left=0.045,
right=0.97) #
frame_df = frame_df0[(frame_df0.c2 == c2)]
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
if len(frame_df) > 0:
# frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
cs = {}
means = {}
scores_datas = [['amp_f0_01_original', 'amp_f0_012_original',
'amp_B1_01_original', 'amp_B1_012_original'], ['amp_f0_02_original',
'amp_f0_0_original',
'amp_B2_02_original',
'amp_B2_012_original']]
colorss = [['green', 'purple', 'green', 'blue'], ['orange', 'black', 'orange', 'red']]
linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']]
show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means,
scores_datas)
find_mt_all(b)
contrasts = np.unique(frame_df.c1)
if len(contrasts) > 0:
contrasts_1_chosen, indeces_show = choice_specific_indices(contrasts, negativ='positiv',
units=5, cut_val=1)
nr_col = int(len(contrasts_1_chosen))
grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col, height_ratios=[1, 1, 0.5, 1, 1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axps = []
for c_nr, c1 in enumerate(contrasts_1_chosen):
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
key = 'control_01'
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + key)
xlim = [200, 250]
nr_example = 0
###########################################
# time spikes
axt = plt.subplot(grid2[1, c_nr])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate,
spike_times, xlim, key=key)
##########################################
axt = plt.subplot(grid2[0, c_nr])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example,
sampling_rate, spike_times, time, xlim, key=key)
axp = plt.subplot(grid2[3, c_nr])
axp2 = plt.subplot(grid2[4, c_nr])
axps.append(axp)
axps.append(axp2)
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(DF2) + '_c2_' + str(
c2)
save_visualization(individual_tag, show)
print('finished examples')
embed()
def show_lines_several_plots(colorss, cs, frame_df_mean, grid1, linestyless, means, scores_datas):
for s_nr, score in enumerate(scores_datas):
scores_data = scores_datas[s_nr]
linestyles = linestyless[s_nr]
colors = colorss[s_nr]
ax = plt.subplot(grid1[s_nr])
for sss, score in enumerate(scores_data):
ax.plot(np.sort(frame_df_mean['c1']),
frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])],
color=colors[sss], linestyle=linestyles[
sss]) # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big)
if sss not in means.keys():
means[sss] = []
cs[sss] = []
ax.set_ylabel(score.replace('_mean', '').replace('amp_', '') + '[Hz]',
fontsize=8)
ax.set_xlabel('Contrast small')
ax.set_xlabel('Contrast small')
return ax
def color_three(name):
dict_here = {'0': 'grey',
'01': 'green',
'02': 'blue',
'012': 'purple',
'base_0': 'grey',
'control_01': 'green',
'control_02': 'blue'}
return dict_here[name]
def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_final(only_first=True, limit=1,
duration_exclude=0.45,
nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
_, _, _ = find_cells_for_phaselocking()
datasets = ['2023-05-12-ar-invivo-1']
frame_all = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
plot_style()
default_settings()
for i, cell in enumerate(datasets):
path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
DF1s_here = [DF1s[0]]
else:
DF1s_here = DF1s # [0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[
(np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
frame_df = frame_df0[(frame_df0.c2 == c2)]
contrasts = np.unique(frame_df.c1)[::-1]
if len(contrasts) > 0:
contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1)
for c_nr, c1 in enumerate(contrasts_1_chosen):
if len(frame_df) > 0:
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
find_mt_all(b)
keys = ['base_0', 'control_01', 'control_02', '012'] # ]
nr_col = len(keys)
axts = []
axps = []
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
for nr_example in range(len(spike_times)):
###################################
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure(figsize=(20, 14))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
plt_lines_phaselockingloss(frame_df_mean, grid1)
grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col,
height_ratios=[1, 1, 0.5, 1,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str(
DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str(
c1) + ' Trial ' + str(nr_example))
xlim = [200, 250]
###########################################
# time spikes
for k, key in enumerate(keys):
axt = plt.subplot(grid2[1, k])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example,
sampling_rate,
spike_times, xlim, key=key,
color=color_three(key))
##########################################
axt = plt.subplot(grid2[0, k])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type,
nr_example,
sampling_rate, spike_times, time, xlim,
key=key, color=color_three(key))
axt.set_title(key)
axt.set_ylabel('')
##########################################
# time psd
axp = plt.subplot(grid2[3, k])
axp2 = plt.subplot(grid2[4, k])
spikes_mat = plt_psds_phaselock(axp, axp2, counter, f,
nr_example, sampling_rate,
spikes_mat, spikes_mats, key=key,
color=color_three(key))
axps.append(axp)
axps.append(axp2)
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(
DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str(
nr_example)
save_visualization(individual_tag, show)
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar(only_first=True, limit=1,
duration_exclude=0.45, nfft=int(4096),
show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
datasets, loss, gain = find_cells_for_phaselocking()
frame_all = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
# cells_for_phaselocking, loss, gain = find_cells_for_phaselocking()
plot_style()
default_settings()
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
DF1s_here = [DF1s[0]]
else:
DF1s_here = DF1s # [0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[
(np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
frame_df = frame_df0[(frame_df0.c2 == c2)]
contrasts = np.unique(frame_df.c1)[::-1]
if len(contrasts) > 0:
contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1)
for c_nr, c1 in enumerate(contrasts_1_chosen):
if len(frame_df) > 0:
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
find_mt_all(b)
# todo: man könnte auch heir ienfach das mt und den mt name abspeichern
keys = ['base_0', 'control_01', 'control_02', '012'] # ]
nr_col = len(keys)
axts = []
axps = []
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
for nr_example in range(len(spike_times)):
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure(figsize=(20, 14))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
plt_lines_phaselockingloss(frame_df_mean, grid1)
grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col,
height_ratios=[1, 1, 0.5, 1,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str(
DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str(
c1) + ' Trial ' + str(nr_example))
xlim = [200, 250]
###########################################
# time spikes
for k, key in enumerate(keys):
axt = plt.subplot(grid2[1, k])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example,
sampling_rate,
spike_times, xlim, key=key,
color=color_three(key))
axt = plt.subplot(grid2[0, k])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type,
nr_example,
sampling_rate, spike_times, time, xlim,
key=key, color=color_three(key))
axt.set_title(key)
axt.set_ylabel('')
##########################################
# time psd
axp = plt.subplot(grid2[3, k])
axp2 = plt.subplot(grid2[4, k])
spikes_mat = plt_psds_phaselock(axp, axp2, counter, f,
nr_example, sampling_rate,
spikes_mat, spikes_mats, key=key,
color=color_three(key))
axps.append(axp)
axps.append(axp2)
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(
DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str(
nr_example)
save_visualization(individual_tag, show)
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_single_pdf(only_first=True, limit=1,
duration_exclude=0.45,
nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
# path_new2 = load_folder_name('calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl'
# aber ich will ja die neuen Datensätzte
datasets, loss, gain = find_cells_for_phaselocking()
frame_all = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
# cells_for_phaselocking, loss, gain = find_cells_for_phaselocking()
plot_style()
default_settings()
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
DF1s_here = [DF1s[0]]
else:
DF1s_here = DF1s # [0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[
(np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)]
# frame_df0 = frame_data_cell[(np.abs(frame_data_cell.m1 - DF1) < 0.02) & (
# np.abs(frame_data_cell.m2 - DF2) < 0.02)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
frame_df = frame_df0[(frame_df0.c2 == c2)]
contrasts = np.unique(frame_df.c1)[::-1]
if len(contrasts) > 0:
contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1)
for c_nr, c1 in enumerate(contrasts_1_chosen):
if len(frame_df) > 0:
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
#
# plt_cocktailparty_lines(ax, frame_df)
find_mt_all(b)
keys = ['control_01'] # ,'base_0', 'control_02', '012']
nr_col = len(keys)
axts = []
axps = []
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
for nr_example in range(len(spike_times)):
###################################
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure(figsize=(20, 14))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
plt_lines_phaselockingloss(frame_df_mean, grid1)
grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col,
height_ratios=[1, 1, 0.5, 1,
1, 1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str(
DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str(
c1) + ' Trial ' + str(nr_example))
xlim = []
# time spikes
for k, key in enumerate(keys):
axt = plt.subplot(grid2[1, k])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example,
sampling_rate,
spike_times, xlim, key=key,
color=color_three(key))
##########################################
axt = plt.subplot(grid2[0, k])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type,
nr_example,
sampling_rate, spike_times, time, xlim,
key=key, color=color_three(key))
axt.set_title(key)
axt.set_ylabel('')
##########################################
# time psd
axp = plt.subplot(grid2[3, k])
axp2 = plt.subplot(grid2[4, k])
spikes_mat = plt_psds_phaselock(axp, axp2, counter, f,
nr_example, sampling_rate,
spikes_mat, spikes_mats, key=key,
color=color_three(key))
axps.append(axp)
axps.append(axp2)
# hists
axi = plt.subplot(grid2[5, k])
isi = calc_isi(spike_times[nr_example][key],
frame_type.iloc[nr_example].EODf)
axi.hist(np.concatenate(isi), bins=100) # color = 'grey',
axi.axvline(x=1, color='black', linestyle='--')
if len(axts) > 0:
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(
DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str(
nr_example)
save_visualization(individual_tag, show, pdf=True)
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_from_three_contorol_compar_single(only_first=True, limit=1,
duration_exclude=0.45,
nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = False
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
pass
# aber ich will ja die neuen Datensätzte
datasets, loss, gain = find_cells_for_phaselocking()
frame_all = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl') # pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
plot_style()
default_settings()
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
if os.path.exists(path):
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_in_plot_phaselocking
test_in_plot_phaselocking(b, path)
if cont2 == True:
counter = 0
DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique = get_dfs_and_contrasts_from_calccocktailparty(
cell, frame_all)
if only_first:
DF1s_here = [DF1s[0]]
else:
DF1s_here = DF1s # [0]]
for d1, DF1 in enumerate(DF1s_here):
for d2, DF2 in enumerate(DF2s):
# das ist blöd man sollte die abgespeicherten Ms machen
frame_df0 = frame_data_cell[
(np.round(frame_data_cell.m1, 2) == DF1) & (np.round(frame_data_cell.m2, 2) == DF2)]
contrasts_2_chosen = np.sort(np.unique(frame_df0.c2)) # c1_unique_big[0:5]
for c_nr2, c2 in enumerate(contrasts_2_chosen):
frame_df = frame_df0[(frame_df0.c2 == c2)]
contrasts = np.unique(frame_df.c1)[::-1]
if len(contrasts) > 0:
contrasts_1_chosen = contrasts # , indeces_show = choice_specific_indices(contrasts, negativ = 'positiv', units = 5, cut_val = 1)
for c_nr, c1 in enumerate(contrasts_1_chosen):
if len(frame_df) > 0:
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
frame_c1 = frame_df[(frame_df.c1 == c1)] # | (frame_df.mt_type == 'base')
print(c_nr)
mt_types = frame_c1.mt_type.unique()
for mt_type in mt_types:
find_mt_all(b)
keys = ['control_01'] # ,'base_0', 'control_02', '012']
nr_col = len(keys)
axts = []
axps = []
frame_type = frame_c1[
(frame_c1.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mat, spikes_mats = load_spikes_eods_phaselock(
b, cell, datas_new, frame_type, names, nfft)
for nr_example in range(len(spike_times)):
###################################
plt.suptitle(cell + ' DF1=' + str(DF1) + ' DF2=' + str(DF2))
plt.figure(figsize=(20, 14))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
# Vergleichsplot
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
plt_lines_phaselockingloss(frame_df_mean, grid1)
grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col,
height_ratios=[1, 1, 0.5, 1,
1, 1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
plt.suptitle('data ' + cell + ' ' + mts.name + ' ' + '_DF1_' + str(
DF1) + '_DF2_' + str(DF2) + '\n_c2_' + str(c2) + '_c1_' + str(
c1) + ' Trial ' + str(nr_example))
xlim = [200, 250]
###########################################
# time spikes
for k, key in enumerate(keys):
axt = plt.subplot(grid2[1, k])
time = plt_voltage_phaselock(V_1, axt, axts, counter, nr_example,
sampling_rate,
spike_times, xlim, key=key,
color=color_three(key))
axt = plt.subplot(grid2[0, k])
plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type,
nr_example,
sampling_rate, spike_times, time, xlim,
key=key, color=color_three(key))
axt.set_title(key)
axt.set_ylabel('')
##########################################
# time psd
axp = plt.subplot(grid2[3, k])
axp2 = plt.subplot(grid2[4, k])
spikes_mat = plt_psds_phaselock(axp, axp2, counter, f,
nr_example, sampling_rate,
spikes_mat, spikes_mats, key=key,
color=color_three(key))
axps.append(axp)
axps.append(axp2)
# embed()
axi = plt.subplot(grid2[5, k])
# frame_type.iloc[nr_example].EODf
isi = calc_isi(spike_times[nr_example][key],
frame_type.iloc[nr_example].EODf)
axi.hist(np.concatenate(isi), bins=100) # color = 'grey',
axi.axvline(x=1, color='black', linestyle='--')
if len(axts) > 0:
axts[0].get_shared_y_axes().join(*axts[0::2])
axts[1].get_shared_y_axes().join(*axts[1::2])
join_x(axts)
join_y(axps)
set_same_ylim(axps)
join_x(axps)
individual_tag = 'data' + cell + '_DF1_' + str(DF1) + '_DF2_' + str(
DF2) + '_c2_' + str(c2) + '_c1_' + str(c1) + '_trial_' + str(
nr_example)
save_visualization(individual_tag, show)
print('finished examples')
embed()
def plt_lines_phaselockingloss(frame_df_mean, grid1):
cs = {}
means = {}
scores_datas = [['amp_f0_01_original', 'amp_f0_012_original',
'amp_B1_01_original', 'amp_B1_012_original'],
['amp_f0_02_original',
'amp_f0_0_original', 'amp_B2_02_original',
'amp_B2_012_original']]
colorss = [['green', 'purple', 'green', 'blue'],
['orange', 'red', 'orange', 'red']]
linestyless = [['--', '--', '-', '-'], ['--', '--', '-', '-']]
for s_nr, score in enumerate(scores_datas):
scores_data = scores_datas[s_nr]
linestyles = linestyless[s_nr]
colors = colorss[s_nr]
ax = plt.subplot(grid1[s_nr])
for sss, score in enumerate(scores_data):
ax.plot(np.sort(frame_df_mean['c1']),
frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])],
color=colors[sss], linestyle=linestyles[
sss], label=score.replace('_mean', '').replace('amp_',
'') + '[Hz]') # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big)
if sss not in means.keys():
means[sss] = []
cs[sss] = []
ax.set_ylabel('Peak Amplitude',
fontsize=8)
ax.set_xlabel('Contrast small')
ax.set_xlabel('Contrast small')
ax.legend()
def find_cells_for_phaselocking():
cells_for_phaselocking = [
'2023-05-12-aq-invivo-1', '2023-05-03-aa-invivo-1',
'2023-05-12-al-invivo-1', '2023-05-12-ai-invivo-1',
'2023-05-24-ac-invivo-1', '2023-05-12-at-invivo-1',
'2023-05-12-as-invivo-1',
'2023-05-12-ap-invivo-1', '2023-05-12-af-invivo-1', '2023-05-12-ae-invivo-1',
'2023-05-12-ar-invivo-1', ]
loss = '2023-05-12-ap-invivo-1' # (Verlust)
gain = '2023-05-03-aa-invivo-1'
return cells_for_phaselocking, loss, gain
def load_spikes_eods_phaselock(b, cell, datas_new, frame_type, names, nfft):
frame_name = frame_type # [frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mt_names = frame_type.mt_name.unique()
mts = b.multi_tags[mt_names[0]]
print(mts.name)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
datas_new.append(cell)
try:
pass
except:
print('rlx problem')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
spikes_mats = []
for m in mt_idxs: # range(len(mts.positions[:]))
frame_features = feature_extract_cut(mts, m)
zeroth_cut, first_cut, second_cut, third_cut, fish_number, fish_cuts, whole_duration, delay, cont = load_four_durations(
mts, frame_features, 0, m)
try:
eods, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='LocalEOD-1')
except:
print('eods thing')
embed()
sampling_rate = get_sampling(b, 'EOD')
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate, NFFT=nfft,
noverlap=nfft // 2)
else:
pass
cut = 0.05
eod_mt, spikes_mt, time_eod, time_laod_eods, timepoint = spike_times_cocktailparty(b, delay, mts, m)
v_1, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='V-1')
eods_g, spikes_mt = load_eod_for_three(b, delay, mts, m, load_eod_array='EOD')
devname, smoothened2, smoothed05, mat, time_here, arrays_calc, effective_duration, spikes_cut = cut_spikes_sequences(
delay, spikes_mt, sampling_rate, fish_cuts, cut=cut,
fish_number=fish_number, cut_compensate=True, devname_orig=['original'], cut_length=False)
spike_times.append(spikes_cut)
v_1_cut, _ = cut_eod_sequences(v_1, fish_cuts,
time_eod, cut=cut,
rec=False,
fish_number=fish_number)
eods_cut, _ = cut_eod_sequences(eods, fish_cuts,
time_eod, cut=cut,
rec=False,
fish_number=fish_number)
eods_g_cut, _ = cut_eod_sequences(eods_g, fish_cuts,
time_eod, cut=cut,
rec=False,
fish_number=fish_number)
spikes_mats.append(arrays_calc[0])
test = False
if test:
fig, ax = plt.subplots(2, 1)
ax[0].plot(np.arange(0, len(v_1_cut['control_01']) / sampling_rate, 1 / sampling_rate),
v_1_cut['control_01'])
ax[0].scatter(spikes_cut['control_01'][0],
np.max(v_1_cut['control_01']) * np.ones(len(spikes_cut['control_01'][0])))
ax[1].plot(np.arange(0, len(arrays_calc[0]['control_01']) / sampling_rate, 1 / sampling_rate),
arrays_calc[0]['control_01'])
ax[1].scatter(spikes_cut['control_01'][0],
np.max(arrays_calc[0]['control_01']) * np.ones(len(spikes_cut['control_01'][0])))
plt.show()
eods_all.append(eods_cut)
V_1.append(v_1_cut)
eods_all_g.append(eods_g_cut)
return V_1, eods_all, f, mts, sampling_rate, spike_times, spikes_mats[0], spikes_mats
def plt_voltage_phaselock(V_1, axt, axts, counter, nr_example, sampling_rate, spike_times, xlim, key='01',
color='purple'):
axt.set_ylabel('local')
time = np.arange(0, len(V_1[nr_example][key]) / sampling_rate,
1 / sampling_rate) * 1000
axt.plot(time, V_1[nr_example][key], color=color, linewidth=0.5)
if (len(spike_times[nr_example][key][0]) > 0) & (len(V_1[nr_example][key]) > 0):
try:
axt.scatter((spike_times[nr_example][key][0]) * 1000,
np.max(V_1[nr_example][key]) * np.ones(len(spike_times[nr_example][key][0]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
except:
print('spikes something')
embed()
if len(xlim) > 0:
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axts.append(axt)
return time
def plt_psds_phaselock(axp, axp2, counter, f, nr_example, sampling_rate, spikes_mat, spikes_mats, color='purple',
key='01'):
ps = []
for s, spikes_mat in enumerate(spikes_mats):
try:
p, f = ml.psd(spikes_mat[key] - np.mean(spikes_mat[key]), Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
except:
print('p something')
embed()
ps.append(p)
if s == nr_example:
color = color
zorder = 100
axp.plot(f, p, color=color, zorder=zorder)
else:
color = 'grey'
zorder = 1
axp2.plot(f, p, color=color, zorder=zorder)
axp2.set_xlim(0, 1000)
axp.set_xlim(0, 1000)
remove_xticks(axp)
axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2, linestyle='--')
axp2.set_xlabel('Power [Hz]')
if counter != 0:
remove_yticks(axp2)
axp2.set_ylabel('')
if counter != 0:
remove_yticks(axp)
axp.set_ylabel('')
return spikes_mat
def plt_stimulus_phaselock(axt, axts, counter, eods_all, frame_type, nr_example, sampling_rate, spike_times, time,
xlim, key='01', color='red'):
stimulus = eods_all[nr_example][key] # eods_g + Efield
if len(stimulus) > 0:
axt.set_title(' c1' + str(np.unique(frame_type.c1)) + ' c2' + str(np.unique(frame_type.c2)))
axts.append(axt)
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all')
embed()
try:
eods_am, eod_norm = extract_am(stimulus, time, norm=False)
except:
print('am something')
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
axt.plot(time, eods_am, color=color)
axt.scatter(np.array(spike_times[nr_example][key][0]) * 1000,
np.mean(eod_norm) * np.ones(len(spike_times[nr_example][key][0]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
if len(xlim) > 0:
axt.set_xlim(xlim)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
def get_features_and_info(mts, dfs=[], contrasts=[]):
features = []
id = []
for ff, f in enumerate(mts.features):
if 'id' in f.data.name:
id = f.data.name
elif 'Contrast' in f.data.name:
contrasts = mts.features[f.data.name].data[:]
elif 'DeltaF' in f.data.name:
dfs = mts.features[f.data.name].data[:]
else:
features.append(f.data.name)
return features, dfs, contrasts, id
def get_most_similiar_spikes(all_spikes, am_corr_cut, beat_cut, error, maxima, spikes_cut):
most_similiar = np.where(error < np.sort(error)[6])[0]
beat_final = []
am_final = []
spike_sm = []
spike = []
max = []
# ok wir machen das erstmal am ähnlichsten das sollte schon passen!
max_corr = True
if max_corr:
for l in range(len(most_similiar)):
beat_final.append(beat_cut[most_similiar[l]])
spike_sm.append(spikes_cut[most_similiar[l]])
spike.append(all_spikes[most_similiar[l]])
max.append(maxima[most_similiar[l]])
am_final.append(am_corr_cut[most_similiar[l]])
else:
beat_final = beat_cut
spike_sm = spikes_cut
spike = all_spikes
am_final = am_corr_cut
return am_final, beat_final, most_similiar, spike, spike_sm
def plt_beats_modulation_several_with_overview_nice_big_final3(contrasts_given=[], datasets=['2020-10-20-ad-invivo-1'],
dfs_all_unique_given=[25], limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl')
plot_style()
for i, cell in enumerate(datasets):
path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1']
df_pos = False
if cell not in cells_exclude:
if os.path.exists(path):
print('exists')
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
if cont2 == True:
print('cont2')
counter = 0
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) < 1:
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) > 0:
if len(dfs_all_unique_given) < 1:
dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1]
df_name = 'df_sign'
df_pos = '' # 'min_df'
dfs_all_unique = list(dfs_all_unique)
# todo: also hier gibts halt noch pobleme
if len(np.unique(np.array(dfs_all_unique))) < 2:
df_name = 'df'
dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1]
dfs_all_unique = list(dfs_all_unique)
else:
dfs_all_unique = dfs_all_unique_given
df_name = 'df_sign'
if len(dfs_all_unique) > 0:
if df_pos == 'min_df':
try:
dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]]
except:
print('df min')
embed()
for df_chosen in dfs_all_unique:
if not np.isnan(df_chosen):
frame_df = frame_cell[frame_cell[df_name] == df_chosen]
contrasts_all_unique = np.unique(frame_df.contrast)
contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)]
if len(contrasts_given) > 0:
contrasts_all_unique = contrasts_given
if len(contrasts_all_unique) > 1:
mt_types = frame_df.mt_type.unique()
for mt_type in mt_types:
if 'base' not in mt_type:
contrasts_here = []
frame_type = frame_df[
(frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
default_figsize(column=2, length=3.5) # 5
gs0 = gridspec.GridSpec(1, 1, hspace=0.4, bottom=0.18,
left=0.1, top=0.94,
right=0.97) # width_ratios=[4, 1],
if (cell == '2020-10-20-ad-invivo-1') & (
50 == df_chosen): # das erst fehlt aus welchem Grund auch immer
reduce = 0
else:
reduce = 0
nr_col = int(len(np.unique(contrasts_all_unique))) - reduce
grid2 = gridspec.GridSpecFromSubplotSpec(5, nr_col,
height_ratios=[1, 0.5, 1, 0.1,
1, ],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0]) # 0.7,1
axts = []
axfs = []
axps = []
mt_names = frame_type.mt_name.unique()
counters = []
for m, mt_name in enumerate(mt_names):
frame_name = frame_type[frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mts = b.multi_tags[mt_name]
print(mts.name)
name = mts.name
contrast = name.split('=')[1].split('%')[0]
if contrast not in contrasts_here:
print(contrast)
if len(np.where(np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[0]) > 0:
if np.isnan(float(contrast)):
counter = 0
else:
try:
counter = np.where(
np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[0][0] - reduce # +1
except:
print('something')
embed()
counters.append(counter)
try:
dfs = [mts.metadata[mts.name]['DeltaF']] * len(
mts.positions[:])
except:
dfs = mts.metadata['DeltaF']
features, dfs, contrasts, id = get_features_and_info(mts,
dfs=dfs)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
eod = b.data_arrays['LocalEOD-1'][:]
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
for m in mt_idxs: # range(len(mts.positions[:]))
try:
eods, _ = link_arrays_eod(b, mts.positions[:][m],
mts.extents[:][m],
'LocalEOD-1')
except:
print('eods thing')
embed()
eods_all.append(eods)
eods_g, sampling_rate = link_arrays_eod(b,
mts.positions[
:][m],
mts.extents[:][
m],
'EOD')
v_1, sampling_rate = link_arrays_eod(b,
mts.positions[:][
m],
mts.extents[:][m],
'V-1')
eods_all_g.append(eods_g)
V_1.append(v_1)
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods),
Fs=sampling_rate,
NFFT=nfft,
noverlap=nfft // 2)
eod_fr = f[np.argmax(p)]
else:
eod_fr = eod_frs[m]
print('EODF' + str(eod_fr))
spike_times.append(
(mts.retrieve_data(m, 'Spikes-1')[:] -
mts.positions[
m]) * 1000) # - cut
print(len(spike_times))
smooth = []
spikes_mats = []
for s in range(len(spike_times)):
try:
spikes_mat = cr_spikes_mat(spike_times[s] / 1000,
sampling_rate,
int(
mts.extents[:][
mt_idxs[
s]] * sampling_rate)) # time[-1] * sampling_rate
except:
print('mts prob')
embed()
spikes_mats.append(spikes_mat)
# für den Mean danach schneiden wir das wie das kürzeste
try:
smooth.append(gaussian_filter(
spikes_mat[
0:int(np.min(mts.extents[:]) * sampling_rate)],
sigma=0.002 * sampling_rate))
except:
print('embed problem')
embed()
try:
pass
except:
print('smoothed thing')
embed()
skip_nr = 2
xlim = [0, 1000 * skip_nr / np.abs(dfs[m])]
nr_example = 0 # 'no'#0
##########################################
try:
axt = plt.subplot(grid2[0, counter])
except:
print('axt something')
embed()
axts.append(axt)
stimulus = eods_all[nr_example] # eods_g + Efield
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all2')
embed()
eods_am, eod_norm = extract_am(stimulus, time, norm=False,
kind='linear')
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
am = False
if am:
axt.plot(time, eods_am, color='red')
scatter_extra = False
if scatter_extra:
axt.scatter(spike_times[nr_example],
np.mean(eod_norm) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
axt.show_spines('')
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axt.text(1, 1, '$c_{' + vary_val() + '}=%s$' % (
contrast) + '$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % (
int(dfs[m])), ha='right',
transform=axt.transAxes)
###########################################
# time spikes
axt = plt.subplot(grid2[1, counter])
axt.set_ylabel('local')
axt.show_spines('')
time = np.arange(0, len(V_1[nr_example]) / sampling_rate,
1 / sampling_rate) * 1000
# ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht
# das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte
umstuelp = False
if umstuelp:
# ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix!
spikes_umstuelpen(eod, sampling_rate, time)
eods_cut, spikes_cut, times_cut, cut_next, smoothed_cut = cut_spike_snippets(
spike_times[nr_example], period_based=True,
array_cut2=np.arange(0, len(
eods_all[nr_example]) / sampling_rate,
skip_nr / np.abs(dfs[m])),
end=2000, smoothened=smooth[nr_example],
time_eod=time / 1000, norming=False)
axt.eventplot(np.array(spikes_cut[0:4]) * 1000,
color='black') # lineoffsets=np.max(V_1[nr_example])* np.ones(
axt.set_xlim(xlim)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axt.show_spines('')
axts.append(axt)
# convolved firing rate
axf = plt.subplot(grid2[2, counter])
if len(smooth[nr_example]) != len(time):
time_here = time[0:len(smooth[nr_example])]
else:
time_here = time # [0:len(smooth[nr_example])]
mean_firing = True
if mean_firing:
lengths = []
for sm in smoothed_cut[0:4]:
lengths.append(len(sm))
sms = []
for sm in smoothed_cut[0:4]:
sms.append(sm[0:np.min(lengths)])
time_here = time[0:np.shape(sms)[1]]
axf.plot(time_here, np.mean(sms, axis=0), color='grey')
else:
axf.plot(time_here, smooth[nr_example], color='grey',
)
axf.show_spines('')
axf.set_xlim(xlim)
axfs.append(axf)
##########################################
# time psd
axp2 = plt.subplot(grid2[4, counter])
ps = []
maxx = 1000
for s, spikes_mat in enumerate(spikes_mats):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat),
Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
if s == nr_example:
pass
else:
pass
axp2.set_xlim(0, maxx)
axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2,
linestyle='-')
pp = np.mean(ps, axis=0)
eodf = np.mean(frame_name.eod_fr)
names = ['0', '01', '02', '012']
names_here = [names[1]] #
extend = False
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept()
colors_array = ['pink', color01]
if float(contrast) > 2:
name = names_here[
0]
else:
name = 'eodf'
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(
[],
colors_array,
df_chosen,
maxx,
eodf,
color_eodf=coloer_eod_fr_core(),
name=
name,
color_stim=color_stim_core(),
color_stim_mult=color_stim_core()) # 'black'color_stim_core()
plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0,
colors_peaks, limit=1200,
perc_peaksize=0.15,
alphas=alphas, extend=extend, ms=18,
clip_on=False)
legend_here = False
if legend_here:
if (counter == 2) & (name != 'eodf'):
try:
handles, labels = axp2.get_legend_handles_labels()
reorder_legend_handles(axp2,
order=[len(labels) - 3,
len(labels) - 2,
len(labels) - 1],
loc=(-2.5, 1), fs=9,
handlelength=1, ncol=3)
except:
print('label something')
embed()
axp2.set_xlabel('Frequency [Hz]')
if counter != 0:
remove_yticks(axp2)
else:
axp2.set_ylabel(power_spectrum_name())
axps.append(axp2)
#############################
# spike_times[nr_example]
isis = False
if isis:
axi = plt.subplot(grid2[-1, counter])
plt_isis_phaselocking(axi, frame_name,
spike_times)
axi.set_xticks_delta(2)
axi.set_xlim(0, 13)
try:
axts[0].get_shared_y_axes().join(*axts[0::2])
except:
print('axt problem')
embed()
axts[1].get_shared_y_axes().join(*axts[1::2])
axts[0].get_shared_x_axes().join(*axts)
join_y(axfs)
join_y(axps)
join_x(axps)
fig = plt.gcf()
fig.tag([axts[4], axts[2], axts[0], ], xoffs=-2, yoffs=1)
firing_rate_scalebars(axfs[np.where(np.array(counters) == 0)[0][0]],
length=10)
individual_tag = 'data ' + cell + '_DF_chosen_' + str(
df_chosen) + mt_type
save_visualization(individual_tag, show)
print('plotted')
file.close()
print('finished examples')
def reorder_legend_handles(ax1, order=[0, 2, 4, 1, 3, 5], ncol=None, rev=False, loc=(0.65, 0.6), fs=9,
handlelength=0.5):
handles, labels = ax1.get_legend_handles_labels()
if rev:
order = [len(labels) - order[0], len(labels) - order[1], len(labels) - order[2]]
hand_new = [handles[i] for i in order]
label_new = [labels[i] for i in order]
if fs:
if ncol:
first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, fontsize=fs,
handlelength=handlelength, ncol=ncol)
else:
first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc, fontsize=fs,
handlelength=handlelength)
else:
if ncol:
first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc,
handlelength=handlelength, ncol=ncol)
else:
first_legend = ax1.legend(handles=hand_new, labels=label_new, loc=loc,
handlelength=handlelength)
return first_legend
def plt_beats_modulation_several_with_overview_nice_big_final2(contrasts_given=[], datasets=['2020-10-20-ad-invivo-1'],
dfs_all_unique_given=[25], limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl')
plot_style()
for i, cell in enumerate(datasets):
path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1']
df_pos = False
if cell not in cells_exclude:
if os.path.exists(path):
print('exists')
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
if cont2 == True:
print('cont2')
counter = 0
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) < 1:
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) > 0:
if len(dfs_all_unique_given) < 1:
dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1]
df_name = 'df_sign'
df_pos = '' # 'min_df'
dfs_all_unique = list(dfs_all_unique)
# todo: also hier gibts halt noch pobleme
if len(np.unique(np.array(dfs_all_unique))) < 2:
df_name = 'df'
dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1]
dfs_all_unique = list(dfs_all_unique)
else:
dfs_all_unique = dfs_all_unique_given
df_name = 'df_sign'
if len(dfs_all_unique) > 0:
if df_pos == 'min_df':
try:
dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]]
except:
print('df min')
embed()
for df_chosen in dfs_all_unique:
if not np.isnan(df_chosen):
frame_df = frame_cell[frame_cell[df_name] == df_chosen]
contrasts_all_unique = np.unique(frame_df.contrast)
contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)]
if len(contrasts_given) > 0:
contrasts_all_unique = contrasts_given
if len(contrasts_all_unique) > 1:
mt_types = frame_df.mt_type.unique()
for mt_type in mt_types:
if 'base' not in mt_type:
contrasts_here = []
frame_type = frame_df[
(frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
default_settings(column=2, length=5)
gs0 = gridspec.GridSpec(1, 1, hspace=0.4,
left=0.1, top=0.94,
right=0.97) # width_ratios=[4, 1],
if (cell == '2020-10-20-ad-invivo-1') & (
50 == df_chosen): # das erst fehlt aus welchem Grund auch immer
reduce = 0
else:
reduce = 0
nr_col = int(len(np.unique(contrasts_all_unique))) - reduce
grid2 = gridspec.GridSpecFromSubplotSpec(7, nr_col,
height_ratios=[1, 0.5, 1, 0.1,
1, 0.7,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axfs = []
axps = []
mt_names = frame_type.mt_name.unique()
counters = []
for m, mt_name in enumerate(mt_names):
frame_name = frame_type[frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mts = b.multi_tags[mt_name]
print(mts.name)
name = mts.name
contrast = name.split('=')[1].split('%')[0]
if contrast not in contrasts_here:
print(contrast)
if len(np.where(np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[0]) > 0:
if np.isnan(float(contrast)):
counter = 0
else:
try:
counter = np.where(
np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[0][0] - reduce # +1
except:
print('something')
embed()
counters.append(counter)
try:
dfs = [mts.metadata[mts.name]['DeltaF']] * len(
mts.positions[:])
except:
dfs = mts.metadata['DeltaF']
features, dfs, contrasts, id = get_features_and_info(mts,
dfs=dfs)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
eod = b.data_arrays['LocalEOD-1'][:]
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
for m in mt_idxs: # range(len(mts.positions[:]))
try:
eods, _ = link_arrays_eod(b, mts.positions[:][m],
mts.extents[:][m],
'LocalEOD-1')
except:
print('eods thing')
embed()
eods_all.append(eods)
eods_g, sampling_rate = link_arrays_eod(b,
mts.positions[
:][m],
mts.extents[:][
m],
'EOD')
v_1, sampling_rate = link_arrays_eod(b,
mts.positions[:][
m],
mts.extents[:][m],
'V-1')
eods_all_g.append(eods_g)
V_1.append(v_1)
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods),
Fs=sampling_rate,
NFFT=nfft,
noverlap=nfft // 2)
eod_fr = f[np.argmax(p)]
else:
eod_fr = eod_frs[m]
print('EODF' + str(eod_fr))
spike_times.append(
(mts.retrieve_data(m, 'Spikes-1')[:] -
mts.positions[
m]) * 1000) # - cut
print(len(spike_times))
smooth = []
spikes_mats = []
for s in range(len(spike_times)):
try:
spikes_mat = cr_spikes_mat(spike_times[s] / 1000,
sampling_rate,
int(
mts.extents[:][
mt_idxs[
s]] * sampling_rate)) # time[-1] * sampling_rate
except:
print('mts prob')
embed()
spikes_mats.append(spikes_mat)
# für den Mean danach schneiden wir das wie das kürzeste
try:
smooth.append(gaussian_filter(
spikes_mat[
0:int(np.min(mts.extents[:]) * sampling_rate)],
sigma=0.002 * sampling_rate))
except:
print('embed problem')
embed()
try:
pass
except:
print('smoothed thing')
embed()
skip_nr = 2
xlim = [0, 1000 * skip_nr / np.abs(dfs[m])]
nr_example = 0 # 'no'#0
##########################################
try:
axt = plt.subplot(grid2[0, counter])
except:
print('axt something')
embed()
axts.append(axt)
stimulus = eods_all[nr_example] # eods_g + Efield
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all2')
embed()
eods_am, eod_norm = extract_am(stimulus, time, norm=False,
kind='linear')
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
am = False
if am:
axt.plot(time, eods_am, color='red')
scatter_extra = False
if scatter_extra:
axt.scatter(spike_times[nr_example],
np.mean(eod_norm) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
axt.show_spines('')
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axt.text(1, 1, '$c=%s$' % contrast + '$\%$', ha='right',
transform=axt.transAxes)
###########################################
# time spikes
axt = plt.subplot(grid2[1, counter])
axt.set_ylabel('local')
axt.show_spines('')
time = np.arange(0, len(V_1[nr_example]) / sampling_rate,
1 / sampling_rate) * 1000
# ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht
# das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte
umstuelp = False
if umstuelp:
# ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix!
spikes_umstuelpen(eod, sampling_rate, time)
eods_cut, spikes_cut, times_cut, cut_next, smoothed_cut = cut_spike_snippets(
spike_times[nr_example], period_based=True,
array_cut2=np.arange(0, len(
eods_all[nr_example]) / sampling_rate,
skip_nr / np.abs(dfs[m])),
end=2000, smoothened=smooth[nr_example],
time_eod=time / 1000, norming=False)
axt.eventplot(np.array(spikes_cut[0:4]) * 1000,
color='black') # lineoffsets=np.max(V_1[nr_example])* np.ones(
axt.set_xlim(xlim)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axt.show_spines('')
axts.append(axt)
# convolved firing rate
axf = plt.subplot(grid2[2, counter])
if len(smooth[nr_example]) != len(time):
time_here = time[0:len(smooth[nr_example])]
else:
time_here = time # [0:len(smooth[nr_example])]
mean_firing = True
if mean_firing:
lengths = []
for sm in smoothed_cut[0:4]:
lengths.append(len(sm))
sms = []
for sm in smoothed_cut[0:4]:
sms.append(sm[0:np.min(lengths)])
time_here = time[0:np.shape(sms)[1]]
axf.plot(time_here, np.mean(sms, axis=0), color='grey',
linewidth=0.5)
else:
axf.plot(time_here, smooth[nr_example], color='grey',
linewidth=0.5)
axf.show_spines('')
axf.set_xlim(xlim)
axfs.append(axf)
##########################################
# time psd
axp2 = plt.subplot(grid2[4, counter])
ps = []
maxx = 1000
for s, spikes_mat in enumerate(spikes_mats):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat),
Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
if s == nr_example:
pass
else:
pass
axp2.set_xlim(0, maxx)
axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2,
linestyle='-')
pp = np.mean(ps, axis=0)
eodf = np.mean(frame_name.eod_fr)
names = ['0', '01', '02', '012']
names_here = [names[1]] #
extend = True
colors_array = ['pink', 'green']
if float(contrast) > 2:
name = names_here[
0]
else:
name = 'eodf'
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(
[],
colors_array,
df_chosen,
maxx,
eodf,
color_eodf='black',
name=
name,
color_stim='grey',
color_stim_mult='grey')
plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0,
colors_peaks, limit=1200,
perc_peaksize=0.15,
alphas=alphas, extend=extend, ms=18,
clip_on=False)
axp2.set_xlabel('Frequency [Hz]')
if counter != 0:
remove_yticks(axp2)
else:
axp2.set_ylabel(power_spectrum_name())
axps.append(axp2)
#############################
# spike_times[nr_example]
axi = plt.subplot(grid2[-1, counter])
plt_isis_phaselocking(axi, frame_name, spike_times)
axi.set_xticks_delta(2)
axi.set_xlim(0, 13)
try:
axts[0].get_shared_y_axes().join(*axts[0::2])
except:
print('axt problem')
embed()
axts[1].get_shared_y_axes().join(*axts[1::2])
axts[0].get_shared_x_axes().join(*axts)
join_y(axfs)
join_y(axps)
join_x(axps)
fig = plt.gcf()
fig.tag([axts[4], axts[2], axts[0], ], xoffs=-3)
firing_rate_scalebars(axfs[np.where(np.array(counters) == 0)[0][0]],
length=10)
individual_tag = 'data ' + cell + '_DF_chosen_' + str(
df_chosen) + mt_type
save_visualization(individual_tag, show)
print('plotted')
file.close()
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_big_final(datasets=['2020-10-20-ad-invivo-1'],
dfs_all_unique_given=[25], limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
frame_all = pd.read_pickle(load_folder_name('calc_phaselocking') + '/calc_phaselocking-phaselocking5_big.pkl')
colors = ['red', 'green', 'purple', 'blue']
try:
plot_style()
except:
print('plotstyle not there')
if len(datasets) < 1:
datasets, data_dir = find_all_dir_cells()
datasets = np.sort(datasets)[::-1]
stop_cell = '2018-11-20-af-invivo-1'
datasets = find_stop_cell(datasets, stop_cell)
for i, cell in enumerate(datasets):
path = load_folder_name('data') + '/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1']
df_pos = False
if cell not in cells_exclude:
if os.path.exists(path):
print('exists')
try:
file = nix.File.open(path, nix.FileMode.ReadOnly)
cont0 = True
except:
cont0 = False
if cont0:
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
if cont2 == True:
print('cont2')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) < 1:
# falls frame 5 noch nicht fertig ist haben wir ja den Backup von davor!
frame_all = pd.read_pickle(
'../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) > 0:
if len(dfs_all_unique_given) < 1:
dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1]
df_name = 'df_sign'
df_pos = '' # 'min_df'
dfs_all_unique = list(dfs_all_unique)
# todo: also hier gibts halt noch pobleme
if len(np.unique(np.array(dfs_all_unique))) < 2:
df_name = 'df'
dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1]
dfs_all_unique = list(dfs_all_unique)
else:
dfs_all_unique = dfs_all_unique_given
df_name = 'df_sign'
if len(dfs_all_unique) > 0:
if df_pos == 'min_df':
try:
dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]]
except:
print('df min')
embed()
for df_chosen in dfs_all_unique:
if not np.isnan(df_chosen):
frame_df = frame_cell[frame_cell[df_name] == df_chosen]
contrasts_all_unique = np.unique(frame_df.contrast)
contrasts_all_unique = contrasts_all_unique[~np.isnan(contrasts_all_unique)]
if len(contrasts_all_unique) > 0:
mt_types = frame_df.mt_type.unique()
for mt_type in mt_types:
if ('base' not in mt_type) & ('chirp' not in mt_type) & (
'SAM DC-1' not in mt_type):
contrasts_here = []
frame_type = frame_df[
(frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
default_settings(column=2, length=5)
plt.figure(figsize=(30, 8))
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.1,
right=0.97) #
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2,
hspace=0.2,
subplot_spec=gs0[1])
scores = ['amp_stim', 'amp_df', 'amp_f0',
'amp_fmax_interval'] # 'stim', 'f0',
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores,
cell,
axs)
axs.set_xlim(-10, 100)
axs = plt.subplot(grid1[1])
scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_']
if scores[0] in frame_cell.keys():
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores,
cell, axs)
axs.set_xlim(-10, 100)
if (cell == '2020-10-20-ad-invivo-1') & (
50 == df_chosen): # das erst fehlt aus welchem Grund auch immer
reduce = 0
else:
reduce = 0
nr_col = int(len(np.unique(contrasts_all_unique))) - reduce
grid2 = gridspec.GridSpecFromSubplotSpec(7, nr_col,
height_ratios=[1, 0.5, 1,
0.1, 1, 0.7,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axfs = []
axps = []
mt_names = frame_type.mt_name.unique()
for m, mt_name in enumerate(mt_names):
frame_name = frame_type[frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mts = b.multi_tags[mt_name]
print(mts.name)
name = mts.name
try:
contrast = name.split('=')[1].split('%')[0]
cont3 = True
except:
cont3 = False
print('contrasts')
if cont3:
if contrast not in contrasts_here:
print(contrast)
if np.isnan(float(contrast)):
counter = 0
else:
counter = np.where(
np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[0][0] - reduce # +1
try:
contrasts_here.append(contrast)
except:
print('embed problem')
embed()
try:
pass
except:
pass
features = []
for ff, f in enumerate(mts.features):
if 'id' in f.data.name:
pass
elif 'Contrast' in f.data.name:
pass
elif 'DeltaF' in f.data.name:
pass
else:
features.append(f.data.name)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
eod = b.data_arrays['LocalEOD-1'][:]
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
for m in mt_idxs: # range(len(mts.positions[:]))
try:
eods, _ = link_arrays_eod(b,
mts.positions[:][m],
mts.extents[:][m],
'LocalEOD-1')
except:
print('eods thing')
embed()
eods_all.append(eods)
eods_g, sampling_rate = link_arrays_eod(b,
mts.positions[
:][m],
mts.extents[
:][m],
'EOD')
v_1, sampling_rate = link_arrays_eod(b,
mts.positions[
:][m],
mts.extents[:][
m],
'V-1')
eods_all_g.append(eods_g)
V_1.append(v_1)
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods),
Fs=sampling_rate,
NFFT=nfft,
noverlap=nfft // 2)
else:
pass
spikes = link_arrays_spikes(b,
first=mts.positions[:][
m],
second=mts.extents[:][
m],
minus_spikes=
mts.positions[:][
m]) * 1000
spike_times.append(spikes) # - cut#
print(len(spike_times))
smooth = []
spikes_mats = []
for s in range(len(spike_times)):
try:
spikes_mat = cr_spikes_mat(
spike_times[s] / 1000,
sampling_rate,
int(
mts.extents[:][
mt_idxs[
s]] * sampling_rate)) # time[-1] * sampling_rate
except:
print('mts prob')
embed()
spikes_mats.append(spikes_mat)
# für den Mean danach schneiden wir das wie das kürzeste
try:
smooth.append(gaussian_filter(
spikes_mat[
0:int(np.min(
mts.extents[:]) * sampling_rate)],
sigma=0.0005 * sampling_rate))
except:
print('embed problem')
embed()
try:
smooth_mean = np.mean(smooth, axis=0)
except:
print('smoothed thing')
embed()
plt.suptitle('data ' + cell + ' ' + mts.name)
xlim = [0, 40]
nr_example = 0 # 'no'#0
##########################################
try:
axt = plt.subplot(grid2[0, counter])
except:
print('axt something')
embed()
axts.append(axt)
stimulus = eods_all[nr_example] # eods_g + Efield
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all2')
embed()
eods_am, eod_norm = extract_am(stimulus, time,
norm=False,
kind='linear') # 'cubic'
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
axt.plot(time, eods_am, color='red')
axt.scatter(spike_times[nr_example],
np.mean(eod_norm) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
try:
axt.show_spines('')
except:
print('not there')
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axt.text(1, 1, 'c=' + str(contrast), ha='right',
transform=axt.transAxes)
###########################################
# time spikes
axt = plt.subplot(grid2[1, counter])
axt.set_ylabel('local')
try:
axt.show_spines('')
except:
print('not there')
time = np.arange(0,
len(V_1[nr_example]) / sampling_rate,
1 / sampling_rate) * 1000
# ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht
# das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte
umstuelp = False
if umstuelp:
# ah aber ich hab auch noch das umstuelpen aus dem susept das für den Appendix!
spikes_umstuelpen(eod, sampling_rate, time)
axt.scatter(spike_times[nr_example],
np.max(V_1[nr_example]) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
axt.set_xlim(xlim)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
try:
axt.show_spines('')
except:
print('not there')
axts.append(axt)
# convolved firing rate
axf = plt.subplot(grid2[2, counter])
if len(smooth[nr_example]) != len(time):
time_here = time[0:len(smooth[nr_example])]
else:
time_here = time # [0:len(smooth[nr_example])]
mean_firing = True # smooth_mean
if mean_firing:
axf.plot(time_here, smooth_mean,
color='grey', linewidth=0.5)
else:
axf.plot(time_here, smooth[nr_example],
color='grey', linewidth=0.5)
try:
axt.show_spines('')
except:
print('not there')
axf.set_xlim(xlim)
axfs.append(axf)
if counter == 0:
firing_rate_scalebars(axf)
##########################################
# time psd
axp2 = plt.subplot(grid2[4, counter])
ps = []
maxx = 1000
for s, spikes_mat in enumerate(spikes_mats):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat),
Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
axp2.set_xlim(0, maxx)
axp2.plot(f, np.mean(ps, axis=0), color='black',
zorder=2,
linestyle='-')
pp = np.mean(ps, axis=0)
eodf = np.mean(frame_name.eod_fr)
names = ['0', '01', '02', '012']
names_here = [names[1]] #
extend = True
colors_array = ['pink', 'green']
if contrast > 1:
name = names_here[0]
else:
name = 'eodf'
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(
[],
colors_array,
np.abs(df_chosen),
maxx,
eodf,
color_eodf='black',
name=
name,
color_stim='pink',
color_stim_mult='pink')
plt_peaks_several(freqs, [pp], axp2, pp, f, labels, 0,
colors_peaks, limit=1200,
alphas=alphas, extend=extend, ms=18,
clip_on=False)
axp2.set_xlabel('Frequency [Hz]')
if counter != 0:
remove_yticks(axp2)
else:
axp2.set_ylabel(power_spectrum_name())
axps.append(axp2)
#############################
# spike_times[nr_example]
axi = plt.subplot(grid2[-1, counter])
plt_isis_phaselocking(axi, frame_name,
spike_times)
if len(axts) > 0:
try:
axts[0].get_shared_y_axes().join(*axts[0::2])
except:
print('axt problem')
embed()
axts[1].get_shared_y_axes().join(*axts[1::2])
axts[0].get_shared_x_axes().join(*axts)
join_y(axfs)
join_y(axps)
join_x(axps)
individual_tag = 'data_' + cell + '_DF_chosen_' + str(
df_chosen) + mt_type
save_visualization(individual_tag, show)
print('plotted')
file.close()
print('finished examples')
embed()
def plt_isis_phaselocking(axi, frame_name, spike_times):
isis = []
for sp_nr, sp in enumerate(np.array(spike_times)):
isis.append(
calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr]))
axi.hist(np.concatenate(isis), bins=100, color='grey')
axi.axvline(1, color='black', linestyle='--', linewidth=0.5)
try:
axi.show_spines('b')
except:
pass
axi.set_xlabel(isi_xlabel())
def firing_rate_scalebars(axt, length=10):
try:
axt.xscalebar(0.1, -0.02, length, 'ms', va='right', ha='bottom') ##ylim[0]
axt.yscalebar(-0.02, 0.1, 500, 'Hz', va='bottom', ha='left')
except:
pass
def spikes_umstuelpen(eod, sampling_rate, time):
shift_period = 0.005 # period * 2#
shifts = np.arange(0, 200 * shift_period, shift_period)
time_b = np.arange(0, len(beat) / sampling_rate,
1 / sampling_rate)
am_corr = extract_am(beat, time_b, eodf=eod, norm=False,
extract='globalmax', kind='linear')[0]
len_smoothed, smoothed_trial, all_spikes, maxima, error, spikes_cut, beat_cut, am_corr_cut = create_shifted_spikes(
eod, len_smoothed_b, len_smoothed, beat, am_corr,
sampling_rate, time_b, time, smoothed, shifts, plot_segment,
tranformed_spikes, version=version)
_, _, _, _, _ = get_most_similiar_spikes(
all_spikes, am_corr_cut, beat_cut,
error, maxima, spikes_cut)
def plt_beats_modulation_several_with_overview_nice_big_max(limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = True
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
big_adapt = True
if big_adapt:
frame_big = frame_desired[(frame_desired.contrast > 25) | (frame_desired.contrast_true > 25)]
else:
frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)]
datasets = frame_big.cell.unique()
datasets_loaded = datasets[::-1]
else:
frame = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl')
datasets_loaded = np.sort(frame.cell.unique())[::-1]
datasets = ['2020-10-20-ad-invivo-1', '2020-10-27-ac-invivo-1', '2020-10-29-ai-invivo-1', '2018-09-13-aa-invivo-1',
'2020-10-29-ac-invivo-1'] # [,'2020-10-29-ai-invivo-1',]
datasets.extend(datasets_loaded)
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
colors = ['red', 'green', 'purple', 'blue']
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1']
if cell not in cells_exclude:
if os.path.exists(path):
print('exists')
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import test_rlx
test_rlx()
if cont2 == True:
print('cont2')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) < 1:
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) > 0:
dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1]
df_name = 'df_sign'
df_pos = '' # 'min_df'
dfs_all_unique = list(dfs_all_unique)
if len(np.unique(np.array(dfs_all_unique))) < 2:
df_name = 'df'
dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1]
dfs_all_unique = list(dfs_all_unique)
if len(dfs_all_unique) > 0:
if df_pos == 'min_df':
try:
dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]]
except:
print('df min')
embed()
contrasts_all_unique = np.unique(frame_cell.contrast)
if len(contrasts_all_unique) > 1:
for df_chosen in dfs_all_unique:
if np.abs(df_chosen) < 75:
if not np.isnan(df_chosen):
frame_df = frame_cell[frame_cell[df_name] == df_chosen]
mt_types = frame_df.mt_type.unique()
for mt_type in mt_types:
if 'base' not in mt_type:
contrasts_here = []
frame_type = frame_df[
(frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2,
hspace=0.2,
subplot_spec=gs0[1])
scores = ['amp_stim', 'amp_df', 'amp_f0',
'amp_fmax_interval'] # 'stim', 'f0',
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores,
cell,
axs)
axs.set_xlim(-10, 100)
axs = plt.subplot(grid1[1])
scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_']
if scores[0] in frame_cell.keys():
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores,
cell, axs)
axs.set_xlim(-10, 100)
nr_col = int(len(np.unique(contrasts_all_unique)) - 1)
grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col,
height_ratios=[1, 1, 0.5,
1, 1,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axps = []
mt_names = frame_type.mt_name.unique()
for m, mt_name in enumerate(mt_names):
frame_name = frame_type[frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mts = b.multi_tags[mt_name]
print(mts.name)
name = mts.name
contrast = name.split('=')[1].split('%')[0]
if contrast not in contrasts_here:
print(contrast)
if np.isnan(float(contrast)):
counter = 0
else:
counter = np.where(
np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[
0][0] # +1
try:
contrasts_here.append(contrast)
except:
print('embed problem')
embed()
try:
dfs = [mts.metadata[mts.name]['DeltaF']] * len(
mts.positions[:])
except:
dfs = mts.metadata['DeltaF']
features, dfs, contrasts, id = get_features_and_info(mts,
dfs=dfs,
contrasts=contrasts)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
datas_new.append(cell)
try:
pass
except:
print('rlx problem')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
for m in mt_idxs: # range(len(mts.positions[:]))
try:
eods, _ = link_arrays_eod(b, mts.positions[:][m],
mts.extents[:][m],
'LocalEOD-1')
except:
print('eods thing')
embed()
eods_all.append(eods)
eods_g, sampling_rate = link_arrays_eod(b,
mts.positions[
:][m],
mts.extents[:][
m],
'EOD')
v_1, sampling_rate = link_arrays_eod(b,
mts.positions[:][
m],
mts.extents[:][m],
'V-1')
eods_all_g.append(eods_g)
V_1.append(v_1)
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods),
Fs=sampling_rate,
NFFT=nfft,
noverlap=nfft // 2)
else:
pass
spike_times.append(
(mts.retrieve_data(m, 'Spikes-1')[:] -
mts.positions[
m]) * 1000) # - cut
print(len(spike_times))
smooth = []
spikes_mats = []
for s in range(len(spike_times)):
try:
spikes_mat = cr_spikes_mat(spike_times[s] / 1000,
sampling_rate,
int(
mts.extents[:][
mt_idxs[
s]] * sampling_rate)) # time[-1] * sampling_rate
except:
print('mts prob')
embed()
spikes_mats.append(spikes_mat)
# für den Mean danach schneiden wir das wie das kürzeste
try:
smooth.append(gaussian_filter(
spikes_mat[
0:int(np.min(mts.extents[:]) * sampling_rate)],
sigma=0.0005 * sampling_rate))
except:
print('embed problem')
embed()
try:
pass
except:
print('smoothed thing')
embed()
plt.suptitle('data ' + cell + ' ' + mts.name)
xlim = []
nr_example = 0
##########################################
# time psd
axp = plt.subplot(grid2[3, counter])
axp2 = plt.subplot(grid2[4, counter])
ps = []
maxx = 1000
for s, spikes_mat in enumerate(spikes_mats):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat),
Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
if s == nr_example:
color = 'purple'
zorder = 100
axp.plot(f, p, color=color, zorder=zorder)
eodf = np.mean(frame_name.eod_fr)
names = ['0', '01', '02', '012']
names_here = [names[1]] #
extend = True
colors_array = ['pink', 'green']
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(
[],
colors_array,
df_chosen,
maxx,
eodf,
color_eodf='black',
name=
names_here[
0],
color_stim='pink',
color_stim_mult='pink')
plt_peaks_several(freqs, [p], axp, p, f, labels, 0,
colors_peaks, alphas=alphas,
extend=extend, ms=18,
clip_on=True)
else:
color = 'grey'
zorder = 1
axp2.plot(f, p, color=color, zorder=zorder)
axp2.set_xlim(0, maxx)
axp.set_xlim(0, maxx)
remove_xticks(axp)
axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2,
linestyle='--')
axp2.set_xlabel('Power [Hz]')
if counter != 0:
remove_yticks(axp2)
axp2.set_ylabel('')
if counter != 0:
remove_yticks(axp)
axp.set_ylabel('')
axps.append(axp)
axps.append(axp2)
###########################################
# time spikes
stimulus = eods_all[nr_example] # eods_g + Efield
axt = plt.subplot(grid2[0, counter])
axt.set_ylabel('local')
time = np.arange(0, len(V_1[nr_example]) / sampling_rate,
1 / sampling_rate) * 1000
axt.plot(time, V_1[nr_example], color='purple',
linewidth=0.5)
axt.scatter(spike_times[nr_example],
np.max(V_1[nr_example]) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
if len(xlim) > 0:
axt.set_xlim(xlim)
axt.set_title(contrast)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axts.append(axt)
axt = plt.subplot(grid2[1, counter])
axts.append(axt)
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all')
embed()
eods_am, eod_norm = extract_am(stimulus, time, norm=False)
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
axt.plot(time, eods_am, color='red')
axt.scatter(spike_times[nr_example],
np.mean(eod_norm) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
if len(xlim) > 0:
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
#############################
# spike_times[nr_example]
axi = plt.subplot(grid2[-1, counter])
isis = []
for sp_nr, sp in enumerate(np.array(spike_times)):
isis.append(
calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr]))
axi.hist(np.concatenate(isis), bins=100)
axi.axvline(1, color='grey', linestyle='--')
try:
axts[0].get_shared_y_axes().join(*axts[0::2])
except:
print('axt problem')
embed()
axts[1].get_shared_y_axes().join(*axts[1::2])
axts[0].get_shared_x_axes().join(*axts)
join_y(axps)
join_x(axps)
individual_tag = 'data ' + cell + '_DF_chosen_' + str(
df_chosen) + mt_type
save_visualization(individual_tag, show, pdf=True)
print('plotted')
file.close()
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice_big(limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = True
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
big_adapt = True
if big_adapt:
frame_big = frame_desired[(frame_desired.contrast > 25) | (frame_desired.contrast_true > 25)]
else:
frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)]
datasets = frame_big.cell.unique()
datasets_loaded = datasets[::-1]
else:
frame = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl')
datasets_loaded = np.sort(frame.cell.unique())[::-1]
datasets = ['2020-10-29-ai-invivo-1', '2020-10-20-ad-invivo-1', '2018-09-13-aa-invivo-1',
'2020-10-29-ac-invivo-1'] # [,'2020-10-29-ai-invivo-1',]
datasets.extend(datasets_loaded)
plt_spectra_compar(datas_new, datasets, nfft, show, add='plt_beats_modulation_several_with_overview_nice_big')
print('finished examples')
embed()
def plt_beats_modulation_several_with_overview_nice(limit=1,
duration_exclude=0.45, nfft=int(4096), show=False):
# Function to load the experimental data
save_name = 'beat_results_smoothed_limit' + str(limit) + '_minimalduration_' + str(duration_exclude) + '_all'
print(save_name)
datas_new = []
old_cells = True
if old_cells:
# das ist falls ich die alten Datensätze untersuchen will
_, _ = find_all_dir_cells()
frame_desired = pd.read_csv('../code/calc_base/find_contrasts_SAMs-SAM_amplitudes.csv')
big_adapt = False
if big_adapt:
frame_big = frame_desired[(frame_desired.contrast > 29) | (frame_desired.contrast_true > 29)]
else:
frame_big = frame_desired # [(frame_desired.contrast > 5) | (frame_desired.contrast_true > 5)]
datasets = frame_big.cell.unique()
datasets_loaded = datasets[::-1]
else:
frame = pd.read_pickle(
load_folder_name(
'calc_cocktailparty') + '/calc_data_peaks_threewave-spikes_all_psdEOD_1_nfft_16384[05,original]_psdEOD__sqrt__points_5_ALL_.pkl')
datasets_loaded = np.sort(frame.cell.unique())[::-1]
datasets = ['2018-09-13-aa-invivo-1', '2020-10-20-ad-invivo-1']
datasets.extend(datasets_loaded)
plt_spectra_compar(datas_new, datasets, nfft, show, add='plt_beats_modulation_several_with_overview_nice')
print('finished examples')
embed()
def plt_spectra_compar(datas_new, datasets, nfft, show, add=''):
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking5_big.pkl')
colors = ['red', 'green', 'purple', 'blue']
for i, cell in enumerate(datasets):
path = '../data/cells/' + cell + '/' + cell + ".nix" # cell.split(os.path.sep)[-1]
print(cell)
cells_exclude = ['2020-10-29-af-invivo-1', '2019-05-07-cb-invivo-1']
if cell not in cells_exclude:
if os.path.exists(path):
print('exists')
file = nix.File.open(path, nix.FileMode.ReadOnly)
b = file.blocks[0]
cont2 = False
names = []
names_dataarrays = []
for stims in b.data_arrays:
# this seems to be reasonable because the only data with sinewave with higher number was not useful (das betrifft aber nur 2019-05-07-ab und 2019-05-07-ac, wobei wenn ich jetzt ganau schau scheinen das nur komische trials zu sein die wir nicht wirklich brauchen
if 'sinewave-1_Contrast' in stims.name:
names.append(stims.name)
names_dataarrays.append(stims.name)
'sinewave''SAM'
sam = find_mt(b, 'SAM')
sine = find_mt(b, 'sine')
if (len(sine) > 0) or (len(sam) > 0):
cont2 = True
test = False
if test:
from utils_test import tes_rlx2
tes_rlx2()
if cont2 == True:
print('cont2')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) < 1:
frame_all = pd.read_pickle('../code/calc_phaselocking/calc_phaselocking-phaselocking_big.pkl')
frame_cell = frame_all[frame_all['cell'] == cell]
if len(frame_cell) > 0:
dfs_all_unique = np.unique(frame_cell.df_sign.dropna())[::-1]
df_name = 'df_sign'
df_pos = '' # 'min_df'
dfs_all_unique = list(dfs_all_unique)
if len(np.unique(np.array(dfs_all_unique))) < 2:
df_name = 'df'
dfs_all_unique = np.unique(frame_cell.df.dropna())[::-1]
dfs_all_unique = list(dfs_all_unique)
if len(dfs_all_unique) > 0:
if df_pos == 'min_df':
try:
dfs_all_unique = [dfs_all_unique[np.argmin(np.abs(dfs_all_unique))]]
except:
print('df min')
embed()
contrasts_all_unique = np.unique(frame_cell.contrast)
if len(contrasts_all_unique) > 1:
for df_chosen in dfs_all_unique:
if not np.isnan(df_chosen):
frame_df = frame_cell[frame_cell[df_name] == df_chosen]
mt_types = frame_df.mt_type.unique()
for mt_type in mt_types:
if 'base' not in mt_type:
contrasts_here = []
frame_type = frame_df[
(frame_df.mt_type == mt_type)] # | (frame_df.mt_type == 'base')
gs0 = gridspec.GridSpec(1, 2, width_ratios=[4, 1], hspace=0.4,
left=0.045,
right=0.97) #
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.2, hspace=0.2,
subplot_spec=gs0[1])
scores = ['amp_stim', 'amp_df', 'amp_f0',
'amp_fmax_interval'] # 'stim', 'f0',
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores, cell,
axs)
axs.set_xlim(-10, 100)
axs = plt.subplot(grid1[1])
scores = ['dsp_perc95_', 'dsp_max_', 'dsp_mean_']
if scores[0] in frame_cell.keys():
plt_single_phaselockloss(colors, frame_cell, df_chosen, scores,
cell, axs)
axs.set_xlim(-10, 100)
nr_col = int(len(np.unique(contrasts_all_unique)) - 1)
grid2 = gridspec.GridSpecFromSubplotSpec(6, nr_col,
height_ratios=[1, 1, 0.5, 1, 1,
1],
wspace=0.2, hspace=0.2,
subplot_spec=gs0[0])
axts = []
axps = []
mt_names = frame_type.mt_name.unique()
for m, mt_name in enumerate(mt_names):
frame_name = frame_type[frame_type.mt_name == mt_name]
mt_idxs = list(map(int, np.array(frame_name.mt_idx)))
mts = b.multi_tags[mt_name]
print(mts.name)
name = mts.name
contrast = name.split('=')[1].split('%')[0]
if contrast not in contrasts_here:
print(contrast)
if np.isnan(float(contrast)):
counter = 0
else:
counter = np.where(
np.round(contrasts_all_unique, 2) == np.round(
float(contrast), 2))[
0][0] # +1
try:
contrasts_here.append(contrast)
except:
print('embed problem')
embed()
try:
dfs = [mts.metadata[mts.name]['DeltaF']] * len(
mts.positions[:])
except:
dfs = mts.metadata['DeltaF']
features, dfs, contrasts, id = get_features_and_info(mts,
dfs=dfs,
contrasts=contrasts)
eod_frs, eod_redo = get_eod_fr_simple(b, names)
names = []
for stims in b.data_arrays:
names.append(stims.name)
print(cell + ' Beat calculation')
datas_new.append(cell)
try:
pass
except:
print('rlx problem')
eods_all = []
eods_all_g = []
V_1 = []
spike_times = []
for m in mt_idxs: # range(len(mts.positions[:]))
try:
eods, _ = link_arrays_eod(b, mts.positions[:][m],
mts.extents[:][m],
'LocalEOD-1')
except:
print('eods thing')
embed()
eods_all.append(eods)
eods_g, sampling_rate = link_arrays_eod(b,
mts.positions[:][m],
mts.extents[:][m],
'EOD')
v_1, sampling_rate = link_arrays_eod(b, mts.positions[:][m],
mts.extents[:][m],
'V-1')
eods_all_g.append(eods_g)
V_1.append(v_1)
if eod_redo == True:
p, f = ml.psd(eods - np.mean(eods), Fs=sampling_rate,
NFFT=nfft,
noverlap=nfft // 2)
else:
pass
spike_times.append(
(mts.retrieve_data(m, 'Spikes-1')[:] - mts.positions[
m]) * 1000) # - cut
print(len(spike_times))
smooth = []
spikes_mats = []
for s in range(len(spike_times)):
try:
spikes_mat = cr_spikes_mat(spike_times[s] / 1000,
sampling_rate,
int(
mts.extents[:][
mt_idxs[
s]] * sampling_rate)) # time[-1] * sampling_rate
except:
print('mts prob')
embed()
spikes_mats.append(spikes_mat)
# für den Mean danach schneiden wir das wie das kürzeste
try:
smooth.append(gaussian_filter(
spikes_mat[
0:int(np.min(mts.extents[:]) * sampling_rate)],
sigma=0.0005 * sampling_rate))
except:
print('embed problem')
embed()
try:
pass
except:
print('smoothed thing')
embed()
plt.suptitle('data ' + cell + ' ' + mts.name)
xlim = [0, 40]
nr_example = 0
##########################################
# time psd
axp = plt.subplot(grid2[3, counter])
axp2 = plt.subplot(grid2[4, counter])
ps = []
maxx = 1000
for s, spikes_mat in enumerate(spikes_mats):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat),
Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
if s == nr_example:
color = 'purple'
zorder = 100
axp.plot(f, p, color=color, zorder=zorder)
eodf = np.mean(frame_name.eod_fr)
names = ['0', '01', '02', '012']
names_here = [names[1]] #
extend = True
colors_array = ['pink', 'green']
freqs, colors_peaks, labels, alphas = chose_all_freq_combos(
[],
colors_array,
df_chosen,
maxx,
eodf,
color_eodf='black',
name=
names_here[
0],
color_stim='pink',
color_stim_mult='pink')
plt_peaks_several(freqs, [p], axp, p, f, labels, 0,
colors_peaks, alphas=alphas,
extend=extend, ms=18, clip_on=True)
else:
color = 'grey'
zorder = 1
axp2.plot(f, p, color=color, zorder=zorder)
axp2.set_xlim(0, maxx)
axp.set_xlim(0, maxx)
remove_xticks(axp)
axp2.plot(f, np.mean(ps, axis=0), color='black', zorder=2,
linestyle='--')
axp2.set_xlabel('Power [Hz]')
if counter != 0:
remove_yticks(axp2)
axp2.set_ylabel('')
if counter != 0:
remove_yticks(axp)
axp.set_ylabel('')
axps.append(axp)
axps.append(axp2)
###########################################
# time spikes
stimulus = eods_all[nr_example] # eods_g + Efield
axt = plt.subplot(grid2[0, counter])
axt.set_ylabel('local')
time = np.arange(0, len(V_1[nr_example]) / sampling_rate,
1 / sampling_rate) * 1000
axt.plot(time, V_1[nr_example], color='purple', linewidth=0.5)
axt.scatter(spike_times[nr_example],
np.max(V_1[nr_example]) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
axt.set_xlim(xlim)
axt.set_title(contrast)
remove_xticks(axt)
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axts.append(axt)
axt = plt.subplot(grid2[1, counter])
axts.append(axt)
try:
time = np.arange(0, len(stimulus) / sampling_rate,
1 / sampling_rate) * 1000
except:
print('time all')
embed()
eods_am, eod_norm = extract_am(stimulus, time, norm=False)
axt.plot(time, eod_norm, color='grey', linewidth=0.5)
axt.plot(time, eods_am, color='red')
axt.scatter(spike_times[nr_example],
np.mean(eod_norm) * np.ones(
len(spike_times[nr_example]))
, color='black', s=10,
marker='|') # np.max(v1)*lineoffsets=np.max(V_1[nr_example]),
axt.set_xlim(xlim)
axt.set_xlabel('Time [ms]')
if counter != 0:
remove_yticks(axt)
axt.set_ylabel('')
axi = plt.subplot(grid2[-1, counter])
isis = []
for sp_nr, sp in enumerate(np.array(spike_times)):
isis.append(
calc_isi(sp / 1000, frame_name.eod_fr.iloc[sp_nr]))
axi.hist(np.concatenate(isis), bins=100)
axi.axvline(1, color='grey', linestyle='--')
try:
axts[0].get_shared_y_axes().join(*axts[0::2])
except:
print('axt problem')
embed()
axts[1].get_shared_y_axes().join(*axts[1::2])
axts[0].get_shared_x_axes().join(*axts)
join_y(axps)
join_x(axps)
individual_tag = 'data ' + cell + '_DF_chosen_' + str(
df_chosen) + mt_type
save_visualization(add + individual_tag, show)
print('plotted')
file.close()
def get_eod_fr_simple(b, names):
if 'sinewave-1_EOD Rate' in names:
eod_frs = b.data_arrays['sinewave-1_EOD Rate'][:]
eod_redo = False
else:
eod_frs = b.metadata['Recording']['Subject']['EOD Frequency']
eod_redo = True
return eod_frs, eod_redo
def plt_response(ax, sampling_rate, spike_times, smooth_mean, spikes_mats, counter, stimulus, extract=True):
######################
# plt local eod
xlim = (0, 200)
ax[0, counter].set_ylabel('mV')
ax[0, counter].set_title('local')
time = np.arange(0, len(stimulus) / sampling_rate, 1 / sampling_rate) * 1000
if extract:
eods_am, eod_norm = extract_am(stimulus, time, norm=False)
ax[0, counter].plot(time, eod_norm)
ax[0, counter].plot(time, eods_am, color='red')
else:
ax[0, counter].plot(time, stimulus, color='red')
ax[0, counter].set_xlim(xlim)
color = 'grey'
#####################
# plt smpikes mat
ax[1, counter].eventplot(spike_times, color=color) # s, np.ones(len(spike_times)),
ax[1, counter].set_xlim(xlim)
ax[1, counter].set_ylabel('Run nr')
remove_xticks(ax[0, counter])
remove_xticks(ax[1, counter])
remove_xticks(ax[2, counter])
try:
ax[2, counter].plot(np.arange(0, len(smooth_mean) / 40000, 1 / 40000) * 1000, smooth_mean,
color=color)
except:
print('smooth problem')
embed()
ax[2, counter].set_xlim(xlim)
ax[2, counter].set_ylabel('FR [Hz]')
ps = []
for spikes_mat in spikes_mats:
p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate,
NFFT=2 ** 13,
noverlap=2 ** 13 / 2)
ps.append(p)
ax[3, counter].plot(f, p, color=color)
ax[3, counter].set_xlim(0, 1000)
remove_xticks(ax[3, counter])
ax[3, counter].plot(f, np.mean(ps, axis=0), color='black')
ax[4, counter].plot(f, np.mean(ps, axis=0), color='black')
ax[4, counter].set_ylabel('[Hz]')
ax[4, counter].set_xlim(0, 1000)
ax[4, counter].set_xlabel('F [Hz]')
if counter != 0:
remove_yticks(ax[1, counter])
remove_yticks(ax[0, counter])
remove_yticks(ax[2, counter])
remove_yticks(ax[3, counter])
remove_yticks(ax[4, counter])
ax[1, counter].set_ylabel('')
ax[0, counter].set_ylabel('')
ax[2, counter].set_ylabel('')
ax[3, counter].set_ylabel('')
ax[4, counter].set_ylabel('')
def plt_cocktailparty_lines(ax, frame_df):
frame_df_mean = frame_df # .groupby(['c1']).mean()#, 'c2'])
cs = {}
means = {}
scores_data = ['amp_f0_01_original', 'amp_f0_012_original', 'amp_f0_02_original',
'amp_f0_0_original',
'amp_B1_01_original', 'amp_B1_012_original', 'amp_B2_02_original',
'amp_B2_012_original', ]
colors = ['green', 'purple', 'orange', 'black', 'green', 'blue', 'orange', 'red']
linestyles = ['--', '--', '--', '--', '-', '-', '-', '-']
for sss, score in enumerate(scores_data):
ax[sss].plot(np.sort(frame_df_mean['c1']),
frame_df_mean[score].iloc[np.argsort(frame_df_mean['c1'])],
color=colors[sss], linestyle=linestyles[
sss]) # +str(np.round(np.mean(group_restricted[score_data]))), label = 'c_small='+str(c_small)+' c_big='+str(c_big)
if sss not in means.keys():
means[sss] = []
cs[sss] = []
ax[sss].set_ylabel(score.replace('_mean', '').replace('amp_', '') + '[Hz]', fontsize=8)
ax[sss].set_xlabel('Contrast small')
ax[sss].set_xlabel('Contrast small')
def get_dfs_and_contrasts_from_calccocktailparty(cell, frame_data):
frame_data_cell = frame_data[(frame_data['cell'] == cell)]
c1_unique = np.sort(frame_data_cell.c1.unique())[::-1]
c1_unique_big = c1_unique[c1_unique > 7]
c2_unique = np.sort(frame_data_cell.c2.unique())[::-1]
c2_unique_big = c2_unique[c2_unique > 7]
DF1s = np.unique(np.round(frame_data_cell.m1, 2))
DF2s = np.unique(np.round(frame_data_cell.m2, 2))
return DF1s, DF2s, frame_data_cell, c2_unique, c2_unique_big, c1_unique_big, c1_unique
def plt_stim_response_saturation(a, arrays_here, arrays_sp, arrays_st, arrays_time, axes, axps, axts,
colors_array_here, f, f_counter, grid_ll, names,
nfft, sampling, time, freqs=[50], colors_peaks=['green', 'red'], xlim=[1, 1.12]):
grid_pt = gridspec.GridSpecFromSubplotSpec(3, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll) # hspace=0.4,wspace=0.2,len(chirps)
#############################
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, arrays_sp[a][0], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim)
#############################
axt = plt.subplot(grid_pt[1])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
#############################
axp = plt.subplot(grid_pt[2])
axps.append(axp)
pp, ff = ml.psd(arrays_here[a][0] - np.mean(arrays_here[a][0]), Fs=sampling, NFFT=nfft,
noverlap=nfft // 2)
pp = log_calc_psd('log', pp, np.max(pp))
plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs, colors_peaks=colors_peaks)
return axp, axt
def plt_stim_saturation(a, arrays_sp, arrays_st, axe, colors_array_here, f, f_counter, names, time,
xlim=[1, 1.12]):
if f != 0:
remove_yticks(axe)
if a != len(arrays_st) - 1:
remove_xticks(axe)
if f_counter == 0:
axe.set_ylabel(names[a])
try:
axe.plot(time, arrays_st[a], color=colors_array_here[a], linewidth=0.5) # colors_contrasts[c_nn]
except:
print('axe something')
embed()
axe.set_xlim(xlim)
axe.show_spines('')
spikes_in_vmem(arrays_sp, arrays_st[a], axe, type_here='stim')
def plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f, time,
xlim=[1, 1.12]):
if f != 0:
remove_yticks(axt)
if a != len(arrays_time) - 1:
remove_xticks(axt)
try:
axt.plot(time[(time < xlim[1]) & (time > xlim[0])], arrays_time[a][(time < xlim[1]) & (time > xlim[0])],
color=colors_array_here[a], clip_on=False) # colors_contrasts[c_nn]
except:
axt.plot(time[(time < xlim[1]) & (time > xlim[0])], arrays_time[a][0][(time < xlim[1]) & (time > xlim[0])],
color=colors_array_here[a], clip_on=False) # colors_contrasts[c_nn]
axt.set_xlim(xlim)
axt.show_spines('')
spikes_in_vmem(arrays_sp[a][0], arrays_time[a], axt, type_here='vmem')
def plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=[50, 50], colors_peaks=['blue', 'red'],
xlim=(0, 300), markeredgecolor=[],
labels=['DF1', 'DF2', 'DF1', 'DF2', 'DF1', 'DF2', 'DF1', 'DF2']):
axp.plot(ff[ff < xlim[1]], pp[ff < xlim[1]], color=colors_array_here[a])
axp.set_xlim(xlim)
plt_peaks_several(freqs, [pp], axp, pp, ff, labels, 0, colors_peaks, markeredgecolors=markeredgecolor)
def vary_contrasts(freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4,
gain=1,
cells_here=["2013-01-08-aa-invivo-1"],
fish_jammer='Alepto', us_name='', show=True):
runs = 1
n = 1
dev = 0.0005
#############################################
# plot a single ROC Curve for the model!
# das aus dem Lissabon talk und das was wir für Jörg verwenden werden
# also wir wollen hier viele Kontraste und einige Frequenzen
# das will ich noch für verschiedene Frequenzen und Kontraste
default_settings() # ts=13, ls=13, fs=13, lw = 0.7
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
a_f1s = [0.03] # np.logspace(np.log10(0.0001), np.log10(1), 25)
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [1]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_settings(column=2, length=8.5)
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str(
stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
c_nrs = [0.0002, 0.05, 0.5]
grid0 = gridspec.GridSpec(1, 1, bottom=0.05, top=0.92, left=0.09,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(2, 1,
wspace=0.04, hspace=0.27, height_ratios=[1, 2.5],
subplot_spec=grid0[0]) #
grid_u = gridspec.GridSpecFromSubplotSpec(1, len(freqs),
hspace=0.7,
wspace=0.1,
subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps)
grid_l = gridspec.GridSpecFromSubplotSpec(1, len(freqs),
hspace=0.7,
wspace=0.1,
subplot_spec=grid00[1]) # hspace=0.4,wspace=0.2,len(chirps)
#################################################################
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
grid_ll = gridspec.GridSpecFromSubplotSpec(3, len(c_nrs),
hspace=0.2,
wspace=0.2,
subplot_spec=grid_l[
f]) # hspace=0.4,wspace=0.2,len(chirps)
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
scores = ['amp_B1_01_mean', 'amp_B1_012_mean', 'amp_B2_02_mean',
'amp_B2_012_mean'] # 'amp_B1+B2_012_mean',
colors = ['green', 'blue', 'orange', 'red', 'grey']
colors_array = ['grey', 'green', 'orange', 'purple']
linestyles = ['-', '--', '-', '--', '--']
alpha = [1, 1, 1, 1, 1]
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
ax_u1 = plt.subplot(grid_u[f, 0])
ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2,
scores=scores, colors=colors,
linestyles=linestyles, alpha=alpha,
sum=False, B_replace='F')
if f != 0:
print('hi')
else:
ax_u1.set_ylabel('Hz')
plt.suptitle(cell_here + ' DF1=' + str(freq1) + ' DF2=' + str(freq2))
axts = []
axps = []
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs, cell=cell_here)
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.scatter(c_nrs, np.zeros(len(c_nrs)), color='black', marker='^', clip_on=False)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nr], n=n,
reshuffled=reshuffled, min_amps=min_amps)
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
arrays_time = arrays[1::] # v_mems[1::]
arrays_here = arrays[1::]
colors_array_here = colors_array[1::]
for a in range(len(arrays_here)):
grid_pt = gridspec.GridSpecFromSubplotSpec(2, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[
a, c_nn]) # hspace=0.4,wspace=0.2,len(chirps)
axt = plt.subplot(grid_pt[0])
axts.append(axt)
if f != 0:
remove_yticks(axt)
if a != len(arrays_time) - 1:
remove_xticks(axt)
if f_counter == 0:
axt.set_ylabel(names[a])
if a == 0:
axt.set_title(' c1=' + str(a_f1s[0]) + ' c2=' + str(a_f2s[0]))
axt.plot(time, arrays_time[a][0], color=colors_array_here[a]) # colors_contrasts[c_nn]
axt.set_xlim(1, 1.12)
#############################
axp = plt.subplot(grid_pt[1])
axps.append(axp)
pp, ff = ml.psd(arrays_here[a][0] - np.mean(arrays_here[a][0]), Fs=sampling, NFFT=nfft,
noverlap=nfft // 2)
axp.plot(ff, pp, color=colors_array_here[a]) # colors_contrasts[c_nn]
axp.set_xlim(0, 300)
if a != 2:
colors_peaks = [colors_array[1], colors_array[2]]
else:
colors_peaks = ['blue', 'red']
plt_peaks_several([freq1, np.abs(freq2)], [pp], axp, pp, ff, ['DF1', 'DF2'], 0,
colors_peaks)
if a != 2:
remove_xticks(axp)
if c_nn != 0:
remove_yticks(axt)
remove_yticks(axp)
axt.set_xlabel('Time [s]')
axp.set_xlabel('Frequency [Hz]')
f_counter += 1
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[0].legend(loc=(-0.07, 1), ncol=6)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_x(ax_us)
join_y(ax_us)
save_visualization(cell_here, show)
def spikes_in_vmem(arrays_sp, arrays_time, axt, type_here='vmem'):
if type_here == 'vmem':
axt.eventplot(arrays_sp, lineoffsets=np.max(arrays_time)) # * np.ones(len(arrays_sp)))
else:
try:
axt.eventplot(arrays_sp, lineoffsets=np.mean(arrays_time)) # * np.ones(len(arrays_sp)))
except:
print('axt something')
embed()
def get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig):
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
sorted = np.argsort(np.abs(dfs))
new_f2_tuple = new_f2_tuple[sorted]
f2s = [tup[1] for tup in new_f2_tuple]
f2s = np.sort(f2s)
frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)]
frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell)
def find_mt_type(mts):
if 'chirp' in mts.name:
mt_type = 'chirp'
elif 'SAM' in mts.name:
mt_type = 'SAM'
elif 'sine' in mts.name:
mt_type = 'sine'
elif find_gwn(mts):
mt_type = 'stim'
elif 'three' in mts.name:
mt_type = 'three'
return mt_type
def choice_specific_indices(contrasts, negativ='negativ', units=2 * 7, cut_val=2):
next_step = int(np.round(len(contrasts) / units))
if next_step == 0:
next_step = 1
if negativ == 'negativ':
indeces_show = np.argsort(contrasts)[0:int(len(contrasts) / cut_val)][0::next_step][::-1]
contrasts_show = np.sort(contrasts)[0:int(len(contrasts) / cut_val)][0::next_step][::-1]
elif negativ == 'positiv':
try:
indeces_show = np.argsort(contrasts)[::-1][0:int(len(contrasts) / cut_val)][
0::next_step][::-1]
except:
print('positiv something')
embed()
contrasts_show = np.sort(contrasts)[::-1][0:int(len(contrasts) / cut_val)][
0::next_step][::-1]
elif negativ == 'highest':
indeces_show = np.argsort(contrasts)[::-1][0:int(units / cut_val)][::-1]
contrasts_show = np.sort(contrasts)[::-1][0:int(units / cut_val)][::-1]
return contrasts_show, indeces_show
def spike_times_cocktailparty(b, delay, mt, mt_nr, load_eod_array='LocalEOD-1'):
timepoint = time.time()
try:
eod_mt, spikes_mt = load_eod_for_three(b, delay, mt, mt_nr, load_eod_array=load_eod_array)
except:
print('problem')
embed()
time_eod = np.arange(0, len(eod_mt) / 40000, 1 / 40000) - delay
time_laod_eods = time.time() - timepoint
return eod_mt, spikes_mt, time_eod, time_laod_eods, timepoint
def load_eod_for_three(b, delay, mt, mt_nr, load_eod_array='LocalEOD-1'):
eod_mt, spikes_mt, sampling = link_arrays(b, first=mt.positions[:][mt_nr] - delay,
second=mt.extents[:][mt_nr] + delay,
minus_spikes=mt.positions[:][mt_nr], load_eod_array=load_eod_array)
return eod_mt, spikes_mt
def diagonal_points(): #
global combis
combis = {'off1': (0.5, 0.67),
'test_data_cell_2022-01-05-aa-invivo-1': (0.27, 1.27,),
'B1-B2_diagonal': (0.27, 1.27,),
'diagonal1': (1 / 3, 2 / 3),
'B1+B2_diagonal': (1 / 4, 3 / 4),
'B1+B2_diagonal2': (0.27, 0.73),
'B1+B2_diagonal3': (0.3, 0.7),
'B1-B2_diagonal3': (0.3, 1.3),
'B1+B2_diagonal31': (0.31, 0.69),
'B1+B2_diagonal32': (0.32, 0.68),
'B1+B2_diagonal33': (0.33, 0.67),
'B1+B2_diagonal_plus_0.2c1': ((1 / 3) + 0.2, 2 / 3),
'Half_Fr_c1': (0.5, 0.3),
'Half_Fr_c2': (0.3, 0.5),
'diagonal2': (2 / 3, 1 / 3,),
'diagonal3': (0.1, 0.9),
'vertical1': (1, 0.7),
'vertical4': (0.8, 0.6),
'vertical5': (0.8, 0.55),
'vertical2': (1, 1.05),
'vertical3': (1 + (1.167 - 1.1644) / 0.1644, 1 + (1.18 - 1.1644) / 0.1644),
'vertical6': (0.4, 1),
'vertical6': (0.4, 1),
'horizontal': (0.8, 1),
'inside': (1 / 2, 2 / 3),
'outside': (1.2, 2 / 3)
}
return combis
def plt_ROC_model_w_female_square_nonlin(frame_names=[], female='wo_female', reshuffled='reshuffled',
datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False,
plus_q='minus', freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3,
way='absolut', stimulus_length=0.5, runs=3, trials_nr=500,
cells=[], show=False, nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4,
gain=1, talk=True, fish_jammer='Alepto', us_name=''):
if talk:
plt.rcParams['lines.linewidth'] = 1
try:
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
except:
embed()
print('still some model something')
if len(cells) < 1:
cells = len(model_cells)
for cell_here in cells:
# sachen die ich variieren will
###########################################
single_waves = ['_SeveralWave_'] # , '_SingleWave_']
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
for a_f1 in a_f1s:
a_frs = [1]
titles_amp = ['base eodf'] # ,'baseline to Zero',]
for a, a_fr in enumerate(a_frs):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0()
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick,
sampling, fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
variant = 'sinz'
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr,
nfft_for_morph,
phaseshift_fr,
mimick, zeros,
cell_recording,
sampling,
stimulus_length,
fish_receiver,
deltat, nfft,
damping_type,
damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
spikes_base = [[]] * trials_nr
colors_w, colors_wo, color0, color01, color02, color012 = colors_cocktailparty_all()
default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(12.43))
default_ticks_talks()
fig = plt.figure()
grid = gridspec.GridSpec(1, 2, wspace=0.35, width_ratios=[0.8, 1.6, ], hspace=0.5,
left=0.08, top=0.95, bottom=0.12,
right=0.96) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8,
grid0 = gridspec.GridSpecFromSubplotSpec(5, 2, wspace=0.18, hspace=0.12,
subplot_spec=grid[1],
height_ratios=[1, 0.6, 1, 1, 1.25]) # ,0.4,1.2
for run in range(runs):
print(run)
t1 = time.time()
for t in range(trials_nr):
stimulus = eod_fish_r
stimulus_base = eod_fish_r
if 'Zero' in titles_amp[a]:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
adaptation_upper_tol=upper_tol,
adaptation_lower_tol=lower_tol,
power_variant=power_here,
power_alpha=alpha,
power_nr=n,
reshuffle=reshuffled,
**model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
if printing:
print('Baseline time' + str(time.time() - t1))
base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
fr = np.mean(base_cut)
if 'diagonal' in diagonal:
two_third_fr = fr * freq2_ratio
freq1_ratio = (1 - freq2_ratio)
third_fr = fr * freq1_ratio
else:
two_third_fr = fr * freq2_ratio
third_fr = fr * freq1_ratio
if plus_q == 'minus':
two_third_fr = -two_third_fr
third_fr = -third_fr
freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr,
freqs1 = [
eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ]
sampling_rate = 1 / deltat
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat,
stimulus_length, dev=dev)
fr = np.mean(base_cut)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
isi = np.diff(spikes_base[0])
cv0 = np.std(isi) / np.mean(isi)
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
freq2 = [freqs2[ff]]
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1,
phaseshift_f1, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_emitter, thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2,
phaseshift_f2, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_jammer, thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2,
stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2,
trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev)
if printing:
print('Generation process' + str(time.time() - t1))
##################################
# power spectrum
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
t_off = 10
position_diff = 0
results_diff = pd.DataFrame()
results_diff['f1'] = freq1
results_diff['f2'] = freq2
results_diff['f0'] = eod_fr
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off,
way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling)
if run == 0:
pass
else:
pass
grid1 = gridspec.GridSpecFromSubplotSpec(2, 1, hspace=0.4,
subplot_spec=grid[0])
ax_ROC = plt.subplot(grid1[0])
ax_nonlin = plt.subplot(grid1[1])
colors_wo = ['orange', 'orange', 'orange']
colors_w = ['green', 'green', 'green']
xlim = core_xlim_dist_roc()
plt_ROC_nonlin(xlim, frame_names, ax_ROC, ax_nonlin, cells, colors_wo, colors_w)
ax_nonlin.set_xlabel(core_distance_label())
ax_ROC.set_xlabel(core_distance_label())
if run == 0:
plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012,
spikes_base, mat_base, mat05_01,
mat05_012, mat05_02, color02, color012, a_f2, trials, sampling,
a_f1, fr, female, color01, color0, grid0, eod_fr,
freq2,
freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01,
stimulus_base, time_array - time_array[0], vlin=False,
carrier=True)
axs = plt_power_spectrum(grid0, color01, color02, color012, color0, fr,
results_diff, female, nfft, smoothed012, smoothed01,
smoothed02, smoothed0, sampling_rate, mult_val=0.15,
add_to=195, wierd_charing=False)
_, _ = plt.gca().get_legend_handles_labels()
remove_yticks(axs[1]) # ax[6 + 1]
join_y(axs)
ax = fig.axes
ax = ax[1::]
ax[4 + 1].set_ylabel('Firing Rate [Hz]')
ax[4 + 1].set_xlabel('Time [ms]')
ax[5 + 1].set_xlabel('Time [ms]')
ax[6 + 1].set_xlabel('Frequency [Hz]')
ax[7 + 1].set_xlabel('Frequency [Hz]')
ax[6 + 1].set_ylabel('Power [Hz]')
for aa, ax_here in enumerate(ax[2:5]):
ax_here.set_xticks([])
for aa, ax_here in enumerate(ax[1::]):
if aa not in np.arange(0, 2, 2):
pass
else:
ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2])
individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str(
trials_nr) + '_stimulus_length_' + str(
stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr)
fig = plt.gcf()
fig.tag([fig.axes[0], fig.axes[2], fig.axes[3], fig.axes[1]], xoffs=-3.5)
save_visualization(individual_tag, show, pdf=pdf, counter_contrast=0, savename='')
def default_model0():
f1 = 0
f2 = 0
sampling_factor = ''
phaseshift_fr = 0
cell_recording = ''
mimick = 'no'
zeros = 'zeros'
fish_morph_harmonics_var = 'harmonic'
fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus']
fish_receiver = 'Alepto' #
phase_right = '_phaseright_'
adapt_offset = 'adaptoffsetallall2'
constant_reduction = ''
n = 1
lower_tol = 0.995
upper_tol = 1.005
SAM = '' # ,
damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35
damping_type = ''
exponential = ''
dent_tau_change = 1
return SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros
def plt_ROC_nonlin(xlim, frame_names, ax_ROC, ax_nonlin, cells, colors_wo, colors_w):
for c, cell in enumerate(cells):
for f, frame_name in enumerate(frame_names):
path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv'
if os.path.exists(path):
frame = pd.read_csv(path)
path_ref = load_folder_name(
'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
frame_ref = frame_ref.sort_values(by='cv_0')
_, _ = find_row_col(cells, row=4)
frame_cell = frame[frame.cell == cell]
label = ['with female', 'CLS: 100n', 'LS: 1000n']
label2 = ['without female', 'CLS: 100n', 'LS: 1000n']
for ax in [ax_ROC]:
if len(frame_cell) > 0:
plt_area_between(frame_cell, ax, ax, colors_wo, colors_w, f,
labels_without_female=label2[f], labels_with_female=label[f])
ax.set_xlim(xlim)
ax.set_ylim(0, 0.5)
ax.legend(loc=(0.5, 0.8))
plt_nonlin_line(ax_nonlin, cell, 0, frame_cell, xlim)
def plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012, spikes_base, mat_base,
mat05_01, mat05_012, mat05_02, color02, color012, a_f2, trials, sampling, a_f1, fr, female,
color01, color0, grid0, eod_fr, freq2, freq1, sampling_rate, stimulus_012,
stimulus_02, stimulus_01, stimulus_base, time_array, carrier=False, spike_events=True, vlin=True,
short_title=True):
beat2 = freq2 - eod_fr
beat1 = freq1 - eod_fr
#############################################
eod_interp_base, _, = extract_am(stimulus_base, time_array,
sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if len(np.shape(stimulus_01)) > 1:
stimulus_01_here = stimulus_01[0]
else:
stimulus_01_here = stimulus_01 # [0]
eod_interp_01, eod_norm = extract_am(stimulus_01_here, time_array,
sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if len(np.shape(stimulus_02)) > 1:
stimulus_02_here = stimulus_02[0]
else:
stimulus_02_here = stimulus_02 # [0]
eod_interp_02, eod_norm = extract_am(stimulus_02_here, time_array,
sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
if len(np.shape(stimulus_012)) > 1:
stimulus_012_here = stimulus_012[0]
else:
stimulus_012_here = stimulus_012 # [0]
eod_interp, eod_norm = extract_am(stimulus_012_here, time_array, sampling=sampling_rate,
eodf=eod_fr,
emb=False,
extract='', norm=False)
start = 0 # 0.2
time_array = time_array - start # lim_shift
xlim = (0, 0.102 * 1000)
counter = 0
ax = plt_stimulus_ROC(eod_interp_base, stimulus_base, a_f1, a_f2, beat1, beat2, carrier, color0, color01, color012,
color02, eod_interp,
eod_interp_01, eod_interp_02, female, fr, freq1_ratio, freq2_ratio, grid0,
short_title, stimulus_01, stimulus_012, stimulus_02, time_array, xlim, counter=counter)
counter += 1
#############################################
# spikes_012
if spike_events:
plt_eventplot_ROC(ax, female, grid0, spikes_01, spikes_012, spikes_02, spikes_base, xlim, counter=counter)
counter += 1
#############################################
# smoothed
plt_firingrate_ROC(female, grid0, mat05_01, mat05_012, mat05_02, mat_base, sampling, t_off, time_array, trials,
vlin, xlim, counter=counter)
counter += 1
def plt_firingrate_ROC(female, grid0, mat05_01, mat05_012, mat05_02, mat_base, sampling, t_off, time_array, trials,
vlin, xlim, counter=2):
color_mat = 'black'
if 'wo_female' in female:
ax = plt.subplot(grid0[counter, 0])
ax.set_xlim(xlim)
plt.plot(time_array * 1000, mat_base, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
plt.plot(time_array * 1000, mat05_01, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
elif 'base_female' in female:
ax = plt.subplot(grid0[2, 0])
ax.set_xlim(xlim)
plt.plot(time_array * 1000, mat_base, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
plt.plot(time_array * 1000, mat05_02, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
else:
ax = plt.subplot(grid0[counter, 0])
ax.set_xlim(xlim)
plt.plot(time_array * 1000, mat05_02, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
plt.plot(time_array * 1000, mat05_012, color=color_mat)
if vlin:
plt.vlines((trials / sampling) * 1000, ymin=0, ymax=750, color='grey', linestyle='--', linewidth=0.5)
plt.vlines((trials - t_off / sampling) * 1000, ymin=0, ymax=750, color='grey',
linestyle='--', linewidth=0.5)
plt.axvline([0], color='grey',
linestyle='--', linewidth=0.5)
def plt_eventplot_ROC(ax, female, grid0, spikes_01, spikes_012, spikes_02, spikes_base, xlim, counter=1):
if 'wo_female' in female:
ax = plt.subplot(grid0[counter, 0], sharex=ax)
ax.set_xlim(xlim)
ax.spines['bottom'].set_visible(False)
plt.eventplot(np.array(spikes_base) * 1000, color='black')
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
ax.spines['bottom'].set_visible(False)
plt.eventplot(np.array(spikes_01) * 1000, color='black')
elif 'base_female' in female:
ax = plt.subplot(grid0[counter, 0], sharex=ax)
ax.set_xlim(xlim)
ax.spines['bottom'].set_visible(False)
plt.eventplot(np.array(spikes_base) * 1000, color='black')
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
ax.spines['bottom'].set_visible(False)
plt.eventplot(np.array(spikes_02) * 1000, color='black')
else:
ax = plt.subplot(grid0[counter, 0])
ax.spines['bottom'].set_visible(False)
ax.set_xlim(xlim)
plt.eventplot(np.array(spikes_02) * 1000, color='black')
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
ax.spines['bottom'].set_visible(False)
plt.eventplot(np.array(spikes_012) * 1000, color='black')
def plt_stimulus_ROC(eod_interp_base, stimulus_base, a_f1, a_f2, beat1, beat2, carrier, color0, color01, color012,
color02, eod_interp, eod_interp_01,
eod_interp_02, female, fr, freq1_ratio, freq2_ratio, grid0, short_title,
stimulus_01, stimulus_012, stimulus_02, time_array, xlim, counter=0):
if 'wo_female' in female:
ax = plt.subplot(grid0[counter, 0])
plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier)
if short_title:
plt.title('Baseline', color=color0)
else:
plt.title('Base: 0 \n $fr=$' + str(np.round(fr)) + 'Hz', color=color0)
ax = plt.subplot(grid0[counter, 1], sharex=ax)
plt_base(ax, xlim, time_array, eod_interp_01, color01, stimulus_01, carrier)
remove_yticks(ax)
if short_title:
plt.title('Intruder', color=color01)
else:
plt.title('Intruder: 01 \n $f=$' + str(np.round(beat1[0])) + 'Hz ' + ' $c_{1}=$' + str(
a_f1 * 100) + '$\%$' + '\n' + r' $\frac{f}{fr}=$' + str(np.round(freq1_ratio, 2)),
color=color01)
elif 'base_female' in female:
ax = plt.subplot(grid0[counter, 0])
#
plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier)
if short_title:
ax.set_title('Baseline', color=color0)
else:
ax.set_title('Base: 0 \n $fr=$' + str(np.round(fr)) + 'Hz', color=color0)
ax = plt.subplot(grid0[counter, 1], sharex=ax)
remove_yticks(ax)
plt_base(ax, xlim, time_array, eod_interp_02, color02, stimulus_02, carrier)
if short_title:
ax.set_title('Female', color=color02)
else:
ax.set_title('Female: 02 \n $f=$' + str(np.round(beat2[0])) + ' Hz' + ' $c_{2}$ ' + str(
a_f2 * 100) + '$\%$ ' + '\n' + r'$\frac{f}{fr}={len(folder)}$' + str(np.round(freq2_ratio, 2)),
color=color02)
else:
ax = plt.subplot(grid0[counter, 0])
plt_base(ax, xlim, time_array, eod_interp_02, color02, stimulus_02, carrier)
if short_title:
ax.set_title('Female', color=color02)
else:
ax.set_title('Female: 02 \n $f=$' + str(np.round(beat2[0])) + ' Hz' + ' $c_{2}$ ' + str(
a_f2 * 100) + '$\%$ ' + '\n' + r'$\frac{f}{fr}={len(folder)}$' + str(np.round(freq2_ratio, 2)),
color=color02)
# eod interp
ax = plt.subplot(grid0[counter, 1], sharex=ax)
plt_base(ax, xlim, time_array, eod_interp, color012, stimulus_012, carrier)
remove_yticks(ax)
if short_title:
ax.set_title('Female + Intruder', color=color012)
else:
ax.set_title('Fem. + Int.: 012 \n $f=$' + str(np.round(beat1[0] + beat2[0])) + ' Hz',
color=color012)
return ax
def plt_base(ax, xlim, time_array, eod_interp_base, color0, stimulus_base, carrier):
ax.set_xlim(xlim)
ax.plot(time_array * 1000, eod_interp_base, color=color0)
if carrier:
if len(np.shape(stimulus_base)) > 1:
stimulus_base_here = stimulus_base[0]
else:
stimulus_base_here = stimulus_base
ax.plot(time_array * 1000, stimulus_base_here, color='grey', linewidth=0.5)
ax.set_ylim(-1.15, 1.15)
else:
ax.set_ylim([0.85, 1.15])
ax.spines['bottom'].set_visible(False)
def plt_power_spectrum2(grid0, color01, color02, color012, color0, fr, results_diff, female, nfft,
smoothed012,
smoothed01, smoothed02, smoothed0, sampling_rate, counter=4, add_to=70, mult_val=0.125,
wierd_charing=True, log = ''):
p0, p02, p01, p012, fs = calc_ps(nfft, smoothed012,
smoothed01, smoothed02, smoothed0,
sampling_rate=sampling_rate, log = log, xlim = xlim_ROC_talk2())
DF1 = np.abs(results_diff.f1.iloc[-1] - results_diff.f0.iloc[-1])
DF2 = np.abs(results_diff.f2.iloc[-1] - results_diff.f0.iloc[-1])
if 'wo_female' in female:
p_arrays = [p0, p01]
else:
four = False
if four:
p_arrays = [p02, p012]
freqs_all = [[np.abs(DF2), np.abs(DF2) * 2],
[np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1), np.abs(DF1) + np.abs(DF2),
(np.abs(DF1) + np.abs(DF2)) * 2, fr,
fr * 2, (np.abs(DF1) + np.abs(DF2) * 2), ]] # np.abs(np.abs(DF1) - np.abs(DF2)),
color0122 = color_sumpeak()
colors_all = [[color02, color02],
[color02, color02, color01, color012, color012, color0, color0, color0122, ]] # color01_2,
labels_all = [['DF2', 'DF2 H1'],
[r'$\Delta \mathrm{f_{Female}}$', '', r'$\Delta \mathrm{f_{Intruder}}$', sum_intruder_core(),
'', r'$\mathrm{f'+basename()+'}$', '',
r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$']] # '$|Intruder-Female|$',
else:
p_arrays = [p02, p012]
freqs_all = [[np.abs(DF2), np.abs(DF2) * 2],
[np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1),
np.abs(DF1) + np.abs(DF2),
(np.abs(DF1) + np.abs(DF2)) * 2, ]] # color0122, (np.abs(DF1) + np.abs(DF2) * 2)fr, fr * 2, np.abs(np.abs(DF1) - np.abs(DF2)),
color0122 = color_sumpeak()
colors_all = [[color02, color02],
[color02, color02, color01, color012, color012, color0, color0, ]] # color01_2, '',
labels_all = [[r'$\Delta \mathrm{f_{Female}}$', r'2 $\Delta \mathrm{f_{Female}}$'],
[r'$\Delta \mathrm{f_{Female}}$', '',
r'$\Delta \mathrm{f_{Intruder}}$',
r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$',
'',
'',
'', ]] # r'$\mathrm{f'+basename()+'}$''$|Intruder-Female|$',r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$'
axs = plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all,
mult_val, p0, p_arrays, wierd_charing, xlim=xlim_ROC_talk2(), counter=counter,
add_texts=[0, 0, 0, 0, 0, 0, 0, 0, 500, 0, 0, 0, 0, 0], log = log, text_extra=True)
return axs
def plt_power_spectrum(grid0, color01, color02, color012, color0, fr, results_diff, female, nfft,
smoothed012,
smoothed01, smoothed02, smoothed0, sampling_rate, counter=4, add_to=70, mult_val=0.125,
wierd_charing=True):
p0, p02, p01, p012, fs = calc_ps(nfft, smoothed012,
smoothed01, smoothed02, smoothed0,
sampling_rate=sampling_rate)
DF1 = np.abs(results_diff.f1.iloc[-1] - results_diff.f0.iloc[-1])
DF2 = np.abs(results_diff.f2.iloc[-1] - results_diff.f0.iloc[-1])
if 'wo_female' in female:
p_arrays = [p0, p01]
else:
four = False
if four:
p_arrays = [p02, p012]
freqs_all = [[np.abs(DF2), np.abs(DF2) * 2],
[np.abs(DF2), np.abs(DF2) * 2, np.abs(DF1), np.abs(DF1) + np.abs(DF2),
(np.abs(DF1) + np.abs(DF2)) * 2, fr,
fr * 2, (np.abs(DF1) + np.abs(DF2) * 2), ]] # np.abs(np.abs(DF1) - np.abs(DF2)),
color0122 = color_sumpeak()
colors_all = [[color02, color02],
[color02, color02, color01, color012, color012, color0, color0, color0122, ]] # color01_2,
labels_all = [['DF2', 'DF2 H1'],
[r'$\Delta \mathrm{f_{Female}}$', '', r'$\Delta \mathrm{f_{Intruder}}$', sum_intruder_core(),
'', r'$\mathrm{f'+basename()+'}$', '',
r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$']] # '$|Intruder-Female|$',
else:
p_arrays = [p02, p012]
freqs_all = [[np.abs(DF2), np.abs(DF2) * 2],
[np.abs(DF2), np.abs(DF2) * 2, (np.abs(DF1) + np.abs(DF2) * 2), np.abs(DF1),
np.abs(DF1) + np.abs(DF2),
(np.abs(DF1) + np.abs(DF2)) * 2, ]] # fr, fr * 2, np.abs(np.abs(DF1) - np.abs(DF2)),
color0122 = color_sumpeak()
colors_all = [[color02, color02],
[color02, color02, color0122, color01, color012, color012, color0, color0, ]] # color01_2,
labels_all = [[r'$\Delta \mathrm{f_{Female}}$', r'2 $\Delta \mathrm{f_{Female}}$'],
[r'$\Delta \mathrm{f_{Female}}$', '',
'',
r'$\Delta \mathrm{f_{Intruder}}$',
r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$',
'',
'',
'', ]] # r'$\mathrm{f'+basename()+'}$''$|Intruder-Female|$',r'$2|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$'
axs = plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all,
mult_val, p0, p_arrays, wierd_charing, xlim=xlim_ROC_talk2(), counter=counter,
add_texts=[0, 0, 0, 0, 0, 0, 0, 0, 500, 0, 0, 0, 0, 0], text_extra=True)
return axs
def xlim_ROC_talk2():
return (0, 200)
def color_sumpeak():
color0122 = 'yellow'
return color0122
def plt_spectra_given(DF1, add_to, color0, color01, colors_all, female, fr, freqs_all, fs, grid0, labels_all, mult_val,
p0, p_arrays, wierd_charing, counter=4, xlim=(0, 300), add_texts=[0, 0, 0, 0, 0, 0, 0, ],
text_extra=True, log = ''):
axs = []
for j in range(len(p_arrays)):
if (j != 0) & wierd_charing:
ax = plt.subplot(grid0[counter, j], sharex=ax, sharey=ax) # , sharex=ax
else:
ax = plt.subplot(grid0[counter, j]) # , sharex=ax
axs.append(ax)
p0_means = []
for i in range(len(p0)):
ax.plot(fs, p_arrays[j][i], color='grey')
p0_mean = np.mean(p_arrays[j], axis=0)
p0_means.append(p0_mean)
ax.plot(fs, p0_mean, color='black') # plt_peaks(ax[0], p01, fs, 'orange')
for p in range(len(p0_means)):
if 'wo_female' in female:
freqs = [np.abs(DF1), fr]
colors = [color01, color0]
labels = ['DF1', 'baseline']
else:
labels = labels_all[j]
colors = colors_all[j]
freqs = freqs_all[j]
new = True
ax.set_xlim(xlim)
if new:
plt_peaks_several(freqs, p_arrays, ax, p0_mean, fs, labels, 0, colors,
add_texts=add_texts, add_log=2.5, exact=False, text_extra=True, perc_peaksize=5,
rel='rel', ms=24,ha='left',
clip_on=False, several_peaks=True, log=log) # True ha='center',
else:
df_passed = []
for f in range(len(freqs)):
if int(freqs[f]) in df_passed:
add = (np.max(np.max(p_arrays)) + add_to) * mult_val
else:
add = (np.max(np.max(p_arrays)) + add_to) * 0.05
try:
_, _ = plt_peaks(ax, p0_means[p], freqs[f], fs, fr_color=colors[f], s=25,
label=labels[f], add_text=add_texts[f], text_extra=text_extra, add=add,
extend=False, clip_on = False)
except:
print('p problem')
embed()
df_passed.append(int(freqs[f]))
return axs
def sum_intruder_core():
return r'$|\Delta \mathrm{f_{Female}}|+|\Delta \mathrm{f_{Intruder}}|$'
def plt_area_between(frame_cell, ax0, ax, colors_w, colors_wo, f, cut_starts=False, alphas=[1, 1, 1, 1, 1, 1, 1, 1, 1],
starts=0.25, ls='-', labels_with_female='', arrow=True, fill=True,
labels_without_female='', dist_redo=True):
cell = frame_cell.cell.unique()[0]
frame_cell = frame_cell.groupby('c1', as_index=False).mean()
c1 = frame_cell.c1
if dist_redo:
c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True)
sorting = np.argsort(c1)
c1 = np.sort(c1)
frame_cell['auci_02_012'].iloc[frame_cell.index] = frame_cell['auci_02_012'].iloc[sorting]
frame_cell['auci_base_01'].iloc[frame_cell.index] = frame_cell['auci_base_01'].iloc[sorting]
upper0 = frame_cell['auci_02_012'] * 1
lower1 = frame_cell['auci_base_01'] * 1
upper0new = [0.5]
upper0new.extend(upper0)
upper0 = np.array(upper0new)
lower0new = [0.5]
lower0new.extend(lower1)
lower1 = np.array(lower0new)
upper = upper0 * 1
lower = lower1 * 1
lower[lower > upper0] = upper0[lower > upper0]
upper[upper > lower1] = lower1[upper > lower1]
c1_new = [5]
c1_new.extend(c1)
c1 = c1_new
if fill:
ax0.fill_between(c1, upper0, upper, color='red', edgecolor=None, zorder=2, alpha=0.05)
ax0.fill_between(c1, lower1, lower, color='blue', edgecolor=None, zorder=2,
alpha=0.05)
ax.set_xlim(0, ax.get_xlim()[1])
if type(starts) == list:
start = starts[f]
else:
start = starts
test = False
if test:
ax = plt.subplot(1, 1, 1)
with_female = np.array(upper0)
without_female = np.array(lower1)
c1_interp = c1
reintepolate = True
if reintepolate:
c1_interp_new = np.arange(np.min(c1_interp), np.max(c1_interp), 1)
with_female = interpolate(c1_interp, with_female,
c1_interp_new,
kind='linear')
without_female = interpolate(c1_interp, without_female,
c1_interp_new,
kind='linear')
c1_interp = c1_interp_new # _new
pos_l = np.argmin(np.abs(with_female - start))
pos_r = np.argmin(np.abs(without_female - start))
val_l = with_female[pos_l]
val_r = without_female[pos_r]
pos_ll = np.min([pos_l, pos_r])
if cut_starts:
ax.plot(c1_interp[pos_r::], without_female[pos_r::], alpha=alphas[f], color=colors_wo[f],
label=labels_without_female, clip_on=True,
linestyle=ls) # linewidth=lw,
# todo: das muss man in linear machen
ax.plot(c1_interp, with_female, alpha=alphas[f], color=colors_w[f], label=labels_with_female, clip_on=True)
else:
ax.plot(c1_interp, without_female, color=colors_wo[f], alpha=alphas[f], label=labels_without_female,
clip_on=True, linestyle=ls) # linewidth = lw,
ax.plot(c1_interp, with_female, color=colors_w[f], alpha=alphas[f], label=labels_with_female,
clip_on=True) # , linewidth = lw
if arrow: # colors_w[f]
# ich will halt dass es einge gerade linie ist
if val_r != val_l:
val_rr = val_l
if pos_l != pos_ll:
ax.annotate('', xy=(c1_interp[pos_l], val_l), xytext=(c1_interp[pos_r], val_rr),
arrowprops=dict(arrowstyle="->",
color='black'), textcoords='data', xycoords='data', horizontalalignment='left')
else:
ax.annotate('', xy=(c1_interp[pos_l], val_l), xytext=(c1_interp[pos_r], val_rr),
arrowprops=dict(arrowstyle="->",
color='black'), textcoords='data', xycoords='data', horizontalalignment='left')
ax.set_xlabel(core_distance_label())
ax.set_ylabel(core_auc_label())
return c1_interp, without_female, with_female
def arrow_annotate(ax, c1, colors_w, f, pos_l, pos_ll, pos_r, val_l, val_r):
if pos_l != pos_ll:
ax.annotate('', xy=(c1[pos_l], val_l), xytext=(c1[pos_r], val_r),
arrowprops=dict(arrowstyle="->",
color=colors_w[f]), textcoords='data', xycoords='data',
horizontalalignment='left')
else:
ax.annotate('', xy=(c1[pos_l], val_l), xytext=(c1[pos_r], val_r),
arrowprops=dict(arrowstyle="->",
color=colors_w[f]), textcoords='data', xycoords='data',
horizontalalignment='left')
def plt_several_ROC_declining_one_with_ROC_single_in_one_dec(bt=0.12, lf=0.07, females=[], color_01='green',
color_02='red',
color_012='orange', figsize=(12, 5.5), frame_names=[
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'],
reshuffled='reshuffled', datapoints=1000, dev=0.0005,
a_f1s=[0.03], printing=False, plus_q='minus',
way='absolut',
stimulus_length=0.5, runs=3, trials_nr=500,
nfft=int(2 ** 15),
beat='', nfft_for_morph=4096 * 4, gain=1,
fish_jammer='Alepto',
us_name='', wr=[1.6, 2]):
try:
pass
except:
print('split something')
freq1_ratio = float(frame_names[0].split('FrF1rel_')[1].split('_FrF2rel')[0])
freq2_ratio = float(frame_names[0].split('FrF2rel_')[1].split('_C2')[0])
cells = [
"2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ]
cells_chosen = [
'2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1',
plt.rcParams['lines.linewidth'] = 1
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
grid_here = gridspec.GridSpec(1, 2, hspace=0.3, wspace=0.36, left=lf, top=0.92, bottom=bt,
right=0.98, width_ratios=wr) # 1.3,1 wspace=0.16
grid_here1 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.6, hspace=0.75,
subplot_spec=grid_here[0]) # 0.3
grid_here2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.3,
subplot_spec=grid_here[1], height_ratios=[1.5, 3])
if len(cells) < 1:
cells = len(model_cells)
colors_wo = [color_01] # ['limegreen', 'green', 'darkgreen']
colors_w = [color_012] # ['orange', 'darkorange','goldenrod']
for cell_here in cells:
# sachen die ich variieren will
###########################################
single_waves = ['_SeveralWave_'] # , '_SingleWave_']
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
for a_f1 in a_f1s:
a_frs = [1]
titles_amp = ['base eodf'] # ,'baseline to Zero',]
for a, a_fr in enumerate(a_frs):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0()
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick,
sampling, fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr,
nfft_for_morph,
phaseshift_fr,
mimick, zeros,
cell_recording,
sampling,
stimulus_length,
fish_receiver,
deltat, nfft,
damping_type,
damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
save_name_roc = 'decline_ROC_examples_trial_nr.csv'
redo = False
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False)
for run in range(runs):
print(run)
t1 = time.time()
if cont_redo:
trials_nr_base = 1
stimulus_length = 1
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length,
phaseshift_fr,
cell_recording, zeros,
mimick, sampling,
fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
else:
trials_nr_base = trials_nr
spikes_base = [[]] * trials_nr_base
for t in range(trials_nr_base):
# get the baseline properties here
# baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
stimulus = eod_fish_r
if 'Zero' in titles_amp[a]:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
power_variant=power_here,
power_nr=n,
reshuffle=reshuffled,
**model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
if printing:
print('Baseline time' + str(time.time() - t1))
base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
fr = np.mean(base_cut)
two_third_fr = fr * freq2_ratio
third_fr = fr * freq1_ratio
if plus_q == 'minus':
two_third_fr = -two_third_fr
third_fr = -third_fr
freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr,
freqs1 = [
eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ]
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat,
stimulus_length, dev=dev)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
for ff, freq1 in enumerate(freqs1):
if cont_redo:
frame = pd.read_csv(save_name_roc)
tp_012_all = frame['tp_012'] # = tp_012_all
tp_01_all = frame['tp_01'] # = tp_01_all
tp_02_all = frame['tp_02'] # = tp_02_all
fp_all = frame['fp_all'] # = fp_all
else:
freq1 = [freq1]
freq2 = [freqs2[ff]]
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right,
phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1,
phaseshift_f1, sampling,
stimulus_length, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, zeros,
mimick, fish_emitter,
thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2,
phaseshift_f2, sampling,
stimulus_length, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, zeros,
mimick, fish_jammer,
thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2,
stimulus_length, offset, model_params, n, variant, adapt_offset, deltat,
f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev)
if printing:
print('Generation process' + str(time.time() - t1))
##################################
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
t_off = 10
position_diff = 0
results_diff = pd.DataFrame()
results_diff['f1'] = freq1
results_diff['f2'] = freq2
results_diff['f0'] = eod_fr
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off,
way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling)
frame = pd.DataFrame()
frame['tp_012'] = tp_012_all
frame['tp_01'] = tp_01_all
frame['tp_02'] = tp_02_all
frame['fp_all'] = fp_all
frame['threshhold'] = threshhold
if version_comp == 'develop':
frame.to_csv(save_name_roc)
# threshhold
if run == 0:
color = 'black'
lw = 1
else:
color = 'grey'
lw = 0.5
for f_nr, female in enumerate(females):
if female == 'w_female':
ax1 = plt.subplot(grid_here1[0])
color_e = 'lightgrey'
roc_wo_female(color, ax1, tp_02_all, tp_012_all, color_02, color_012,
title_color='black', color_e=color_e) # colors_w[ff]
plt.fill_between(tp_02_all,
tp_02_all,
tp_012_all,
color=colors_w[ff], alpha=0.8)
ax1.set_title('') #: 0
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
elif female == 'wo_female':
ax1 = plt.subplot(grid_here1[0])
color_e = None
roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black',
title_color='black', color_e=color_e) # colors_wo[ff]
plt.fill_between(fp_all,
fp_all,
tp_01_all,
color=colors_wo[ff], alpha=0.8)
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
ax1.set_title('') #: 0
else:
ax1 = plt.subplot(grid_here1[0])
roc_wo_female(color, ax1, tp_02_all, tp_012_all, 'black', 'black',
title_color='black', color_e=color_e) # colors_w[ff]
plt.fill_between(tp_02_all,
tp_02_all,
tp_012_all,
color=colors_w[ff], alpha=0.8)
roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black',
title_color='black') # colors_wo[ff]
plt.fill_between(fp_all,
fp_all,
tp_01_all,
color=colors_wo[ff], alpha=1)
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
ax1.set_title('') #: 0
################################################
# part with the ROC declining
ax0 = plt.subplot(grid_here2[0])
distance_cm = np.arange(0, 200, 0.2)
xlim_dist = core_xlim_dist_roc()
distances_mv = c_to_dist(distance_cm, convert='dist_to_contrast')
ax0.plot(distance_cm, distances_mv, label='cubed', color='black')
ax0.set_xlim(xlim_dist)
ax0.set_ylabel('EOD Amplitude\n [mV]')
ax0.set_yticks([])
test = False
if test:
from utils_test import test_distances
test_distances()
ax0.set_yscale('log')
ax0.set_ylim(0, 2)
ax1 = plt.subplot(grid_here2[1])
for c, cell in enumerate(cells_chosen):
for f, frame_name in enumerate(frame_names):
path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv'
if os.path.exists(path):
frame = pd.read_csv(path)
path_ref = load_folder_name(
'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
_, _ = find_row_col(cells, row=4)
frame_cell = frame[frame.cell == cell]
axs = [ax1]
label_f = ['with female', 'CLS: 100n with female',
'LS: 1000n with female', ]
label_f_wo = ['without female', 'CLS: 100n without female',
'LS: 1000n without female', ]
labels_w = [label_f[f], label_f[f]]
labels_wo = [label_f_wo[f], label_f_wo[f]]
for a, ax in enumerate(axs):
if len(frame_cell) > 0:
c1 = c_dist_recalc_func(eod_size_change=True, mult_eod=0.5,
frame_cell=frame_cell,
c_nrs=frame_cell.c1, cell=cell,
c_dist_recalc=True)
lw = lw_roc()
s = 15 # 100
if female == 'w_female':
ax.plot(c1, frame_cell['auci_02_012'], color=colors_w[f],
label=labels_w[f],
clip_on=True, linewidth=lw)
elif female == 'wo_female':
ax.plot(c1, frame_cell['auci_base_01'], color=colors_wo[f],
label=labels_wo[f],
clip_on=True, linewidth=lw) # , linestyle='--'
else:
plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f,
labels_with_female=labels_w[a],
labels_without_female='', arrow=True)
ax.set_xlim(xlim_dist)
ax.set_ylim(0, 0.52)
ax.set_yticks_delta(0.1)
pos = np.argmin(np.abs(frame_cell.c1 - a_f1))
if f == 0:
c1 = c_dist_recalc_func(frame_cell=frame_cell,
c_nrs=[frame_cell.c1.iloc[pos]],
cell=cell, c_dist_recalc=True)
if female == 'wo_female':
ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos],
clip_on=True, color=colors_wo[0],
s=s) # , facecolor = 'none'
elif female == 'w_female':
ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos],
clip_on=True, color=colors_w[0],
s=s) # ,facecolor='none'
else:
ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos],
clip_on=True, color=colors_wo[0],
s=s) # , facecolor = 'none'
ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos],
clip_on=True, color=colors_w[0],
s=s) # , facecolor='none'
if a == 0:
ax.legend(loc=(0.6, 0.7)) # , fontsize = 8, handlelength = 0.5
else:
ax.legend(loc=(0.6, 0.6)) # , fontsize = 8, handlelength = 0.5
ax.set_ylim(0, 0.52)
if c != 0:
remove_yticks(ax)
ax.show_spines('lb')
ax1.set_ylabel(core_auc_label()) #
ax1.set_xlabel('mV/cm')
ax1.set_xlabel(core_distance_label())
ax = plt.gcf().axes
if f_nr == 0:
fig.tag(ax[0:3], xoffs=-6.5, yoffs=1.5, ) # 0.7
plt.subplots_adjust(left=0.03, wspace=0.3)
save_visualization(frame_name, False, show_anything=False, pdf=True,
jpg=True, png=False,
counter_contrast=0, savename='', add='_' + female)
plt.show()
def core_distance_label():
return 'Intruder Distance [cm]'
def core_auc_label():
return 'AUC' # 'Determinant'
def core_xlim_dist_roc():
xlim_dist = [0, 225]
return xlim_dist
def lw_roc():
lw = 0.75 # 2
return lw
def plt_several_ROC_declining_one_with_ROC_single_in_one(bt=0.12, lf=0.07, females=[], color_01='green', color_02='red',
color_012='orange', frame_names=[
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'],
reshuffled='reshuffled', datapoints=1000, dev=0.0005,
a_f1s=[0.03], printing=False, plus_q='minus', way='absolut',
stimulus_length=0.5, runs=3, trials_nr=500, nfft=int(2 ** 15),
beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto',
us_name='', wr=[1.6, 2]):
try:
pass
except:
print('split something')
freq1_ratio = float(frame_names[0].split('FrF1rel_')[1].split('_FrF2rel')[0])
freq2_ratio = float(frame_names[0].split('FrF2rel_')[1].split('_C2')[0])
cells = [
"2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ]
cells_chosen = [
'2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1',
plt.rcParams['lines.linewidth'] = 1
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
grid_here = gridspec.GridSpec(1, 2, hspace=0.6, wspace=0.16, left=lf, top=0.92, bottom=bt,
right=0.98, width_ratios=wr) # 1.3,1
grid_here1 = gridspec.GridSpecFromSubplotSpec(1, 1, wspace=0.3, hspace=0.75,
subplot_spec=grid_here[0])
grid_here2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.3, hspace=0.3,
subplot_spec=grid_here[1], height_ratios=[1.5, 3])
if len(cells) < 1:
cells = len(model_cells)
colors_wo = [color_01] # ['limegreen', 'green', 'darkgreen']
colors_w = [color_012] # ['orange', 'darkorange','goldenrod']
for cell_here in cells:
# sachen die ich variieren will
###########################################
single_waves = ['_SeveralWave_'] # , '_SingleWave_']
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
for a_f1 in a_f1s:
a_frs = [1]
titles_amp = ['base eodf'] # ,'baseline to Zero',]
for a, a_fr in enumerate(a_frs):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0()
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick,
sampling, fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr,
nfft_for_morph,
phaseshift_fr,
mimick, zeros,
cell_recording,
sampling,
stimulus_length,
fish_receiver,
deltat, nfft,
damping_type,
damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
# fig = plt.figure(figsize=(11.5, 5.4))
save_name_roc = 'decline_ROC_examples_trial_nr.csv'
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
for run in range(runs):
print(run)
t1 = time.time()
if (os.path.exists(save_name_roc)) | (version_comp == 'public'):
trials_nr_base = 1
stimulus_length = 1
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length,
phaseshift_fr,
cell_recording, zeros,
mimick, sampling,
fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
else:
trials_nr_base = trials_nr
spikes_base = [[]] * trials_nr_base
for t in range(trials_nr_base):
# get the baseline properties here
# baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
stimulus = eod_fish_r
if 'Zero' in titles_amp[a]:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus,
deltat=deltat,
power_variant=power_here,
power_alpha=alpha,
power_nr=n, **model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
if printing:
print('Baseline time' + str(time.time() - t1))
base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
fr = np.mean(base_cut)
two_third_fr = fr * freq2_ratio
third_fr = fr * freq1_ratio
if plus_q == 'minus':
two_third_fr = -two_third_fr
third_fr = -third_fr
freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr,
freqs1 = [
eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ]
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat,
stimulus_length, dev=dev)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
for ff, freq1 in enumerate(freqs1):
if (os.path.exists(save_name_roc)) | (version_comp == 'public'):
frame = pd.read_csv(save_name_roc)
tp_012_all = frame['tp_012'] # = tp_012_all
tp_01_all = frame['tp_01'] # = tp_01_all
tp_02_all = frame['tp_02'] # = tp_02_all
fp_all = frame['fp_all'] # = fp_all
else:
freq1 = [freq1]
freq2 = [freqs2[ff]]
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right,
phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1,
phaseshift_f1, sampling,
stimulus_length, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, zeros,
mimick, fish_emitter,
thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2,
phaseshift_f2, sampling,
stimulus_length, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, zeros,
mimick, fish_jammer,
thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2,
stimulus_length, offset, model_params, n, variant, adapt_offset, deltat,
f2, trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev)
if printing:
print('Generation process' + str(time.time() - t1))
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
t_off = 10
position_diff = 0
results_diff = pd.DataFrame()
results_diff['f1'] = freq1
results_diff['f2'] = freq2
results_diff['f0'] = eod_fr
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off,
way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling)
frame = pd.DataFrame()
frame['tp_012'] = tp_012_all
frame['tp_01'] = tp_01_all
frame['tp_02'] = tp_02_all
frame['fp_all'] = fp_all
frame['threshhold'] = threshhold
if version_comp == 'develop':
frame.to_csv(save_name_roc)
if run == 0:
color = 'black'
lw = 1
else:
color = 'grey'
lw = 0.5
for female in females:
if female == 'w_female':
ax1 = plt.subplot(grid_here1[0])
roc_wo_female(color, ax1, tp_02_all, tp_012_all, color_02, color_012,
title_color='black') # colors_w[ff]
plt.fill_between(tp_02_all,
tp_02_all,
tp_012_all,
color=colors_w[ff], alpha=0.8)
ax1.set_title('') #: 0
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
elif female == 'wo_female':
ax1 = plt.subplot(grid_here1[0])
roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black',
title_color='black') # colors_wo[ff]
plt.fill_between(fp_all,
fp_all,
tp_01_all,
color=colors_wo[ff], alpha=0.8)
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
ax1.set_title('') #: 0
else:
ax1 = plt.subplot(grid_here1[0])
roc_wo_female(color, ax1, tp_02_all, tp_012_all, 'black', 'black',
title_color='black') # colors_w[ff]
plt.fill_between(tp_02_all,
tp_02_all,
tp_012_all,
color=colors_w[ff], alpha=0.8)
roc_female(ax1, color, fp_all, tp_01_all, lw, 'black', 'black',
title_color='black') # colors_wo[ff]
plt.fill_between(fp_all,
fp_all,
tp_01_all,
color=colors_wo[ff], alpha=1)
ax1.set_xlabel('False-Positive Rate ') #: 0
ax1.set_ylabel('Correct-Detection Rate ') # 01
ax1.set_title('') #: 0
################################################
# part with the ROC declining
ax0 = plt.subplot(grid_here2[0])
distance = np.arange(0, 200, 0.2)
xlim_dist = [0, 70]
distances_mv = c_to_dist_reverse(distance) # distance_changed*factor
ax0.plot(distance, distances_mv, label='cubed', color='black')
ax0.set_xlim(xlim_dist)
ax0.set_ylabel('EOD Amplitude')
ax0.set_yticks([])
test = False
if test:
from utils_test import test_vals
test_vals()
ax0.set_yscale('log')
ax0.set_yticks([])
ax1 = plt.subplot(grid_here2[1])
for c, cell in enumerate(cells_chosen):
for f, frame_name in enumerate(frame_names):
path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv'
if os.path.exists(path):
frame = pd.read_csv(path)
path_ref = load_folder_name(
'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
_, _ = find_row_col(cells, row=4)
frame_cell = frame[frame.cell == cell]
axs = [ax1]
label_f = ['with female', 'CLS: 100n with female',
'LS: 1000n with female', ]
label_f_wo = ['without female', 'CLS: 100n without female',
'LS: 1000n without female', ]
labels_w = [label_f[f], label_f[f]]
labels_wo = [label_f_wo[f], label_f_wo[f]]
for a, ax in enumerate(axs):
if len(frame_cell) > 0:
c1 = c_dist_recalc_func(frame_cell=frame_cell,
c_nrs=frame_cell.c1, cell=cell,
c_dist_recalc=True)
if female == 'w_female':
ax.plot(c1, frame_cell['auci_02_012'], color=colors_w[f],
label=labels_w[f],
clip_on=True, linewidth=2)
elif female == 'wo_female':
ax.plot(c1, frame_cell['auci_base_01'], color=colors_wo[f],
label=labels_wo[f],
clip_on=True, linewidth=2) # , linestyle='--'
else:
plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f,
labels_with_female=labels_w[a], talk=False,
labels_without_female='', arrow=True)
ax.set_xlim(xlim_dist)
ax.set_ylim(0, 0.52)
ax.set_yticks_delta(0.1)
pos = np.argmin(np.abs(frame_cell.c1 - a_f1))
if f == 0:
c1 = c_dist_recalc_func(frame_cell=frame_cell,
c_nrs=frame_cell.c1, cell=cell,
c_dist_recalc=True)
s = 100
if female == 'wo_female':
ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos],
clip_on=True, color=colors_wo[0],
s=s) # , facecolor = 'none'
elif female == 'w_female':
ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos],
clip_on=True, color=colors_w[0],
s=s) # ,facecolor='none'
else:
ax.scatter(c1, frame_cell['auci_base_01'].iloc[pos],
clip_on=True, color=colors_wo[0],
s=s) # , facecolor = 'none'
ax.scatter(c1, frame_cell['auci_02_012'].iloc[pos],
clip_on=True, color=colors_w[0],
s=s) # , facecolor='none'
if a == 0:
ax.legend(loc=(0.6, 0.7)) # , fontsize = 8, handlelength = 0.5
else:
ax.legend(loc=(0.6, 0.6)) # , fontsize = 8, handlelength = 0.5
ax.set_ylim(0, 0.52)
if c != 0:
remove_yticks(ax)
ax.show_spines('lb')
ax1.set_ylabel('Determinant')
ax1.set_xlabel('mV/cm')
ax1.set_xlabel('Distance [cm]')
plt.subplots_adjust(left=0.03, wspace=0.3)
save_visualization(frame_name, False, show_anything=False, pdf=True,
jpg=True, png=False,
counter_contrast=0, savename='', add='_' + female)
plt.show()
def plt_several_ROC_square_nonlin(brust_corrs=['_burstIndividual_'], nffts=['whole'], powers=[1],
contrasts=[0], column=None, noises_added=[''],
D_extraction_method=['additiv_visual_d_4_scaled'],
internal_noise=['eRAM'], external_noise=['eRAM'],
level_extraction=['_RAMdadjusted'], repeats=[1000000],
receiver_contrast=[1], dendrids=[''], ref_types=[''], adapt_types=[''],
c_noises=[0.1], c_signal=[0.9], cut_offs1=[300], label=r'$\frac{1}{mV^2S}$'):
plot_style()
default_settings(column=column, width=12) # ts=12, ls=13, fs=11,
cells = [
"2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ]
grid = gridspec.GridSpec(1, 2, wspace=0.35, hspace=0.5, left=0.06, top=0.8, bottom=0.15,
right=0.96) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8,
###################################
# plot square
ax = plt.subplot(grid[0])
square_part(ax)
ax.set_aspect('equal')
####################################
# plot nonlin
ax = plt.subplot(grid[1])
trials_nrs = [1]
iternames = [brust_corrs, cells, D_extraction_method, external_noise,
repeats, internal_noise, powers, nffts, dendrids, cut_offs1, trials_nrs, c_signal,
c_noises,
ref_types, adapt_types, noises_added, level_extraction, receiver_contrast, contrasts, ]
for all in it.product(*iternames):
burst_corr, cell, var_type, stim_type_afe, trials_stim, stim_type_noise, power, nfft, dendrid, cut_off1, trial_nrs, c_sig, c_noise, ref_type, adapt_type, noise_added, extract, a_fr, a_fe = all
print(trials_stim, stim_type_noise, power, nfft, a_fe, a_fr, dendrid, var_type, cut_off1, trial_nrs)
nr = '2'
trial_nr = 250000
save_name = load_folder_name(
'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_' + str(
trial_nr) + '_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
path = save_name + '.pkl' # '../'+
model = pd.read_pickle(path) # load_data(path, cells, save_name)
model_show = model[(
model.cell == cell)]
new_keys = model_show.index.unique() # [0:490]
stack_plot = model_show[new_keys] # [list(map(str, new_keys))]
stack_plot = np.abs(stack_plot.iloc[np.arange(0, len(new_keys), 1)])
ax.set_xlim(0, 237)
ax.set_ylim(0, 237)
ax.set_aspect('equal')
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
model_params = model_cells[model_cells['cell'] == cell]
noise_strength = model_params.noise_strength.iloc[0] # **2/2
D = noise_strength # (noise_strength ** 2) / 2
_, _, _ = D_derive(model_show, save_name, c_sig, D=D, base='', nr=nr) # var_based
stack_plot = RAM_norm(stack_plot, trials_stim=trials_stim, model_show=model_show)
perc = '10' # 'perc'
im = plt_RAM_perc(ax, perc, stack_plot)
ax.set_aspect('equal')
cbar = plt.colorbar(im, ax=ax, orientation='vertical') # pad=0.2, shrink=0.5, "horizontal"
cbar.set_label(label, labelpad=100) # rotation=270,
ax.set_xlabel(F1_xlabel(), labelpad=20)
ax.set_ylabel(F2_xlabel())
save_visualization(jpg=True, png=False)
plt.show()
def plt_several_ROC_declining_classified_small():
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal']
cm = plt.get_cmap("hsv")
cells_chosen = ['2013-01-08-aa-invivo-1', "2012-06-27-ah-invivo-1",
"2014-06-06-ac-invivo-1"] # '2012-06-27-an-invivo-1',
cells = ["2013-01-08-aa-invivo-1", "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1",
"2012-06-27-ah-invivo-1", ]
x_pos = 0.02
grid = gridspec.GridSpec(1, 5, wspace=0.2, hspace=0.5, left=0.1, top=0.8, bottom=0.15,
right=0.95, width_ratios=[1, 1, 1, 0.5, 1]) # height_ratios = [1,6]bottom=0.25, top=0.8,
grid1 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.3, hspace=0.75,
subplot_spec=grid[-1])
for c, cell in enumerate(cells_chosen):
grid0 = gridspec.GridSpecFromSubplotSpec(5, 1, wspace=0.2, hspace=0.35,
subplot_spec=grid[c]) # height_ratios=[1, 0.7, 1, 1],
for f, frame_name in enumerate(frame_names):
path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv'
if os.path.exists(path):
frame = pd.read_csv(path)
title = cut_title(frame_name, datapoints=100)
plt.suptitle(title)
path_ref = load_folder_name(
'calc_ROC') + '/calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
frame_ref = frame_ref.sort_values(by='cv_0')
colr = [cm(float(i) / (len(frame_ref))) for i in range(len(frame_ref))]
cells_sorted = frame_ref.cell.unique()
_, _ = find_row_col(cells, row=4)
frame_cell = frame[frame.cell == cell]
ax0 = plt.subplot(grid0[f])
ax1 = plt.subplot(grid0[3])
ax2 = plt.subplot(grid0[4])
axs = [ax0, ax1]
colors = ['black', 'grey', 'lightgrey', ]
for ax in axs:
if len(frame_cell) > 0:
plt_area_between(frame_cell.c1, frame_cell, ax0, ax, colors, colors, f)
ax.axhline(0, linestyle='--', color='grey', linewidth=0.5)
col_pos = np.where(cells_sorted == cell)[0][0]
if f == 0:
ax0.set_title(cell[0:13] + '\n cv ' + str(np.round(np.mean(frame_cell.cv_0.unique()), 2)),
color=colr[col_pos], fontsize=8)
ax.set_ylim(0, 0.5)
if c != 0:
remove_yticks(ax)
remove_yticks(ax2)
else:
ax2.set_ylabel('B1+B2')
if f == 1:
ax0.set_ylabel('Determinant')
remove_xticks(ax0)
remove_xticks(ax1)
ax.axvline(x_pos, color='grey', linestyle='--', linewidth=0.5)
ax2.plot(frame_cell.c1, frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean'], color=colors[f])
ax2.set_xscale('log')
ax2.axvline(x_pos, color='grey', linestyle='--', linewidth=0.5)
ax2.set_xlabel('mV/cm')
ax2.set_ylim(0, 0.5)
######################################
# plot the plot on the right upper part (Area vs CV)
path = load_folder_name('calc_ROC') + '/' + frame_names[0] + '.csv'
ax_scatter = plt.subplot(grid1[0])
ax_scatter_nonlin_sole = plt.subplot(grid1[1])
ax_scatter_nonlin = plt.subplot(grid1[2])
cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one = calc_areas(path, frame_ref, colr, x_pos,
cells_chosen)
ax_scatter.scatter(cvs, diff_areas, color=colr, s=15, clip_on=False)
ax_scatter.axhline(0, linestyle='--', linewidth=0.5, color='grey')
ax_scatter.set_xlabel('CV')
ax_scatter.set_ylabel('Area Detection improvement')
ax_scatter_nonlin_sole.scatter(cvs, nonlin_area, color=colr, s=15, clip_on=False)
ax_scatter_nonlin_sole.axhline(0, linestyle='--', linewidth=0.5, color='grey')
ax_scatter_nonlin_sole.set_xlabel('CV')
ax_scatter_nonlin_sole.set_ylabel('Area Nonlinearity (B1+B2)')
ax_scatter_nonlin.set_xlabel('Area Detection improvement')
ax_scatter_nonlin.set_ylabel('Area Nonlinearity (B1+B2)')
ax_scatter_nonlin.scatter(nonlin_area, diff_areas, color=colr, s=15, clip_on=False)
######################################
# plot the plot on the right lower part (Area vs Nonlin at B1+B2)
save_visualization(png=False)
plt.show()
def plt_ROC_model_w_female2(redo=False, t_off=10, top=0.95, bottom=0.12, add_name='', color0='green', color01='blue',
color02='red',
color012='orange', female='wo_female', reshuffled='reshuffled',
datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False, plus_q='minus',
freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3, way='absolut',
stimulus_length=0.5, runs=3, trials_nr=500, cells=[], show=False, nfft=int(2 ** 15),
beat='',
nfft_for_morph=4096 * 4, fr=None, gain=1, fish_jammer='Alepto', us_name=''):
save_name_roc = 'decline_ROC_examples_trial_nr.csv'
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False)
if cont_redo:
stimulus_length = 0.14
plt.rcParams['lines.linewidth'] = 1
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells) < 1:
cells = model_cells.cell # )
for cell_here in cells:
# sachen die ich variieren will
###########################################
single_waves = ['_SeveralWave_'] # , '_SingleWave_']
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
for a_f1 in a_f1s:
a_frs = [1]
titles_amp = ['base eodf'] # ,'baseline to Zero',]
for a, a_fr in enumerate(a_frs):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0()
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, ) # phaseshift_fr,
# cell_recording, zeros, mimick,
# sampling, fish_receiver, deltat,
# nfft, nfft_for_morph,
# fish_morph_harmonics_var=fish_morph_harmonics_var,
# beat=beat
# embed()
sampling = 1 / deltat
variant = 'sinz'
spikes_base = [[]] * trials_nr
default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(12.43))
default_figsize(width=cm_to_inch(29.21), length=cm_to_inch(13.98))
default_figsize(width=cm_to_inch(31.89), length=cm_to_inch(15))
add_bottom, add_right = implement_fig_borders(bottom=1.59)
default_ticks_talks()
plt.rcParams['figure.facecolor'] = 'none'
fig = plt.figure()
grid = gridspec.GridSpec(1, 2, wspace=0.4, left=0.09, top=top, bottom=bottom + add_bottom,
right=1.02 - add_right, height_ratios=[1],
width_ratios=[4, 2.8]) # 1.3,1
grid0 = gridspec.GridSpecFromSubplotSpec(3, 2, wspace=0.18, hspace=0.1,
subplot_spec=grid[0],
height_ratios=[1, 0.6, 1]) # ,0.4,1.2
grid1 = gridspec.GridSpecFromSubplotSpec(1, 1,
subplot_spec=grid[1]) # wspace=0.5, hspace=0.55,
for run in range(runs):
print(run)
t1 = time.time()
for t in range(trials_nr):
# get the baseline properties here
# baseline_after,spikes_base,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
stimulus = eod_fish_r
stimulus_base = eod_fish_r
if 'Zero' in titles_amp[a]:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus,
deltat=deltat,
power_variant=power_here,
power_nr=n, **model_params)
if t == 0:
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
if printing:
print('Baseline time' + str(time.time() - t1))
base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
if not fr:
fr = np.mean(base_cut)
if 'diagonal' in diagonal:
two_third_fr = fr * freq2_ratio
freq1_ratio = (1 - freq2_ratio)
third_fr = fr * freq1_ratio
else:
two_third_fr = fr * freq2_ratio
third_fr = fr * freq1_ratio
if plus_q == 'minus':
two_third_fr = -two_third_fr
third_fr = -third_fr
freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr,
freqs1 = [
eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ]
sampling_rate = 1 / deltat
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat,
stimulus_length, dev=dev)
fr = np.mean(base_cut)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
isi = np.diff(spikes_base[0])
cv0 = np.std(isi) / np.mean(isi)
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
freq2 = [freqs2[ff]]
print(cell + ' f1' + str(freq1) + ' f2 ' + str(freq2) + ' f1' + str(
freq1 - eod_fr) + ' f2 ' + str(freq2 - eod_fr))
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1,
phaseshift_f1, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_emitter, thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2,
phaseshift_f2, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_jammer, thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2,
stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2,
trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev,
redo_stim=False)
if printing:
print('Generation process' + str(time.time() - t1))
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
position_diff = 0
results_diff = pd.DataFrame()
results_diff['f1'] = freq1
results_diff['f2'] = freq2
results_diff['f0'] = eod_fr
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off,
way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling)
if run == 0:
color = 'black'
lw = 1.5
z = 2
else:
color = 'grey'
lw = 0.8
z = 1
color0 = 'black'
if cont_redo:
frame = pd.read_csv(save_name_roc)
tp_012_all = frame['tp_012'] # = tp_012_all
tp_01_all = frame['tp_01'] # = tp_01_all
tp_02_all = frame['tp_02'] # = tp_02_all
fp_all = frame['fp_all'] # = fp_all
if 'wo_female' in female:
ax_roc_wof = plt.subplot(grid1[0])
roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01,
title_color=color01, z=z)
elif 'base_female' in female:
ax_roc_wof = plt.subplot(grid1[0])
roc_female(ax_roc_wof, color, fp_all, tp_02_all, lw, color0, color02, z=z,
add_01='\n Female', add_base=' Baseline')
ax_roc_wof.set_title('Receiver Operating Characteristics (ROC)', pad=15)
else:
ax_roc_wf = plt.subplot(grid1[0])
roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012,
title_color=color012, z=z)
if run == 0:
plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012,
spikes_base, mat_base, mat05_01,
mat05_012, mat05_02, color02, color012, a_f2, trials, sampling,
a_f1, fr, female, color01, color0, grid0, eod_fr,
freq2,
freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01,
stimulus_base, time_array, carrier=True)
ax = fig.axes
remove_axes_roc_traces(ax, add=1)
for aa, ax_here in enumerate(ax[2:5]):
ax_here.set_xticks([])
for aa, ax_here in enumerate(ax[1::]):
if aa not in np.arange(0, 100, 2):
pass
else:
ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2])
plt.subplots_adjust(top=0.95, left=0.09, right=0.95, hspace=0.5, bottom=0.12, wspace=0.25)
individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str(
trials_nr) + '_stimulus_length_' + str(
stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr)
save_visualization(individual_tag, show=show, add=add_name, pdf=pdf, counter_contrast=0,
savename='')
def remove_axes_roc_traces(ax, add=0):
ax[0 + add].show_spines('')
ax[1 + add].show_spines('')
ax[2 + add].show_spines('')
ax[3 + add].show_spines('')
ax[4 + add].set_ylabel('Firing Rate [Hz]')
ax[4 + add].set_xlabel('Time [ms]')
ax[5 + add].set_xlabel('Time [ms]')
ax[5 + add].show_spines('b')
def implement_fig_borders(bottom=1.89):
bottom_pp = cm_to_inch(bottom)
rigth_pp = cm_to_inch(2.33)
add_right = rigth_pp / plt.rcParams['figure.figsize'][0]
add_bottom = bottom_pp / plt.rcParams['figure.figsize'][1]
return add_bottom, add_right
def plt_ROC_model_w_female(redo=False, t_off=10, top=0.95, bottom=0.14, add_name='', color0='green', color01='blue',
color02='red',
color012='orange', figsize=(11.5, 5.4), female='wo_female', reshuffled='reshuffled',
datapoints=1000, dev=0.0005, a_f1s=[0.03], pdf=True, printing=False, plus_q='minus',
freq1_ratio=1 / 2, diagonal='diagonal', freq2_ratio=2 / 3, way='absolut',
stimulus_length=0.5, runs=3, trials_nr=500, cells=[], show=False, nfft=int(2 ** 15), beat='',
nfft_for_morph=4096 * 4, fr=None, gain=1, fish_jammer='Alepto', us_name=''):
save_name_roc = 'decline_ROC_examples_trial_nr.csv'
version_comp, subfolder, mod_name_slash, mod_name, subfolder_path = find_code_vs_not()
cont_redo = ((os.path.exists(save_name_roc)) | (version_comp == 'public')) & (redo == False)
if cont_redo:
stimulus_length = 0.14
plt.rcParams['lines.linewidth'] = 1
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells) < 1:
cells = model_cells.cell # )
for cell_here in cells:
# sachen die ich variieren will
###########################################
single_waves = ['_SeveralWave_'] # , '_SingleWave_']
####### VARY HERE
for single_wave in single_waves:
if single_wave == '_SingleWave_':
a_f2s = [0] # , 0,0.2
else:
a_f2s = [0.1]
for a_f2 in a_f2s:
for a_f1 in a_f1s:
a_frs = [1]
titles_amp = ['base eodf'] # ,'baseline to Zero',]
for a, a_fr in enumerate(a_frs):
model_params = model_cells[model_cells['cell'] == cell_here].iloc[0]
eod_fr = model_params['EODf'] # .iloc[0]
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
SAM, adapt_offset, cell_recording, constant_reduction, damping, damping_type, dent_tau_change, exponential, f1, f2, fish_emitter, fish_receiver, fish_morph_harmonics_var, lower_tol, mimick, n, phase_right, phaseshift_fr, sampling_factor, upper_tol, zeros = default_model0()
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick,
sampling, fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
variant = 'sinz'
if exponential == '':
pass
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr,
nfft_for_morph,
phaseshift_fr,
mimick, zeros,
cell_recording,
sampling,
stimulus_length,
fish_receiver,
deltat, nfft,
damping_type,
damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
spikes_base = [[]] * trials_nr
fig = plt.figure(figsize=figsize)
grid = gridspec.GridSpec(1, 2, wspace=0.3, left=0.09, top=top, bottom=bottom,
right=0.96, height_ratios=[1], width_ratios=[4, 2.8]) # 1.3,1
grid0 = gridspec.GridSpecFromSubplotSpec(3, 2, wspace=0.18, hspace=0.1,
subplot_spec=grid[0],
height_ratios=[1, 0.6, 1]) # ,0.4,1.2
grid1 = gridspec.GridSpecFromSubplotSpec(1, 1,
subplot_spec=grid[1]) # wspace=0.5, hspace=0.55,
for run in range(runs):
print(run)
t1 = time.time()
for t in range(trials_nr):
stimulus = eod_fish_r
stimulus_base = eod_fish_r
if 'Zero' in titles_amp[a]:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spikes_base[t], _, _, offset_new, _, noise_final = simulate(cell, offset, stimulus,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
power_variant=power_here,
power_alpha=alpha,
power_nr=n,
reshuffle=reshuffled,
**model_params)
if t == 0:
offset = offset_new * 1
if printing:
print('Baseline time' + str(time.time() - t1))
base_cut, mat_base = find_base_fr(spikes_base, deltat, stimulus_length, time_array, dev=dev)
if not fr:
fr = np.mean(base_cut)
if 'diagonal' in diagonal:
two_third_fr = fr * freq2_ratio
freq1_ratio = (1 - freq2_ratio)
third_fr = fr * freq1_ratio
else:
two_third_fr = fr * freq2_ratio
third_fr = fr * freq1_ratio
if plus_q == 'minus':
two_third_fr = -two_third_fr
third_fr = -third_fr
freqs2 = [eod_fr + two_third_fr] # , eod_fr - third_fr, two_third_fr,
freqs1 = [
eod_fr + third_fr] # , eod_fr - two_third_fr, third_fr,two_third_fr,third_eodf, eod_fr - third_eodf,two_third_eodf, eod_fr - two_third_eodf, ]
sampling_rate = 1 / deltat
base_cut, mat_base, smoothed0, mat0 = find_base_fr2(spikes_base, deltat,
stimulus_length, dev=dev)
fr = np.mean(base_cut)
_, _ = ISI_frequency(time_array, spikes_base[0], fill=0.0)
isi = np.diff(spikes_base[0])
cv0 = np.std(isi) / np.mean(isi)
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
freq2 = [freqs2[ff]]
print(cell + ' f1' + str(freq1) + ' f2 ' + str(freq2) + ' f1' + str(
freq1 - eod_fr) + ' f2 ' + str(freq2 - eod_fr))
t1 = time.time()
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1,
phaseshift_f1, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_emitter, thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2,
phaseshift_f2, sampling, stimulus_length,
nfft_for_morph, cell_recording,
fish_morph_harmonics_var, zeros, mimick,
fish_jammer, thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
v_mems, offset_new, mat01, mat02, mat012, smoothed01, smoothed02, smoothed012, stimulus_01, stimulus_02, stimulus_012, mat05_01, spikes_01, mat05_02, spikes_02, mat05_012, spikes_012 = get_arrays_for_three(
cell, a_f2, a_f1, SAM, eod_stimulus, eod_fish_r, freq2, eod_fish1, eod_fish2,
stimulus_length, offset, model_params, n, variant, adapt_offset, deltat, f2,
trials_nr, time_array, f1, freq1, eod_fr, reshuffle=reshuffled, dev=dev,
redo_stim=False)
if printing:
print('Generation process' + str(time.time() - t1))
array0 = [mat_base]
array01 = [mat05_01]
array02 = [mat05_02]
array012 = [mat05_012]
position_diff = 0
results_diff = pd.DataFrame()
results_diff['f1'] = freq1
results_diff['f2'] = freq2
results_diff['f0'] = eod_fr
trials, results_diff, tp_012_all, tp_01_all, tp_02_all, fp_all, roc_01, roc_0, roc_02, roc_012, threshhold = calc_auci_pd(
results_diff, position_diff, array012, array01, array02, array0, t_off=t_off,
way=way, printing=True, datapoints=datapoints, f0='f0', sampling=sampling)
if run == 0:
color = 'black'
lw = 1.5
z = 2
else:
color = 'grey'
lw = 0.8
z = 1
if cont_redo:
frame = pd.read_csv(save_name_roc)
tp_012_all = frame['tp_012'] # = tp_012_all
tp_01_all = frame['tp_01'] # = tp_01_all
tp_02_all = frame['tp_02'] # = tp_02_all
fp_all = frame['fp_all'] # = fp_all
if 'wo_female' in female:
ax_roc_wof = plt.subplot(grid1[0])
roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01,
title_color=color01, z=z)
elif 'base_female' in female:
ax_roc_wof = plt.subplot(grid1[0])
roc_female(ax_roc_wof, color, fp_all, tp_02_all, lw, color0, color02, z=z,
add_01='\n Female', add_base=' Baseline')
ax_roc_wof.set_title('Receiver Operating Characteristics (ROC)')
else:
ax_roc_wf = plt.subplot(grid1[0])
roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012,
title_color=color012, z=z)
if run == 0:
plt_traces_to_roc(freq2_ratio, freq1_ratio, t_off, spikes_02, spikes_01, spikes_012,
spikes_base, mat_base, mat05_01,
mat05_012, mat05_02, color02, color012, a_f2, trials, sampling,
a_f1, fr, female, color01, color0, grid0, eod_fr,
freq2,
freq1, sampling_rate, stimulus_012, stimulus_02, stimulus_01,
stimulus_base, time_array, carrier=True)
ax = fig.axes
ax[0 + 1].set_ylabel('Amplitude')
ax[2 + 1].set_ylabel('Trials')
ax[4 + 1].set_ylabel('Firing Rate [Hz]')
ax[3 + 2].set_xlabel('Time [ms]')
ax[4 + 2].set_xlabel('Time [ms]')
for aa, ax_here in enumerate(ax[2:5]):
ax_here.set_xticks([])
for aa, ax_here in enumerate(ax[1::]):
if aa not in np.arange(0, 100, 2):
pass
else:
ax_here.get_shared_y_axes().join(*ax[1 + aa:1 + aa + 2])
fig.tag([ax[1], ax[2], ax[0]], xoffs=-4.6, yoffs=1.5)
plt.subplots_adjust(top=0.95, left=0.09, right=0.95, hspace=0.5, bottom=0.12, wspace=0.25)
individual_tag = '_way_' + str(way) + '_runs_' + str(runs) + '_trial_nr_' + str(
trials_nr) + '_stimulus_length_' + str(
stimulus_length) + cell + ' cv ' + str(cv0) + single_wave + '_a_f0_' + str(
a_fr) + '_a_f1_' + str(a_f1) + '_a_f2_' + str(a_f2) + '_trialsnr_' + str(trials_nr)
save_visualization(individual_tag, show=show, add=add_name, pdf=pdf, counter_contrast=0,
savename='')
def roc_wo_female(color, ax_roc_wf, tp_02_all, tp_012_all, color02, color012, add_01='\n Intruder + Female', z=2,
add_base=' Female', color_e='grey', title_color='black'):
ax_roc_wf.set_title(r'With Female: ROC\ensuremath{\rm{_{Female}}}', color=title_color) # linewidth=lw,'With Female'
ax_roc_wf.plot(tp_02_all, tp_012_all, color=color, zorder=z, clip_on=False) # , aspect = 'auto'
ax_roc_wf.set_aspect('equal')
if color_e:
ax_roc_wf.plot([0, 1], [0, 1], color=color_e, linestyle='--')
ax_roc_wf.set_xlabel('False-Positive Rate: ' + add_base, color=color02)
ax_roc_wf.set_ylabel('Correct-Detection Rate: ' + add_01, color=color012)
def roc_female(ax_roc_wof, color, fp_all, tp_01_all, lw, color0, color01, color_e='grey', z=2,
title_color='black',
add_01='\n Intruder', add_base=' Baseline'):
ax_roc_wof.set_title(r'Without Female: ROC\ensuremath{\rm{_{NoFemale}}}', color=title_color) # 'Without Female'
ax_roc_wof.plot(fp_all, tp_01_all, color=color, linewidth=lw, zorder=z, clip_on=False) # , aspect = 'auto'
if color_e:
ax_roc_wof.plot([0, 1], [0, 1], color=color_e, linestyle='--', clip_on=False)
ax_roc_wof.set_aspect('equal')
ax_roc_wof.set_xlabel('False-Positive Rate: ' + add_base, color=color0) #: 0
ax_roc_wof.set_ylabel('Correct-Detection Rate: ' + add_01, color=color01) # 01
def c_to_dist_reverse(distance, power=2.09, factor=12.23):
c_changed = factor / distance ** power
return c_changed
def Pl_model(freqs=[(39.5, -210.5)], printing=False, beat='',
nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[],
fish_jammer='Alepto', us_name='', show=True,
c_nrs_orig=[0.03, 0.2, 0.8]): # "2013-01-08-aa-invivo-1"
runs = 1
n = 1
dev = 0.0005
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=6) # 6.4
default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(15.2))
default_ticks_talks()
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
for cell_here in cells_here:
c_grouped = ['c1'] # , 'c2']
frame_cell_orig = frame[(frame.cell == cell_here)]
if freq_mult:
freqs = freq_two_mult_recalc(frame_cell_orig, freqs)
if len(frame_cell_orig) > 0:
print('cell there')
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
grid0 = gridspec.GridSpec(1, 1, bottom=0.2, top=0.8, left=0.115,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 2,
wspace=0.4, hspace=0.1, width_ratios = [2, 1],
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_ll = gridspec.GridSpecFromSubplotSpec(1, len(c_nrs_orig),
hspace=0.35,
wspace=0.2,
subplot_spec=grid00[0]) # height_ratios=[1, 0.8],hspace=0.4,wspace=0.2,len(chirps)
#grid_rr = gridspec.GridSpecFromSubplotSpec(2, 1,
# wspace=0.04, hspace=0.1,
# subplot_spec=grid0[1]) # height_ratios=[2,1],
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
color_stim = color_stim_core()
color_eodf = coloer_eod_fr_core()
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
try:
ax_u1 = plt.subplot(grid00[1])
except:
print('grid search problem5')
embed()
add = get_mean_add(frame_cell_orig)
#_original
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add=add)
scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add, ] # 'amp_B1+B2_012_mean',
labels = labels_pi_core()
alpha = [1, 1, 1]
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2,
scores=scores, colors=[color01, coloer_eod_fr_core(), color_stim_core()],
linestyles=['-', '-', '-'], alpha=alpha, labels=labels,
sum=False, B_replace='F', default_colors=False,
c_dist_recalc=c_dist_recalc, delta=False)
ax_u1.set_xlabel('Contrast$_{' + vary_val() + '}$ [$\%$]')
if f != 0:
print('hi')
else:
ax_u1.set_ylabel(representation_ylabel(delta=False)) # power_spectrum_name()
axts = []
axps = []
axes = []
recalc = 100
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc)
mults_period = 3
xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
letters = ['A', 'B', 'C']
height = 240
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.scatter(c_nrs, height * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False,
s=7)
ax_u1.text(c_nr, height + 15, letters[c_nn], ha='center', va='center', color='black')
ax_u1.plot([c_nr, c_nr], [0, height], color='black', linewidth=0.05, clip_on=False)
ax_u1.set_ylim(0, 285)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, 'original', cell_here, dev_name=dev_name,
a_f1s=[c_nrs_orig[c_nn]], n=n,
reshuffled=reshuffled, min_amps=min_amps)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps)
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
#######################
# plot the first array
arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays,
arrays_spikes,
arrays_stim,
choice='01')
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_here = p_arrays[1::]
for a in range(len(arrays_here)):
print('a' + str(a))
eodf = frame_cell.f0.iloc[0]
f1 = frame_cell.f1.iloc[0]
colors_peaks = [color01, color_stim, color_eodf] # , 'red']
freqs_psd = [np.abs(freq1), f1, eodf]
grid_pt = gridspec.GridSpecFromSubplotSpec(6, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[a, c_nn],
height_ratios=[1, 0.7, 0, 1, 0.25,
2.2
]) # .2 hspace=0.4,wspace=0.2,len(chirps)
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here)
if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
title_name = ' $c_{' + vary_val() + '}=%s\% c' + stable_val() + '=' % (
((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str(
elif a == 0: # elif (a_f1s[0] != 0):_{p}_{s}
title_name = ' $c_{' + vary_val() + '}=%s$' % int(np.round(
c_nrs[c_nn])) + '\,$\%$, ' + '\n $\Delta f_{' + vary_val() + '}= %s$\,Hz' % (
int(freq1)) # str() #+ '$\%$'
elif a == 1: # elif (a_f2s[0] != 0):
title_name = ' $c_{' + vary_val() + '}=%s$' % int(
np.round(a_f2_cm[0])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % (
int(freq1)) # str()
axe.text(1, 1, title_name, va='bottom', ha='right',
transform=axe.transAxes)
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim)
#############################
axt = plt.subplot(grid_pt[3])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
#############################
axp = plt.subplot(grid_pt[5])
axps.append(axp)
log = '' # 'log' # 'log'
maxx = eodf * 1.15 # 5
pp = log_calc_psd(log, p_arrays_here[a][0],
np.nanmax(p_arrays_here))
freqs_peaks1, colors_peaks1, labels1, alphas1 = chose_all_freq_combos(freq2, colors_array,
freq1,
maxx, eodf,
color_eodf=coloer_eod_fr_core(),
name='01',
stim_thing=False,
color_stim=color_stim_core(),
color_stim_mult=color_stim_core())
plt_peaks_several(freqs_peaks1, [pp], axp, pp, ff, labels1, 0, colors_peaks1, limit=10000,
alphas=alphas1, ms=25, clip_on=False)
plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd,
colors_peaks=colors_peaks, xlim=(0, maxx))
if log:
axp.show_spines('b')
if a == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
else:
axp.show_spines('lb')
if c_nn != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
if a == 0:
axt.show_spines('')
if c_nn == 0:
axt.xscalebar(0.3, -0.1, 5, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top')
axp.set_xlabel('Frequency [Hz]')
#############################
isis = False
if isis:
axi = plt.subplot(grid_pt[-1])
isis = []
for t in range(len(arrays_sp[a])):
isi = calc_isi(arrays_sp[a][t], eodf)
isis.append(isi)
axi.hist(np.concatenate(isis), bins=100, color='grey')
axi.set_xlabel(isi_xlabel())
axi.show_spines('b')
f_counter += 1
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[0].legend(ncol=1, loc=(-0.25, 1.01), columnspacing=2.5) # 5 -0.07#loc=(0.9, 0.7)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_y(axts)
set_same_ylim(axts)
set_same_ylim(axps)
join_x(axts)
join_x(ax_us)
join_y(ax_us)
fig = plt.gcf()
#fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=3)
#fig.tag([ax_u1], xoffs=-3, yoffs=3)
save_visualization(cell_here, show)
print('finished cell here')
def get_mean_add(frame_cell_orig):
if 'amp_B1_01_mean_original' in frame_cell_orig.keys():
add = '_mean_original'
else:
add = '_mean'
return add
def vary_contrasts50(freqs=[(39.5, -210.5)], printing=False, beat='',
nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[],
fish_jammer='Alepto', us_name='', show=True,
c_nrs_orig=[0.03, 0.2, 0.8]): # "2013-01-08-aa-invivo-1"
runs = 1
n = 1
dev = 0.0005
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=6) # 6.4
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
for cell_here in cells_here:
c_grouped = ['c1'] # , 'c2']
frame_cell_orig = frame[(frame.cell == cell_here)]
if freq_mult:
freqs = freq_two_mult_recalc(frame_cell_orig, freqs)
if len(frame_cell_orig) > 0:
print('cell there')
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.96, left=0.115,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 1,
wspace=0.04, hspace=0.1,
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_ll = gridspec.GridSpecFromSubplotSpec(2, len(c_nrs_orig),
hspace=0.35,
wspace=0.2, height_ratios=[1, 0.8],
subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps)
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
color_stim = color_stim_core()
color_eodf = coloer_eod_fr_core()
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
try:
ax_u1 = plt.subplot(grid_ll[-1, :])
except:
print('grid search problem2')
embed()
if 'amp_B1_01_mean_original' in frame_cell_orig.keys():
add = '_mean_original'
else:
add = '_mean'
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add=add)
scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add, ] # 'amp_B1+B2_012_mean',
labels = labels_pi_core()
alpha = [1, 1, 1]
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2,
scores=scores, colors=[color01, coloer_eod_fr_core(), color_stim_core()],
linestyles=['-', '-', '-'], alpha=alpha, labels=labels,
sum=False, B_replace='F', default_colors=False,
c_dist_recalc=c_dist_recalc, delta=False)
ax_u1.set_xlabel('Contrast$_{' + vary_val() + '}$ [$\%$]')
if f != 0:
print('hi')
else:
ax_u1.set_ylabel(representation_ylabel(delta=False)) # power_spectrum_name()
axts = []
axps = []
axes = []
recalc = 100
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc)
mults_period = 3
xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
letters = ['A', 'B', 'C']
height = 240
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.scatter(c_nrs, height * np.ones(len(c_nrs)), color='black', marker='v', clip_on=False,
s=7)
ax_u1.text(c_nr, height + 15, letters[c_nn], ha='center', va='center', color='black')
ax_u1.plot([c_nr, c_nr], [0, height], color='black', linewidth=0.05, clip_on=False)
ax_u1.set_ylim(0, 285)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, 'original', cell_here, dev_name=dev_name,
a_f1s=[c_nrs_orig[c_nn]], n=n,
reshuffled=reshuffled, min_amps=min_amps)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps)
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
#######################
# plot the first array
arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays,
arrays_spikes,
arrays_stim,
choice='01')
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_here = p_arrays[1::]
for a in range(len(arrays_here)):
print('a' + str(a))
eodf = frame_cell.f0.iloc[0]
f1 = frame_cell.f1.iloc[0]
colors_peaks = [color01, color_stim, color_eodf] # , 'red']
freqs_psd = [np.abs(freq1), f1, eodf]
grid_pt = gridspec.GridSpecFromSubplotSpec(6, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[a, c_nn],
height_ratios=[1, 0.7, 0, 1, 0.25,
2.2
]) # .2 hspace=0.4,wspace=0.2,len(chirps)
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here)
if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
title_name = ' $c_{' + vary_val() + '}=%s\% c' + stable_val() + '=' % (
((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str(
elif a == 0: # elif (a_f1s[0] != 0):_{p}_{s}
title_name = ' $c_{' + vary_val() + '}=%s$' % int(np.round(
c_nrs[c_nn])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % (
int(freq1)) # str() #+ '$\%$'
elif a == 1: # elif (a_f2s[0] != 0):
title_name = ' $c_{' + vary_val() + '}=%s$' % int(
np.round(a_f2_cm[0])) + '\,$\%$, ' + ' $\Delta f_{' + vary_val() + '}= %s$\,Hz' % (
int(freq1)) # str()
axe.text(1, 1, title_name, va='bottom', ha='right',
transform=axe.transAxes)
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim)
#############################
axt = plt.subplot(grid_pt[3])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
#############################
axp = plt.subplot(grid_pt[5])
axps.append(axp)
log = '' # 'log' # 'log'
maxx = eodf * 1.15 # 5
pp = log_calc_psd(log, p_arrays_here[a][0],
np.nanmax(p_arrays_here))
freqs_peaks1, colors_peaks1, labels1, alphas1 = chose_all_freq_combos(freq2, colors_array,
freq1,
maxx, eodf,
color_eodf=coloer_eod_fr_core(),
name='01',
stim_thing=False,
color_stim=color_stim_core(),
color_stim_mult=color_stim_core())
plt_peaks_several(freqs_peaks1, [pp], axp, pp, ff, labels1, 0, colors_peaks1, limit=10000,
alphas=alphas1, ms=25, clip_on=False)
plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd,
colors_peaks=colors_peaks, xlim=(0, maxx))
if log:
axp.show_spines('b')
if a == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
else:
axp.show_spines('lb')
if c_nn != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
if a == 0:
axt.show_spines('')
axt.xscalebar(0.1, -0.1, 5, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top')
axp.set_xlabel('Frequency [Hz]')
#############################
isis = False
if isis:
axi = plt.subplot(grid_pt[-1])
isis = []
for t in range(len(arrays_sp[a])):
isi = calc_isi(arrays_sp[a][t], eodf)
isis.append(isi)
axi.hist(np.concatenate(isis), bins=100, color='grey')
axi.set_xlabel(isi_xlabel())
axi.show_spines('b')
f_counter += 1
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[0].legend(ncol=3, loc=(0, 1), columnspacing=2.5) # 5 -0.07#loc=(0.9, 0.7)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_y(axts)
set_same_ylim(axts)
set_same_ylim(axps)
join_x(axts)
join_x(ax_us)
join_y(ax_us)
fig = plt.gcf()
fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=1)
fig.tag([ax_u1], xoffs=-3, yoffs=3)
save_visualization(cell_here, show)
print('finished cell here')
def color_stim_core():
return 'grey'
def coloer_eod_fr_core():
return 'black'
def labels_pi_core():
labels = [DF_pi_core(),
f_pi_core(),
f_eod_pi_core(),
] # ('+f_eod_name_core_rm()+' + f_{p})$('+f_eod_name_core_rm()+' + f_{p}) ('+f_eod_name_core_rm()+' + f_{p})
return labels
def vary_contrasts5(freqs=[(39.5, -210.5)], printing=False, beat='',
nfft_for_morph=4096 * 4, gain=1, freq_mult=False, cells_here=[],
fish_jammer='Alepto', us_name='', show=True,
c_nrs_orig=[0.05, 0.1, 0.8]): # "2013-01-08-aa-invivo-1"
runs = 1
n = 1
dev = 0.0005
#############################################
# plot a single ROC Curve for the model!
# das aus dem Lissabon talk und das was wir für Jörg verwenden werden
# also wir wollen hier viele Kontraste und einige Frequenzen
# das will ich noch für verschiedene Frequenzen und Kontraste
default_settings() # ts=13, ls=13, fs=13, lw = 0.7
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_settings(column=2, length=6.5)
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
for cell_here in cells_here:
c_grouped = ['c1'] # , 'c2']
frame_cell_orig = frame[(frame.cell == cell_here)]
if freq_mult:
freqs = freq_two_mult_recalc(frame_cell_orig, freqs)
if len(frame_cell_orig) > 0:
print('cell there')
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.92, left=0.11,
right=0.95,
wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 1,
wspace=0.04, hspace=0.1,
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_ll = gridspec.GridSpecFromSubplotSpec(2, len(c_nrs_orig),
hspace=0.35,
wspace=0.2, height_ratios=[1, 0.8],
subplot_spec=grid00[0]) # hspace=0.4,wspace=0.2,len(chirps)
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
# frame_cell = frame_cell_orig[ == freq1 & ]
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
color_stim = 'grey'
color_eodf = 'black'
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
try:
ax_u1 = plt.subplot(grid_ll[-1, :])
except:
print('grid search problem3')
embed()
if 'amp_B1_01_mean_original' in frame_cell_orig.keys():
add = '_mean_original'
else:
add = '_mean'
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add=add)
scores = ['amp_B1_01' + add, 'amp_f0_01' + add, 'amp_f1_01' + add, ] # 'amp_B1+B2_012_mean',
labels = ['$\Delta f_{p}$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$',
'$'+f_eod_name_core_rm()+'$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$',
'$f_{p}$ peak in ' + onebeat_cond() + ' $('+f_eod_name_core_rm()+' + f_{p})$',
]
alpha = [1, 1, 1]
c_dist_recalc = dist_recalc_phaselockingchapter()
ax_us = plt_single_trace(ax_us, ax_u1, frame_cell_orig, freq1, freq2,
scores=scores, colors=[color01, 'black', 'grey'],
linestyles=linestyles, alpha=alpha, labels=labels,
sum=False, B_replace='F', default_colors=False,
c_dist_recalc=c_dist_recalc)
if f != 0:
print('hi')
else:
ax_u1.set_ylabel(power_spectrum_name())
plt.suptitle(' $\Delta f_{p}= %s $ Hz' % (int(freq1))) # + cell_here + ' DF2=' + str(freq2)
axts = []
axps = []
axes = []
recalc = 100
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc, recalc_contrast_in_perc=recalc)
mults_period = 3
xlim = [1000, 1000 + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.scatter(c_nrs, np.zeros(len(c_nrs)), color='black', marker='^', clip_on=False)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays_original, names, p_arrays, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, 'original', cell_here, dev_name=dev_name,
a_f1s=[c_nrs_orig[c_nn]], n=n,
reshuffled=reshuffled, min_amps=min_amps)
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, _, ff = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft, us_name, gain, runs, a_fr, nfft_for_morph, beat, printing,
stimulus_length,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps)
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
# plot the first array
arrays_here, arrays_sp, arrays_st, arrays_time = choose_arrays_phaselocking(arrays,
arrays_spikes,
arrays_stim,
choice='01')
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_here = p_arrays[1::]
for a in range(len(arrays_here)):
print('a' + str(a))
eodf = frame_cell.f0.iloc[0]
f1 = frame_cell.f1.iloc[0]
colors_peaks = [color01, color_stim, color_eodf] # , 'red']
freqs_psd = [np.abs(freq1), f1, eodf]
grid_pt = gridspec.GridSpecFromSubplotSpec(8, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[a, c_nn],
height_ratios=[1, 0.7, 0, 1, 0.25,
2.2, 1,
1.2]) # hspace=0.4,wspace=0.2,len(chirps)
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here)
if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
title_name = ' c$_{p}=%s\% c2=' % (
((int(np.round(a_f2_cm[0]))), int(np.round(c_nrs[c_nn])))) # + '$\%$'str(
elif a == 0: # elif (a_f1s[0] != 0):
title_name = ' c$_{p}=%s$' % int(np.round(c_nrs[c_nn])) + '$\%$' # str() #+ '$\%$'
elif a == 1: # elif (a_f2s[0] != 0):
title_name = ' $c2=%s$' % int(np.round(a_f2_cm[0])) + '$\%$' # str()
axe.text(1, 1, title_name, va='bottom', ha='right',
transform=axe.transAxes)
#############################
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim)
#############################
axt = plt.subplot(grid_pt[3])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
#############################
axp = plt.subplot(grid_pt[-3])
axps.append(axp)
log = '' # 'log'
maxx = eodf * 1.15 # 5
pp = log_calc_psd(log, p_arrays_here[a][0],
np.nanmax(p_arrays_here))
plt_psd_saturation(pp, ff, a, axp, colors_array_here, freqs=freqs_psd,
colors_peaks=colors_peaks, xlim=(0, maxx))
if log:
axp.show_spines('b')
if a == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
else:
axp.show_spines('lb')
if c_nn != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
if a == 0:
axt.show_spines('')
axt.xscalebar(0.1, -0.1, 5, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top')
axp.set_xlabel('Frequency [Hz]')
freqs_peaks, colors_peaks, labels, alphas = chose_all_freq_combos(freq2, colors_array,
freq1,
maxx, eodf,
color_eodf='black',
name='01',
color_stim='grey',
color_stim_mult='grey')
plt_peaks_several(freqs_peaks, [pp], axp, pp, ff, labels, 0, colors_peaks, limit=10000,
alphas=alphas, ms=18, clip_on=False)
#############################
axi = plt.subplot(grid_pt[-1])
isis = []
for t in range(len(arrays_sp[a])):
isi = calc_isi(arrays_sp[a][t], eodf)
isis.append(isi)
axi.hist(np.concatenate(isis), bins=100, color='grey')
axi.set_xlabel(isi_xlabel())
axi.show_spines('b')
f_counter += 1
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[0].legend(ncol=3, loc=(0, 1)) # -0.07#loc=(0.9, 0.7)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_y(axts)
set_same_ylim(axts)
set_same_ylim(axps)
join_x(axts)
join_x(ax_us)
join_y(ax_us)
fig = plt.gcf()
fig.tag([axes[0], axes[1], axes[2]], xoffs=-3, yoffs=1)
fig.tag([ax_u1], xoffs=-3, yoffs=3)
save_visualization(cell_here, show)
print('finished cell here')
def power_spectrum_name():
return 'Power [Hz]'
def choose_arrays_phaselocking(arrays, arrays_spikes, arrays_stim, choice='all'):
if choice == 'all':
arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::]
arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]#
arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]#
arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::]
elif choice == '01':
arrays_time = [arrays[1]] # [v_mems[1],v_mems[3]]#[1,2]#[1::]
arrays_here = [arrays[1]] # [arrays[1],arrays[3]]#arrays[1::]#
arrays_st = [arrays_stim[1]] # [arrays_stim[1],arrays_stim[3]]#
arrays_sp = [arrays_spikes[1]] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::]
return arrays_here, arrays_sp, arrays_st, arrays_time
def f_vary_name(delta=False, freq=None):
if delta:
val = '\ensuremath{\Delta f_{1}}'
else:
val = '\ensuremath{f_{1}}'
if freq:
val = '$' + val + '=%s$' % freq + '\,Hz'
return val
def f_stable_name(freq=None, delta=False):
if delta:
val = '\ensuremath{\Delta f_{2}}'
else:
val = '\ensuremath{f_{2}}'
if freq:
val = '$' + val + '=%s$' % freq + '\,Hz'
return val
def strong_signals(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4,
gain=1,
cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='',
show=True,indexes = [[0, 1, 2, 3,4]]):
runs = 1
n = 1
dev = 0.001
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=7.5)
default_figsize(width=cm_to_inch(21.6), length=cm_to_inch(14))
default_ticks_talks()
for _ in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str(
stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
full_names = [
'calc_model_amp_freqs-F1_750-975-75_F2_500-725-75_C2_0.1_C1Len_50_FirstC1_0.0001_LastC1_1.0_mult__start_0.0001_end_1_StimLen_25_nfft_32768_trialsnr_1__power_1_minamps__dev_original_05AUCI_point_1temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
c_nrs_orig = [0.2] #0.02, 0.0002, 0.05, 0.5
trials_nr = 20 # 20
redo = False # True
log = 'log' # 'log'
grid0 = gridspec.GridSpec(1, 1, bottom=0.15, top=0.8, left=0.13,
right=0.95, wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 1,
wspace=0.04, hspace=0.1,
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1,wspace=0.04, hspace=0.1,subplot_spec = grid00[0]) # height_ratios=[2,1],
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str(
frame_cell.f2.unique()) + ' for cell ' + str(cell_here))
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean', nr=4)
#add='_mean'
add = get_mean_add(frame_cell_orig)
nr=4
labels = ['$A(\Delta $' + f_vary_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_vary_name(),
'$A(\Delta $' + f_vary_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name(),
'$A(\Delta $' + f_stable_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_stable_name(),
'$A(\Delta $' + f_stable_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name()
,'amp_f0_01' + add,]
labels = ['Intruder detection without female',
'Intruder detection with female',
'Receiver detection with intruder']
# 'Female detection without intruder',
# 'Female detection with intruder',
color01 = 'darkred' # 'darkgreen'
color02 = 'darkblue' # 'darkblue'
color01_012 = 'red' # 'red'#'black'##'lightgreen' # 'blue'#
color02_012 = 'cyan' # 'green'#'lightblue'#'grey'#
colors = ['green', 'orange', 'black']#color02, color02_012,
colors_array = ['grey', 'green', 'lightgreen', 'purple']#color01
dashed = (0, (nr, nr))
linestyles = ['-', '-', '-', '-', '-']
alpha = [1, 1, 1, 1, 1, 1]
linewidths = [1.6, 1.4, 1.4, 1.4]#1.6, 1.4,
scores = ['amp_B1_01' + add, 'amp_f1_01' + add, 'amp_f0_01' + add,'amp_f0_01' + add ] # 'amp_B1+B2_012_mean',
scores = ['amp_B1_01' + add, 'amp_B1_012' + add,
'amp_f0_01' + add] # 'amp_B2_02' + add,
#'amp_B2_012' + add,'amp_B1+B2_012_mean',#($'+f_eod_name_core_rm()+'$ + $f_{p}$)($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)($'+f_eod_name_core_rm()+'$ + $f_{s}$ )($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
c_dist_recalc = dist_recalc_phaselockingchapter()
c_dist_recalc = True
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc)
if c_dist_recalc == False:
c_nrs = np.array(c_nrs) * 100
mults_period = 3
start = 200 # 1000
xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
letters = ['A']#, 'B'
#[0, 1], [2, 3],
for i, index in enumerate(indexes):
try:
ax_u1 = plt.subplot(grid_rr[i])
except:
print('grid search problem')
embed()
try:
plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2,
scores=np.array(scores)[index], labels=np.array(labels)[index],
colors=np.array(colors)[index],
linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index],
alpha=np.array(alpha)[index],lim_recalc = None,
sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc)
except:
print('something lenght')
embed()
ax_us.append(ax_u1)
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
c1 = c_dist_recalc_here(c_dist_recalc, frame_cell)
#ax_u1.set_xlim(0, 50)
if i != 0:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
#if i < 2:
# ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]],
# frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1)
#ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v',
# clip_on=False)
ax_us[-1].set_xlim(0, 200)
ax_us[-1].set_ylim(0, 400)
ax_us[-1].set_xlabel('Intruder distance [cm]')
ax_us[-1].set_ylabel('Detection')
axts = []
axps = []
axes = []
#if len(indexes[0]) == 3:
try:
reorder_legend_handles(ax_us[-1], order=[0, 2, 1],loc=(-0.14, 1.03),ncol=2, handlelength=2.5, fs = None)
except:#[0, 2, 4, 1,3]
print('handles not working')
#ax_us[-1].legend(loc=(-0.14, 1.03), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1)
#axts_all[0].get_shared_y_axes().join(*axts_all)
#axts_all[0].get_shared_x_axes().join(*axts_all)
#axps_all[0].get_shared_y_axes().join(*axps_all)
#axps_all[0].get_shared_x_axes().join(*axps_all)
#join_y(axts)
#set_same_ylim(axts)
#set_same_ylim(axps)
#join_x(axts)
#join_x(ax_us)
#join_y(ax_us)
#fig = plt.gcf()
save_visualization(cell_here, show)
#fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7)
#fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7)
#fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4)
save_visualization(cell_here, show)
def vary_contrasts6(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4,
gain=1,
cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='',
show=True):
runs = 1
n = 1
dev = 0.001
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=7.5)
default_figsize(width=cm_to_inch(33.6), length=cm_to_inch(15.2))
default_ticks_talks()
for _ in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str(
stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
c_nrs_orig = [0.2] #0.02, 0.0002, 0.05, 0.5
trials_nr = 20 # 20
redo = False # True
log = 'log' # 'log'
grid0 = gridspec.GridSpec(1, 1, bottom=0.15, top=0.8, left=0.11,
right=0.95, wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 2,
wspace=0.04, hspace=0.1, width_ratios = [2.5,1],
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_ll = gridspec.GridSpecFromSubplotSpec(len(c_nrs_orig), 4,
hspace=0.75,
wspace=0.1,
subplot_spec=grid00[0]) # width_ratios=[2, 1],height_ratios=[1, 1],1.2hspace=0.4,wspace=0.2,len(chirps)
grid_rr = gridspec.GridSpecFromSubplotSpec(1, 1,wspace=0.04, hspace=0.1,subplot_spec = grid00[1]) # height_ratios=[2,1],
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str(
frame_cell.f2.unique()) + ' for cell ' + str(cell_here))
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean', nr=4)
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
c_dist_recalc = dist_recalc_phaselockingchapter()
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc)
if c_dist_recalc == False:
c_nrs = np.array(c_nrs) * 100
mults_period = 3
start = 200 # 1000
xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
letters = ['A']#, 'B'
indexes = [[0, 1, 2, 3]]#[0, 1], [2, 3],
for i, index in enumerate(indexes):
try:
ax_u1 = plt.subplot(grid_rr[i])
except:
print('grid search problem')
embed()
plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2,
scores=np.array(scores)[index], labels=np.array(labels)[index],
colors=np.array(colors)[index],
linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index],
alpha=np.array(alpha)[index],
sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc)
ax_us.append(ax_u1)
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
c1 = c_dist_recalc_here(c_dist_recalc, frame_cell)
ax_u1.set_xlim(0, 50)
if i != 0:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
if i < 2:
ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]],
frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1)
ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v',
clip_on=False)
for c_nn, c_nr in enumerate(c_nrs):
try:
ax_u1.text(c_nr, yposs[i][c_nn] + 50, letters[c_nn], color='black', ha='center', va='top')
except:
print('assigment thing')
embed()
axts = []
axps = []
axes = []
p_arrays_all = []
for c_nn, c_nr in enumerate(c_nrs):
#################################
# arrays plot
save_dir = load_savedir(level=0).split('/')[0]
name_psd = save_dir + '_psd.npy'
name_psd_f = save_dir + '_psdf.npy'
if (not os.path.exists(name_psd)) | (redo == True):
if log != 'log':
stimulus_length_here = 0.5
nfft_here = 32768
else:
stimulus_length_here = 50
trials_nr = 20
nfft_here = 2 ** 22
else:
nfft_here = 2 ** 14
stimulus_length_here = 0.5
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph, beat,
printing,
stimulus_length_here,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps)
p_arrays_here = p_arrays_p[1::]
xlimp = (0, 300)
for p in range(len(p_arrays_here)):
p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]]
ff_p = ff_p[ff_p < xlimp[1]]
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
# plot the first array
arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::]
arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]#
arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]#
arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::]
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_all.append(p_arrays_here)
for a in range(len(arrays_here)):
print('a' + str(a))
if a == 0:
freqs = [np.abs(freq1)] # ], np.abs(freq2)],
elif a == 1:
freqs = [np.abs(freq2)]
else:
freqs = [np.abs(freq1), np.abs(freq2)]
grid_pt = gridspec.GridSpecFromSubplotSpec(5, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[c_nn, a],
height_ratios=[1, 0.7, 1, 0.25,
2.5]) # hspace=0.4,wspace=0.2,len(chirps)
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here,
c_dist_recalc=c_dist_recalc)
if c_dist_recalc == False:
a_f2_cm = np.array(a_f2_cm) * 100
if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
fish = 'Three fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$
beat_here = twobeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name(
freq=int(freq1), delta=True) + ',\,$c_{1}=%s$' % (
int(np.round(c_nrs[c_nn]))) + '$\%$' + '\n' + f_stable_name(
freq=int(freq2), delta=True) + ',\,$c_{2}=%s$' % (
int(np.round(a_f2_cm[0]))) + '$\%$' # +'$'
title_name = fish + '\n' + beat_here # +c1+c2
elif a == 0: # elif (a_f1s[0] != 0):
beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name(
freq=int(freq1), delta=True) # +'$' + ' $\Delta '
fish = 'Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() # +'$'
c1 = ',\,$c_{1}=%s$' % (int(np.round(c_nrs[c_nn]))) + '$\%$ \n '
title_name = fish + '\n' + beat_here + c1 # +'cm'+'cm'+'cm'
elif a == 1: # elif (a_f2s[0] != 0):
beat_here = ' ' + onebeat_cond(big=True, double=True,
cond=False) + '\,' + f_stable_name(freq=int(freq2),
delta=True) # +'$'
fish = '\n Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_stable_name() # +'$'
c1 = ',\,$c_{2}=%s$' % (int(np.round(a_f2_cm[0]))) + '$\%$ \n'
title_name = fish + '\n' + beat_here + c1 # +'cm'
text = False
if text:
axe.text(1, 1.1, title_name, va='bottom', ha='right',
transform=axe.transAxes)
#############################
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1)
#############################
axt = plt.subplot(grid_pt[2])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
axp = plt.subplot(grid_pt[-1])
axps.append(axp)
if a == 0:
axt.show_spines('')
if c_nn == 0:
axt.xscalebar(0.2, -0.1, 10, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top')
f_counter += 1
if (not os.path.exists(name_psd)) | (redo == True):
np.save(name_psd, p_arrays_all)
np.save(name_psd_f, ff_p)
else:
ff_p = np.load(name_psd_f) # p_arrays_p
p_arrays_all = np.load(name_psd) # p_arrays_p
for c_nn, c_nr in enumerate(c_nrs):
for a in range(len(arrays_here)):
axps_here = [[axps[0], axps[1], axps[2]]]#, [axps[3], axps[4], axps[5]]
axp = axps_here[c_nn][a]
pp = log_calc_psd(log, p_arrays_all[c_nn][a][0],
np.nanmax(p_arrays_all))
markeredgecolors = []
if a == 0:
colors_peaks = [color01] # , 'red']
freqs = [np.abs(freq1)] # ], np.abs(freq2)],
elif a == 1:
colors_peaks = [color02] # , 'red']
freqs = [np.abs(freq2)]
else:
colors_peaks = [color01_012, color02_012] # , 'red']
freqs = [np.abs(freq1), np.abs(freq2)]
markeredgecolors = [color01, color02]
plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs,
colors_peaks=colors_peaks, xlim=xlimp,
markeredgecolor=markeredgecolors, )
if log:
scalebar = False
if scalebar:
axp.show_spines('b')
if a == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
axp.set_ylim(-33, 5)
else:
axp.show_spines('lb')
if a == 0:
axp.set_ylabel('dB') # , va='center', ha='left'
else:
remove_yticks(axp)
axp.set_ylim(-39, 5)
else:
axp.show_spines('lb')
if a != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
axp.set_xlabel('Frequency [Hz]')
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[-1].legend(loc=(-2.12, 1.03), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_y(axts)
set_same_ylim(axts)
set_same_ylim(axps)
join_x(axts)
#join_x(ax_us)
#join_y(ax_us)
fig = plt.gcf()
#fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7)
#fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7)
#fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4)
save_visualization(cell_here, show)
def vary_contrasts4(yposs=[450, 450, 450], freqs=[(39.5, -210.5)], printing=False, beat='', nfft_for_morph=4096 * 4,
gain=1,
cells_here=["2013-01-08-aa-invivo-1"], fish_jammer='Alepto', us_name='',
show=True):
runs = 1
n = 1
dev = 0.001
reshuffled = 'reshuffled' # ,
# standard combination with intruder small
a_f2s = [0.1]
min_amps = '_minamps_'
dev_name = ['05']
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
if len(cells_here) < 1:
cells_here = np.array(model_cells.cell)
a_fr = 1
a = 0
trials_nrs = [5]
datapoints = 1000
stimulus_length = 2
results_diff = pd.DataFrame()
position_diff = 0
plot_style()
default_figsize(column=2, length=7.5)
for trials_nr in trials_nrs: # +[trials_nrs[-1]]
# sachen die ich variieren will
###########################################
auci_wo = []
auci_w = []
nfft = 32768
for cell_here in cells_here:
full_names = [
'calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_25_FirstC1_0.0001_LastC1_1.0_StimLen_' + str(
stimulus_length) + '_nfft_' + str(nfft) + '_trialsnr_1_absolut_power_1_minamps__dev_05temporal']
c_grouped = ['c1'] # , 'c2']
frame = pd.read_csv(load_folder_name('calc_cocktailparty') + '/' + full_names[0] + '.csv')
frame_cell_orig = frame[(frame.cell == cell_here)]
if len(frame_cell_orig) > 0:
try:
pass
except:
print('min thing')
embed()
get_frame_cell_params(c_grouped, cell_here, frame, frame_cell_orig)
c_nrs_orig = [0.02, 0.2] # 0.0002, 0.05, 0.5
trials_nr = 20 # 20
redo = False # True
log = 'log' # 'log'
grid0 = gridspec.GridSpec(1, 1, bottom=0.08, top=0.93, left=0.11,
right=0.95, wspace=0.04) #
grid00 = gridspec.GridSpecFromSubplotSpec(1, 1,
wspace=0.04, hspace=0.1,
subplot_spec=grid0[0]) # height_ratios=[2,1],
grid_ll = gridspec.GridSpecFromSubplotSpec(len(c_nrs_orig) + 1, 3,
hspace=0.75,
wspace=0.1, height_ratios=[1, 1, 0.7],
subplot_spec=grid00[
0]) # 1.2hspace=0.4,wspace=0.2,len(chirps)
#################################################################
# calc_model_amp_freqs-F1_750-975-25_F2_500-775-25_C2_0.1_C1Len_20_FirstC1_0.0001_LastC1_1.0_StimLen_25_nfft_32768_trialsnr_1_absolut_power_1temporal.csv
# devs_extra = ['stim','stim_rec','stim_am','original','05']#['original','05']
# da implementiere ich das jetzt für eine Zelle
# wo wir den einezlnen Punkt und Kontraste variieren
f_counter = 0
frame_cell_orig, df1s, df2s, f1s, f2s = find_dfs(frame_cell_orig)
eodf = frame_cell_orig.f0.unique()[0]
f = -1
axts_all = []
axps_all = []
ax_us = []
for freq1, freq2 in freqs:
f += 1
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
if len(frame_cell) < 1:
freq1 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df1 - freq1)))].df1
freq2 = frame_cell_orig.iloc[(np.argmin(np.abs(frame_cell_orig.df2 - freq2)))].df2
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
print('Tuning curve needed for F1' + str(frame_cell.f1.unique()) + ' F2' + str(
frame_cell.f2.unique()) + ' for cell ' + str(cell_here))
labels, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidths = colors_susept(
add='_mean', nr=4)
print(cell_here + ' F1' + str(freq1) + ' F2 ' + str(freq2))
sampling = 20000
c_dist_recalc = dist_recalc_phaselockingchapter()
c_nrs = c_dist_recalc_func(frame_cell, c_nrs=c_nrs_orig, cell=cell_here,
c_dist_recalc=c_dist_recalc)
if c_dist_recalc == False:
c_nrs = np.array(c_nrs) * 100
mults_period = 3
start = 200 # 1000
xlim = [start, start + (mults_period * 1000 / np.min([np.abs(freq1), np.abs(freq2)]))]
letters = ['A', 'B']
indexes = [[0, 1], [2, 3], [0, 1, 2, 3]]
for i, index in enumerate(indexes):
try:
ax_u1 = plt.subplot(grid_ll[-1, i])
except:
print('grid search problem4')
embed()
plt_single_trace([], ax_u1, frame_cell_orig, freq1, freq2,
scores=np.array(scores)[index], labels=np.array(labels)[index],
colors=np.array(colors)[index],
linestyles=np.array(linestyles)[index], linewidths=np.array(linewidths)[index],
alpha=np.array(alpha)[index],
sum=False, B_replace='F', default_colors=False, c_dist_recalc=c_dist_recalc)
ax_us.append(ax_u1)
frame_cell = frame_cell_orig[(frame_cell_orig.df1 == freq1) & (frame_cell_orig.df2 == freq2)]
c1 = c_dist_recalc_here(c_dist_recalc, frame_cell)
ax_u1.set_xlim(0, 50)
if i != 0:
ax_u1.set_ylabel('')
remove_yticks(ax_u1)
if i < 2:
ax_u1.fill_between(c1, frame_cell[np.array(scores)[index][0]],
frame_cell[np.array(scores)[index][1]], color='grey', alpha=0.1)
ax_u1.scatter(c_nrs, (np.array(yposs[i]) - 35) * np.ones(len(c_nrs)), color='black', marker='v',
clip_on=False)
#
for c_nn, c_nr in enumerate(c_nrs):
ax_u1.text(c_nr, yposs[i][c_nn] + 50, letters[c_nn], color='black', ha='center', va='top')
# ax_u1.plot([c_nr, c_nr], [0, 435], color='black', linewidth=0.8, clip_on=False)
# embed()
# plt.show()
# embed()
# if f != 0:
# # remove_yticks(ax_u0)
# # remove_yticks(ax_u1)
# print('hi')
# else:
# ax_u1.set_ylabel(power_spectrum_name())
# embed()
axts = []
axps = []
axes = []
p_arrays_all = []
for c_nn, c_nr in enumerate(c_nrs):
#################################
# arrays plot
save_dir = load_savedir(level=0).split('/')[0]
name_psd = save_dir + '_psd.npy'
name_psd_f = save_dir + '_psdf.npy'
if (not os.path.exists(name_psd)) | (redo == True):
if log != 'log':
stimulus_length_here = 0.5
nfft_here = 32768
else:
stimulus_length_here = 50
trials_nr = 20
nfft_here = 2 ** 22
else:
nfft_here = 2 ** 14
stimulus_length_here = 0.5
v_mems, arrays_spikes, arrays_stim, results_diff, position_diff, auci_wo, auci_w, arrays, names, p_arrays_p, ff_p = calc_roc_amp_core_cocktail(
[freq1 + eodf], [freq2 + eodf], datapoints, auci_wo, auci_w, results_diff, a_f2s,
fish_jammer, trials_nr, nfft_here, us_name, gain, runs, a_fr, nfft_for_morph, beat,
printing,
stimulus_length_here,
model_cells, position_diff, dev, cell_here, dev_name=dev_name, a_f1s=[c_nrs_orig[c_nn]],
n=n,
reshuffled=reshuffled, min_amps=min_amps)
p_arrays_here = p_arrays_p[1::]
xlimp = (0, 300)
for p in range(len(p_arrays_here)):
p_arrays_here[p][0] = p_arrays_here[p][0][ff_p < xlimp[1]]
ff_p = ff_p[ff_p < xlimp[1]]
time = np.arange(0, len(arrays[a][0]) / sampling, 1 / sampling)
time = time * 1000
# plot the first array
arrays_time = arrays[1::] # [v_mems[1],v_mems[3]]#[1,2]#[1::]
arrays_here = arrays[1::] # [arrays[1],arrays[3]]#arrays[1::]#
arrays_st = arrays_stim[1::] # [arrays_stim[1],arrays_stim[3]]#
arrays_sp = arrays_spikes[1::] # [arrays_spikes[1],arrays_spikes[3]]#arrays_spikes[1::]
colors_array_here = ['grey', 'grey', 'grey'] # colors_array[1::]
p_arrays_all.append(p_arrays_here)
for a in range(len(arrays_here)):
print('a' + str(a))
if a == 0:
freqs = [np.abs(freq1)] # ], np.abs(freq2)],
elif a == 1:
freqs = [np.abs(freq2)]
else:
freqs = [np.abs(freq1), np.abs(freq2)]
grid_pt = gridspec.GridSpecFromSubplotSpec(5, 1,
hspace=0.3,
wspace=0.2,
subplot_spec=grid_ll[c_nn, a],
height_ratios=[1, 0.7, 1, 0.25,
2.5]) # hspace=0.4,wspace=0.2,len(chirps)
axe = plt.subplot(grid_pt[0])
axes.append(axe)
plt_stim_saturation(a, [], arrays_st, axe, colors_array_here, f,
f_counter, names, time, xlim=xlim) # np.array(arrays_sp)*1000
a_f2_cm = c_dist_recalc_func(frame_cell, c_nrs=[a_f2s[0]], cell=cell_here,
c_dist_recalc=c_dist_recalc)
if c_dist_recalc == False:
a_f2_cm = np.array(a_f2_cm) * 100
if a == 2: # if (a_f1s[0] != 0) & (a_f2s[0] != 0):
fish = 'Three fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() + '\,\&\,' + f_stable_name() # + '$'#' $\Delta '$\Delta$
beat_here = twobeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name(
freq=int(freq1), delta=True) + ',\,$c_{1}=%s$' % (
int(np.round(c_nrs[c_nn]))) + '$\%$' + '\n' + f_stable_name(
freq=int(freq2), delta=True) + ',\,$c_{2}=%s$' % (
int(np.round(a_f2_cm[0]))) + '$\%$' # +'$'
title_name = fish + '\n' + beat_here # +c1+c2
elif a == 0: # elif (a_f1s[0] != 0):
beat_here = ' ' + onebeat_cond(big=True, double=True, cond=False) + '\,' + f_vary_name(
freq=int(freq1), delta=True) # +'$' + ' $\Delta '
fish = 'Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_vary_name() # +'$'
c1 = ',\,$c_{1}=%s$' % (int(np.round(c_nrs[c_nn]))) + '$\%$ \n '
title_name = fish + '\n' + beat_here + c1 # +'cm'+'cm'+'cm'
elif a == 1: # elif (a_f2s[0] != 0):
beat_here = ' ' + onebeat_cond(big=True, double=True,
cond=False) + '\,' + f_stable_name(freq=int(freq2),
delta=True) # +'$'
fish = '\n Two fish: $'+f_eod_name_core_rm()+'$\,\&\,' + f_stable_name() # +'$'
c1 = ',\,$c_{2}=%s$' % (int(np.round(a_f2_cm[0]))) + '$\%$ \n'
title_name = fish + '\n' + beat_here + c1 # +'cm'
axe.text(1, 1.1, title_name, va='bottom', ha='right',
transform=axe.transAxes)
#############################
axs = plt.subplot(grid_pt[1])
plt_spikes_ROC(axs, 'grey', np.array(arrays_sp[a]) * 1000, xlim, lw=1)
#############################
axt = plt.subplot(grid_pt[2])
axts.append(axt)
plt_vmem_saturation(a, arrays_sp, arrays_time, axt, colors_array_here, f,
time, xlim=xlim)
axp = plt.subplot(grid_pt[-1])
axps.append(axp)
if a == 0:
axt.show_spines('')
axt.xscalebar(0.1, -0.1, 10, 'ms', va='right', ha='bottom')
axt.yscalebar(-0.02, 0.35, 600, 'Hz', va='left', ha='top')
f_counter += 1
if (not os.path.exists(name_psd)) | (redo == True):
np.save(name_psd, p_arrays_all)
np.save(name_psd_f, ff_p)
else:
ff_p = np.load(name_psd_f) # p_arrays_p
p_arrays_all = np.load(name_psd) # p_arrays_p
for c_nn, c_nr in enumerate(c_nrs):
for a in range(len(arrays_here)):
axps_here = [[axps[0], axps[1], axps[2]], [axps[3], axps[4], axps[5]]]
axp = axps_here[c_nn][a]
pp = log_calc_psd(log, p_arrays_all[c_nn][a][0],
np.nanmax(p_arrays_all))
markeredgecolors = []
if a == 0:
colors_peaks = [color01] # , 'red']
freqs = [np.abs(freq1)] # ], np.abs(freq2)],
elif a == 1:
colors_peaks = [color02] # , 'red']
freqs = [np.abs(freq2)]
else:
colors_peaks = [color01_012, color02_012] # , 'red']
freqs = [np.abs(freq1), np.abs(freq2)]
markeredgecolors = [color01, color02]
plt_psd_saturation(pp, ff_p, a, axp, colors_array_here, freqs=freqs,
colors_peaks=colors_peaks, xlim=xlimp,
markeredgecolor=markeredgecolors, )
if log:
scalebar = False
if scalebar:
axp.show_spines('b')
if a == 0:
axp.yscalebar(-0.05, 0.5, 20, 'dB', va='center', ha='left')
axp.set_ylim(-33, 5)
else:
axp.show_spines('lb')
if a == 0:
axp.set_ylabel('dB') # , va='center', ha='left'
else:
remove_yticks(axp)
axp.set_ylim(-39, 5)
else:
axp.show_spines('lb')
if a != 0:
remove_yticks(axp)
else:
axp.set_ylabel(power_spectrum_name())
axp.set_xlabel('Frequency [Hz]')
axts_all.extend(axts)
axps_all.extend(axps)
ax_us[-1].legend(loc=(-2.22, 1.2), ncol=2, handlelength=2.5) # -0.07loc=(0.4,1)
axts_all[0].get_shared_y_axes().join(*axts_all)
axts_all[0].get_shared_x_axes().join(*axts_all)
axps_all[0].get_shared_y_axes().join(*axps_all)
axps_all[0].get_shared_x_axes().join(*axps_all)
join_y(axts)
set_same_ylim(axts)
set_same_ylim(axps)
join_x(axts)
join_x(ax_us)
join_y(ax_us)
fig = plt.gcf()
fig.tag([[axes[0], axes[1], axes[2]]], xoffs=0, yoffs=3.7)
fig.tag([[axes[3], axes[4], axes[5]]], xoffs=0, yoffs=3.7)
fig.tag([ax_us[0], ax_us[1], ax_us[2]], xoffs=-2.3, yoffs=1.4)
save_visualization(cell_here, show)
def twobeat_cond(big=False, double=False, cond=True):
if cond:
if not big:
val = 'two-beat condition'
else:
val = 'Two-beat condition'
if double:
val += ':'
else:
if not big:
val = 'two beats'
else:
val = 'Two beats'
if double:
val += ':'
return val
def colors_susept(add='_mean', nr=4):
scores = ['amp_B1_01' + add, 'amp_B1_012' + add, 'amp_B2_02' + add,
'amp_B2_012' + add] # 'amp_B1+B2_012_mean',#($'+f_eod_name_core_rm()+'$ + $f_{p}$)($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)($'+f_eod_name_core_rm()+'$ + $f_{s}$ )($'+f_eod_name_core_rm()+'$ + $f_{p}$ + $f_{s}$)
lables = ['$A(\Delta $' + f_vary_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_vary_name(),
'$A(\Delta $' + f_vary_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name(),
'$A(\Delta $' + f_stable_name() + '$)$ in ' + onebeat_cond() + ' $\Delta $' + f_stable_name(),
'$A(\Delta $' + f_stable_name() + '$)$ in ' + twobeat_cond() + ' $\Delta $' + f_vary_name() + '\,\&\,$\Delta $' + f_stable_name()
]
color01 = 'darkred' # 'darkgreen'
color02 = 'darkblue' # 'darkblue'
color01_012 = 'red' # 'red'#'black'##'lightgreen' # 'blue'#
color02_012 = 'cyan' # 'green'#'lightblue'#'grey'#
colors = [color01, color01_012, color02, color02_012, 'grey']
colors_array = ['grey', color01, color02, 'purple']
dashed = (0, (nr, nr))
linestyles = ['-', dashed, '-', dashed, dashed]
alpha = [1, 1, 1, 1, 1]
linewidth = [1.6, 1.4, 1.6, 1.4, 1.4]
return lables, alpha, color01, color01_012, color02, color02_012, colors, colors_array, linestyles, scores, linewidth
def square_part(ax, shrink=1, what=[], end='.pkl', folder='calc_model',
full_name='modell_all_cell_no_sinz1_afe1_0.03__afr0_1__afj2_0.1__phaseright__len5_adaptoffset_bisecting_0.995_1.005____ratecorrrisidual35__modelbigfit_nfft4096_StartE1_1_EndE1_1.3_in0.005_StartJ2_1_EndJ2_1.3_in0.005_trialnr20__reshuffled_ThreeDiff_SameOffset'):
score = 'auci02_012-auci_base_01' # 'previous_auci02_012-auci_base_01'#['auci_02_012', 'auci_base_01', 'previous_auci02_012-auci_base_01', ]
cell_orig = '2013-01-08-aa-invivo-1' # '2012-12-13-an-invivo-1'
dev = '_05' # ,'_2','_original','_stim','_isi'
mult = '_abs1000' # ,'_mult3'
counter = 0
versions = {}
if ('auci' not in score) and ('auc' not in score):
mult_new = ''
else:
mult_new = mult
if len(what) < 1:
what = score + mult_new + dev
mat, vers_here, cell, eod_m, fr_rate_mult = define_squares_model_three(what=what,
square=[],
full_name=full_name,
minimum=0, folder=folder,
maximum=3, end=end,
cell_data=cell_orig,
emb=False)
lim = find_lims(what, vers_here)
versions[what] = vers_here
ax.set_aspect('equal')
try:
power = np.unique(mat['power'])[0]
except:
print('power thing')
embed()
plt.suptitle(str(cell_orig) + ' power ' + str(power) + ' dev ' + str(dev))
mult_type = ''
pcolor = True
im = plt_square(mat, pcolor, mult_type, vers_here, lim)
square_labels(mult_type, ax, vers_here, 0)
extra_labels = True
if extra_labels:
ax.set_xlabel('$\Delta \mathrm{f_{Intruder}}$ [Hz]')
ax.set_ylabel('$\Delta \mathrm{f_{Female}}$ [Hz]')
_, _, _, _, _ = colorbar_outside(ax, im, add=5, delta=0.25, round_digit=2, width=0.01,
shrink=shrink)
ax.text(1.3
, 0.5, core_scatter_wfemale(), va='center', ha='center', rotation=90
, transform=ax.transAxes) # va = 'center',270
im.set_clim(-0.5, 0.5)
counter += 1
def find_lims(what, vers_here):
if 'auci' in what:
lim = [-0.5, 0.5]
elif 'auc' in what:
lim = [-1, 1]
else:
vmax = np.nanpercentile(np.abs(vers_here), 95)
vmin = np.nanpercentile(np.abs(vers_here), 5)
lims = np.max([vmax, np.abs(vmin)])
lim = [-lims, lims]
return lim
def cut_matrix_generation(condition, minimum, maximum):
index_chosen = condition.index[(condition.index > minimum) & (condition.index < maximum)]
column_chosen = condition.columns[(condition.columns > minimum) & (condition.columns < maximum)]
condition = condition.loc[index_chosen, column_chosen]
return condition, column_chosen, index_chosen
def define_squares_model2(a_fe, nr, a_fj, cell_nr, what, step, cell=[], a_fr=1, adapt='adaptoffsetallall2',
variant='no',
self='', symetric='', resize=True, minimum=0.5, maximum=1.5,
dist_type='SimpleDist', redo=False, beat_type='', version_sinz='sinz', varied='emitter',
full_name='', emb=False):
if full_name == '':
name = load_folder_name('calc_model') + '/modell_all_cell_' + variant + '_' + version_sinz + str(
nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str(
str(a_fj)) + '__length1.5_' + adapt + '___stepefish' + str(
step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + beat_type + '.pkl'
else:
name = load_folder_name('calc_model') + '/' + full_name + '.pkl'
if os.path.exists(name):
############################
# Simples GLOBAL scores, like std, amp etc, without temporal inforrmation
what_orig = what
if 'spike_times' in what:
what = 'spike_times'
control, condition, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, what, a_fr=a_fr,
variant=variant,
adapt=adapt, full_name=full_name,
version_sinz=version_sinz, resize=resize,
symetric=symetric, minimum=minimum,
maximum=maximum,
beat_type=beat_type, self=self, step=step,
cell=cell)
base, base_matrix, baseline = load_baseline_matrix(what, cell, condition, a_fr=a_fr)
control_afj, DF_e, dict_here, eod_m = get_control(nr, cell_nr, what, 'afj', a_fr=a_fr, adapt=adapt,
varied=varied
, symetric=symetric, duration=duration, contrast1=a_fe,
beat_type=beat_type, contrast2='0', version_sinz=version_sinz,
step=step, cell=cell, variant=variant, minimum=minimum,
maximum=maximum, self=self)
control_afe, DF_e, dict_here, eod_m = get_control(nr, cell_nr, what, 'afe', a_fr=a_fr, adapt=adapt,
varied=varied,
contrast1='0', duration=duration, symetric=symetric,
beat_type=beat_type, contrast2=a_fj,
version_sinz=version_sinz,
step=step, cell=cell, variant=variant, minimum=minimum,
maximum=maximum, self=self) # not found
if 'spike_times' in what_orig:
#############################
# temporal information
if maximum != []:
max_name = '_min' + str(minimum) + '_min' + str(maximum)
else:
max_name = ''
name_diff = load_folder_name('calc_model') + '/diffsquare_' + variant + '_' + version_sinz + str(
nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str(
str(a_fr)) + '__length1.5_' + adapt + '___stepefish' + str(
step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + '_' + dist_type + beat_type + max_name
'diffsquare_no_sinz3_afe0.1__afr1__afj1__length1.5_adaptoffsetallall2___stepefish10Hz_ratecorrrisidual35__modelbigfit_nfft4096_SimpleDist_beat_min0.5_min1'
if (os.path.exists(name_diff + '.pkl')) and (redo == False):
diff_loaded = pd.read_pickle(name_diff + '.pkl')
if cell in np.unique(diff_loaded['dataset']):
diff_load = diff_loaded[diff_loaded['dataset'] == cell]
diff_load.pop('dataset')
if '05' in what_orig:
dev = '05'
elif '2' in what_orig:
dev = '2'
else:
dev = 'original'
diff_load = diff_load[diff_load['dev'] == dev]
diff_load.pop('dev')
versions = {}
sorted = retrieve_mat(diff_load, '0-1-2')
versions['diff'] = sorted
sorted = retrieve_mat(diff_load, '0-1')
versions['0-1'] = sorted
sorted = retrieve_mat(diff_load, '0-2')
versions['0-2'] = sorted
diff_load.pop('dist')
cont = False
print('cell already there')
else:
cont = True
else:
diff_loaded = pd.DataFrame()
cont = True
if cont:
print('load diff ' + cell)
versions = {}
# get parameters
control, nfft, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'nfft', a_fr=a_fr,
beat_type=beat_type, variant=variant,
adapt=adapt,
version_sinz=version_sinz, self=self,
step=step, cell=cell)
control, sampling_rate, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr,
'sampling_rate',
beat_type=beat_type, a_fr=a_fr,
variant=variant, adapt=adapt,
version_sinz=version_sinz, self=self,
step=step, cell=cell)
control, length, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'length', a_fr=a_fr,
beat_type=beat_type, variant=variant,
adapt=adapt,
version_sinz=version_sinz, self=self,
step=step, cell=cell)
control, cut_spikes, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, 'cut_spikes',
beat_type=beat_type, a_fr=a_fr,
variant=variant,
adapt=adapt,
version_sinz=version_sinz, self=self,
step=step, cell=cell)
diff = pd.DataFrame()
diff05 = pd.DataFrame()
diff2 = pd.DataFrame()
diff_1 = pd.DataFrame()
diff05_1 = pd.DataFrame()
diff2_1 = pd.DataFrame()
diff_2 = pd.DataFrame()
diff05_2 = pd.DataFrame()
diff2_2 = pd.DataFrame()
min_array = 0.5
max_array = 1
for i in range(len(condition)):
for j in range(len(condition.iloc[0])):
print(j)
print(i)
try:
arrays = [condition.iloc[i, j][0], control_afe.iloc[i, j][0], control_afj.iloc[i, j][0],
base_matrix.iloc[i, j][0]]
except:
embed()
sampling_rate_here = sampling_rate.iloc[i, j]
spikes_mat = {}
mat05 = {}
mat2 = {}
names = ['condition', 'control_afe', 'control_afj', 'base_matrix']
for a, array in enumerate(arrays):
spikes_cut, spikes_mat[names[a]], mat05[names[a]], mat2[names[a]] = create_spikes_mat(
max_array - min_array,
array[(array < max_array) & (array > min_array)] - min_array,
sampling_rate_here) # length_here- cut_spikes_here*2
if 'SimpleDist' in dist_type:
diff05.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat05['condition'] - mat05['control_afe'] - mat05['control_afj'] + mat05['base_matrix'])
diff2.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat2['condition'] - mat2['control_afe'] - mat2['control_afj'] + mat2['base_matrix'])
diff.loc[condition.index[i], condition.columns[j]] = np.nanmean(
spikes_mat['condition'] - spikes_mat['control_afe'] - spikes_mat['control_afj'] +
spikes_mat['base_matrix'])
diff05_1.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat05['condition'] - mat05['control_afe'])
diff2_1.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat2['condition'] - mat2['control_afe'])
diff_1.loc[condition.index[i], condition.columns[j]] = np.nanmean(
spikes_mat['condition'] - spikes_mat['control_afe'])
diff05_2.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat05['condition'] - mat05['control_afj'])
diff2_2.loc[condition.index[i], condition.columns[j]] = np.nanmean(
mat2['condition'] - mat2['control_afj'])
diff_2.loc[condition.index[i], condition.columns[j]] = np.nanmean(
spikes_mat['condition'] - spikes_mat['control_afj'])
# todo: here noch ein paar andere Differenzen machen
elif 'ConspDist' in dist_type:
length_consp = 0.030 * sampling_rate_here
shift = 0.005
shift_conditions = np.arange(0, len(mat2['control_afe']), shift * sampling_rate_here)
shift_controls = np.arange(0, len(mat2['control_afe']), shift * sampling_rate_here)
consp = pd.DataFrame()
for s, shift_condition in enumerate(shift_conditions):
for ss, shift_control in enumerate(shift_controls):
if (int(length_consp + shift_control) < len(mat2['control_afe'])) & (
int(length_consp + shift_condition) < len(mat2['condition'])):
consp.loc[s, ss] = np.sqrt(np.mean((mat2['control_afe'][
0 + int(shift_control):int(
length_consp + shift_control)]
- mat2['condition'][
0 + int(shift_condition):int(
length_consp + shift_condition)]) ** 2))
embed()
diff['dataset'] = cell
diff05['dataset'] = cell
diff2['dataset'] = cell
diff_1['dataset'] = cell
diff05_1['dataset'] = cell
diff2_1['dataset'] = cell
diff_2['dataset'] = cell
diff05_2['dataset'] = cell
diff2_2['dataset'] = cell
diff['dist'] = '0-1-2'
diff05['dist'] = '0-1-2'
diff2['dist'] = '0-1-2'
diff_1['dist'] = '0-1'
diff05_1['dist'] = '0-1'
diff2_1['dist'] = '0-1'
diff_2['dist'] = '0-2'
diff05_2['dist'] = '0-2'
diff2_2['dist'] = '0-2'
diff['dev'] = 'original'
diff05['dev'] = '05'
diff2['dev'] = '2'
diff_1['dev'] = 'original'
diff05_1['dev'] = '05'
diff2_1['dev'] = '2'
diff_2['dev'] = 'original'
diff05_2['dev'] = '05'
diff2_2['dev'] = '2'
if len(diff_loaded) < 1:
vertical_stack = pd.concat(
[diff, diff05, diff2, diff_1, diff05_1, diff2_1, diff_2, diff05_2, diff2_2, ], axis=0)
vertical_stack.to_pickle(name_diff + '.pkl')
else:
vertical_stack = pd.concat(
[diff_loaded, diff, diff05, diff2, diff_1, diff05_1, diff2_1, diff_2, diff05_2, diff2_2, ],
axis=0)
vertical_stack.to_pickle(name_diff + '.pkl')
if '05' in what_orig:
dev = '05'
elif '2' in what_orig:
dev = '2'
else:
dev = 'original'
dev_here = vertical_stack[vertical_stack['dev'] == dev]
diff_output = dev_here[dev_here['dist'] == '0-1-2']
diff_output.pop('dist')
diff_output.pop('dev')
diff_output.pop('dataset')
versions['diff'] = diff_output
diff_output = dev_here[dev_here['dist'] == '0-1']
diff_output.pop('dist')
diff_output.pop('dev')
diff_output.pop('dataset')
versions['0-1'] = diff_output
diff_output = dev_here[dev_here['dist'] == '0-2']
diff_output.pop('dist')
diff_output.pop('dev')
diff_output.pop('dataset')
versions['0-2'] = diff_output
versions['eod'] = eod_m
else:
print('load diff ' + cell)
else:
diff_output = condition - control_afe - control_afj + base_matrix
versions = {}
versions['base'] = base_matrix
versions['control1'] = control_afe
versions['control2'] = control_afj
versions['12'] = condition
versions['diff'] = diff_output
versions['0-1'] = condition - control_afe
versions['0-2'] = condition - control_afj
versions['eod'] = eod_m
else:
versions = []
cell = ''
eod_m = ''
if emb:
embed()
return versions, cell, eod_m
def get_condition(contrast1, nr, contrast2, cell_nr, what, step=60, a_fr=1, adapt='adaptoffsetallall2',
variant='no',
version_sinz='sinz', full_name='', resize=True, symetric='', SAM='SAM', square=[], three='',
length='1.5',
duration='', folder='model', minimum=[], maximum=[], f0='f0', f2='f2', f1='f1', self='', beat_type='',
cell=[], emb=False, end='.pkl'): # f0 = 'eodf' f2 = 'eodj', f1 = 'eode'
if 'csv' in end:
# das ist falls wir ein csv haben wie das simplified Threewave protokoll
if full_name == '':
pass
else:
pass
control = pd.read_csv(
'../data/' + folder + '/' + full_name + end, index_col=0)
cell_array = control[control.cell == cell]
df2 = np.round(cell_array.df2, 2)
df1 = np.round(cell_array.df1, 2)
cell_array.df2 = df2
cell_array.df1 = df1 # np.unique(cell_array.df2)np.unique(cell_array.df1)
condition = cell_array.pivot(index='df2', columns='df1',
values=what) # index=['eode', 'nnft'] this will create multiindexing
eod_f = cell_array['f0']
if resize == True:
fr = cell_array.fr
fr_rate_mult = fr / np.mean(cell_array['f0'])
else:
# das ist für die pkls also vor allem für das nicht simplified protokoll, was am meisten verwendet wurde
if full_name == '':
control = pd.read_pickle(
load_folder_name('calc_model') + +'/modell_all_cell_' + variant + '_' + version_sinz + str(
nr) + self + '_afe' + str(contrast1) + '__afr' + str(a_fr) + '__afj' + str(
str(contrast2)) + '__length' + str(length) + '_' + adapt + '_' + SAM + '__stepefish' + str(
step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + beat_type + symetric + three + end)
else:
control = pd.read_pickle(
load_folder_name('calc_model') + '/' + full_name + end) # '../data/'+#'../data/'+
test = False
if test == True:
from utils_test import test_controls
test_controls()
if square != []:
sqaure_array = control[control['square'] == square]
else:
sqaure_array = control
if not cell:
if cell_nr < len(np.unique(sqaure_array['dataset'])):
cell = np.unique(sqaure_array['dataset'])[cell_nr]
else:
cell = []
if cell != []:
cell_array = sqaure_array[sqaure_array['dataset'] == cell]
if square != []:
base = control[control['square'] == 'base_0']
base_cell = base[base['dataset'] == cell]
fr_rate = np.unique(base_cell.mean_fr)
fr_rate_mult = fr_rate
else:
fr_rate_mult = np.mean(cell_array.rate_baseline_after.iloc[
1::]) # todo: ok als das stimmt schon das after, ab dem zweiten trial das die ursprüngliche baseline
print(fr_rate_mult)
if (what in cell_array.keys()) and not cell_array.empty:
condition = cell_array.pivot(index=f2, columns=f1,
values=what) # index=['eode', 'nnft'] this will create multiindexing
if square == 'base_0':
sqaure_array_012 = control[control['square'] == '012']
cell_array = sqaure_array_012[sqaure_array_012['dataset'] == cell]
condition_012 = cell_array.pivot(index=f2, columns=f1,
values=what)
condition_012[:] = condition.iloc[0, 0]
condition = condition_012
DF_1 = np.unique(
np.array((cell_array[f1] - cell_array[f0]) / cell_array[
f0] + 1))
DF_2 = np.round(np.unique(np.array(
(cell_array[f2] - cell_array[f0]) / cell_array[
f0] + 1)), 3)
eod_f = cell_array[f1]
if resize == True:
fr_rate_mult = fr_rate_mult / np.mean(cell_array[f0])
dict_here = dict(zip(np.unique(cell_array[f1]), np.round(DF_1, 3)))
condition = condition.rename(columns=dict_here)
condition = condition.set_index(DF_2)
condition.columns.name = f1 + str('-f0') # 'fish2-fish0 $f_{stim}/'+f_eod_name_core_rm()+'$' # 'DeltaF-eodj-eodf'
condition.index.name = f2 + str('-f0') # 'fish1-fish0 $f_{stim}/'+f_eod_name_core_rm()+'$' # 'DeltaF-eode-eodf'
if maximum != []:
condition, column_chosen, index_chosen = cut_matrix_generation(condition, minimum, maximum)
else:
condition = []
cell = []
eod_f = []
fr_rate_mult = []
else:
condition = []
cell = []
eod_f = []
fr_rate_mult = []
if emb:
embed()
return control, condition, cell, eod_f, fr_rate_mult
def define_squares_model_three(emb=False, a_fe=0.1, nr=3, a_fj=0.1, cell_nr=0, what='std', step=50, cell_data=[],
a_fr=1,
adapt='adaptoffsetallall2', variant='no', square=[], full_name='', self='',
length=0.5, SAM='SAM', resize=True, cell=[], duration='', symmetric='',
three='ThreeDiff', minimum=[], maximum=[], beat_type='', folder='model', end='.pkl',
version_sinz='sinz'):
if full_name == '':
name = load_folder_name('calc_model') + '/modell_all_cell_' + variant + '_' + version_sinz + str(
nr) + self + '_afe' + str(a_fe) + '__afr' + str(a_fr) + '__afj' + str(
str(a_fj)) + '__length' + str(length) + '_' + adapt + '_' + SAM + '__stepefish' + str(
step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + symmetric + three + end
else:
name = load_folder_name('calc_model') + '/' + full_name + end # '../data/''../data/'+
condition = []
eod_f = []
fr_rate_mult = []
control = []
print(name)
if os.path.exists(name):
############################
# Simples GLOBAL scores, like std, amp etc, without temporal inforrmation
control, condition, cell, eod_f, fr_rate_mult = get_condition(a_fe, nr, a_fj, cell_nr, what, a_fr=a_fr,
variant=variant,
adapt=adapt, full_name=full_name,
version_sinz=version_sinz, SAM=SAM,
symetric=symmetric, folder=folder,
resize=resize, end=end, square=square,
duration=duration,
length=length, three=three, minimum=minimum,
maximum=maximum, beat_type=beat_type, self=self,
step=step, cell=cell_data,
emb=False)
if len(condition) > 0:
if condition.iloc[0].dtype == complex:
condition = np.abs(condition)
if emb:
embed()
return control, condition, cell, eod_f, fr_rate_mult
def plt_square(control, pcolor, mult_type, vers, lim):
if pcolor:
if 'mult' in mult_type:
axs = plt.pcolormesh(
np.array(list(map(float, vers.columns))),
np.array(vers.index),
vers, vmin=lim[0], vmax=lim[1],
cmap="RdBu_r")
else:
axs = plt.pcolormesh(
(np.array(list(map(float, vers.columns))) - 1) * control.f0.iloc[0],
(np.array(vers.index) - 1) * control.f0.iloc[0],
vers, vmin=lim[0], vmax=lim[1],
cmap="RdBu_r")
else:
try:
axs = plt.imshow(vers, origin='lower', cmap="RdBu_r", vmin=lim[0],
vmax=lim[1], extent=[np.min(vers.columns),
np.max(vers.columns),
np.min(vers.index),
np.max(vers.index)])
except:
print('axs problenm')
embed()
return axs
def square_labels(mult_type, ax, vers_here, w):
if 'mult' in mult_type:
if '2' in vers_here.index.names:
ax.set_xlabel('m2')
else:
ax.set_xlabel('m1')
else:
if '2' in vers_here.index.names:
ax.set_xlabel('Beat2 [Hz]')
if w == 0:
ax.set_ylabel('Beat1 [Hz]')
else:
ax.set_xlabel('Beat1 [Hz]')
if w == 0:
ax.set_ylabel('Beat2 [Hz]')
def figsize_ROC_start():
return [column2(), 2.8] # 3.03.370,2.92.73.5# 13.5/7 = 1.9285, 6.5/1.9285 = 3.3704
def plt_several_ROC_square_nonlin_single(shrink=0.5, top=0.9, loc=(0.4, 0.8), fs=14, defaultst=True,
figsize=(13.5, 6.5), ):
xlim = core_xlim_dist_roc()
if defaultst:
default_settings(width=12, ts=20, ls=20, fs=20)
if figsize:
fig = plt.figure(figsize=figsize)
colors_w, colors_wo, color_base, color_01, color_02, color_012 = colors_cocktailparty_all()
frame_names, trial_nr = core_decline_ROC()
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal',
short = True
frame_names_both = [[
core_decline_ROC(trial_nr='20', short=short, absolut=True)[0][0]],
[core_decline_ROC(trial_nr='20', pos=False, short=short, absolut=True)[0][0]]]
print(frame_names_both)
cells = [
"2013-01-08-aa-invivo-1"] # , "2012-12-13-an-invivo-1", "2012-06-27-an-invivo-1", "2012-12-21-ai-invivo-1","2012-06-27-ah-invivo-1", ]
cells_chosen = [
'2013-01-08-aa-invivo-1'] # , "2012-06-27-ah-invivo-1","2014-06-06-ac-invivo-1" ]#'2012-06-27-an-invivo-1',
grid = gridspec.GridSpec(2, 4, wspace=0.45, width_ratios=[0.27, 0.27, 0, 0.7], hspace=0.5, left=0.1, top=top,
bottom=0.17,
right=0.9, ) # , width_ratios = [1,1,1,0.5,1] height_ratios = [1,6]bottom=0.25, top=0.8,
df1 = []
df2 = []
axes = []
for c, cell in enumerate(cells_chosen):
for ff, frame_names in enumerate(frame_names_both):
grid0 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.15, hspace=0.25,
subplot_spec=grid[:, ff]) # height_ratios=[1, 0.7, 1, 1],
ax1 = plt.subplot(grid0[0])
frame_cell = plt_gain_area(ax1, c, cell, cells, colors_w, colors_wo, df1, df2, ff,
frame_names, fs, loc, xlim)
axes.append(ax1)
if ff != 0:
ax1.set_ylabel('')
else:
ax1.set_ylabel('AUC')
ax1.set_xlabel('')
###############################
# roc nonlin part
ax2 = plt.subplot(grid0[1])
plt_nonlin_line(ax2, cell, ff, frame_cell, xlim)
# squares
ax = plt.subplot(grid[:, 3])
axes.append(ax)
full_name = 'modell_all_cell_no_sinz1_afe1_0.03__afr0_1__afj2_0.1__phaseright__len5_adaptoffset_bisecting_0.995_1.005____ratecorrrisidual35__modelbigfit_nfft4096_StartE1_1_EndE1_1.3_in0.005_StartJ2_1_EndJ2_1.3_in0.005_trialnr20__reshuffled_ThreeDiff_SameOffset'
square_part(ax, shrink=shrink, full_name=full_name)
max_val = 236
ax.plot([-2, 0], [50, max_val], color='black', linewidth=1)
ax.plot([max_val, max_val], [50, max_val], color='black', linewidth=1)
ax.plot([-2, max_val], [50, 50], color='black', linewidth=1)
ax.plot([-2, max_val], [max_val, max_val], color='black', linewidth=1)
third_diagonal = True
if third_diagonal:
df1.append(df1[0])
df2.append(df2[0] + np.abs(frame_cell.fr.iloc[0] - df2[0]) * 2)
plt_circles_matrix(ax, df1, df2)
plt.suptitle('')
fig.tag(axes, xoffs=-3.6, yoffs=2.7, )
save_visualization(jpg=True, png=False)
plt.show()
def plt_circles_matrix(ax, df1, df2, scat=True):
titles = [r'$\numcircled{1}$', r'$\numcircled{2}$', r'$\numcircled{3}$']
for d in range(len(df1)):
if d == 2:
ax.text(df1[d] + 5, df2[d] + 5, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes
elif d == 1:
ax.text(df1[d] + 5, df2[d] + 5, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes
if scat:
ax.scatter(df1[d], df2[d], facecolors='none', edgecolor='black', marker='s')
elif d == 0:
ax.text(df1[d] + 5, df2[d] - 15, titles[d]) # va = 'center', fontsize=11, transform=ax.transAxes
if scat:
ax.scatter(df1[d], df2[d], facecolors='none', edgecolor='black', marker='s')
def colors_cocktailparty_all():
color_base, color_01, color_02, color_012 = colors_cocktailparty()
colors_w = [color_012]
colors_wo = [color_01]
return colors_w, colors_wo, color_base, color_01, color_02, color_012
def plt_gain_area(ax1, c, cell, cells, colors_w, colors_wo, df1, df2, ff, frame_names, fs, loc, xlim):
for f, frame_name in enumerate(frame_names):
path = load_folder_name('calc_ROC') + '/' + frame_name + '.csv'
if os.path.exists(path):
frame = pd.read_csv(path)
path_ref = load_folder_name(
'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
_, _ = find_row_col(cells, row=4)
frame_cell = frame[frame.cell == cell]
nrs = int(frame_name.split('LenNrs_')[1].split('_')[0])
frame_cell = frame_cell.iloc[0:nrs]
df1.append(np.mean(frame_cell.df1.unique()))
df2.append(np.mean(frame_cell.df2.unique()))
axs = [ax1]
label_f = ['with female', 'CLS: 100n with female', 'LS: 1000n with female', ]
label_f2 = ['without female', 'CLS: 100n without female', 'LS: 1000n without female', ]
labels = [label_f[f], label_f[f]]
labels2 = [label_f2[f], label_f2[f]]
for a, ax in enumerate(axs):
if len(frame_cell) > 0:
plt_area_between(frame_cell, ax, ax, colors_w, colors_wo, f, labels_with_female=labels[a],
labels_without_female=labels2[a])
ax.set_xlim(xlim)
titles = [
' $ \Delta \mathrm{f_{Female}} + \Delta \mathrm{f_{Intruder}} $\n' + '$ = \mathrm{f'+basename()+'}$' + r'$\numcircled{1}$ ',
' $ \Delta \mathrm{f_{Female}} + \Delta \mathrm{f_{Intruder}} $\n ' + r' $\neq \mathrm{f'+basename()+'}$' + r'$\numcircled{2}$ ']
ax.set_title(titles[ff])
if ff == 1:
if a == 0:
try:
ax.legend(loc=loc, fontsize=fs, ncol=1)
except:
print('legend something')
embed()
else:
ax.legend(loc=loc, fontsize=fs, ncol=1)
ax.set_ylim(0, 0.52)
ax.set_yticks_delta(0.25)
if c != 0:
remove_yticks(ax)
ax.show_spines('lb')
if a == 0:
ax.set_ylabel('AUC')
remove_xticks(ax)
if ff != 0:
remove_yticks(ax)
return frame_cell
def plt_nonlin_line(ax2, cell, ff, frame_cell, xlim):
c1 = c_dist_recalc_func(frame_cell=frame_cell, c_nrs=frame_cell.c1, cell=cell, c_dist_recalc=True)
talk = False
if talk:
ax2.plot(c1, sum_score(frame_cell), color='black', clip_on=True) # , linewidth=lw
else:
score = val_new_core(frame_cell) # , color='black', clip_on=True, linewidth = 0.75)
score[score < 0] = 0
ax2.plot(c1, score, color='black', clip_on=True) # , linewidth = 0.75
ax2.set_ylim(0, 7)
ax2.set_xlim(xlim)
if ff != 0:
remove_yticks(ax2)
if ff == 0:
ax2.set_ylabel(peak_b1b2_name()) # nonlin_title()
ax2.set_xlabel(core_distance_label())
def val_new_core(frame_cell):
return frame_cell['amp_B1+B2_012-01-02+0_norm_01B1_mean']
def sum_score(frame_cell):
return frame_cell['amp_B1+B2_012-01-02+0_norm_01B1_mean']
def find_variable_from_savename(full_name, name='trialnr'):
verb_length = len(name)
name_start = full_name.find(name)
name_end = name_start + verb_length
line_pos = full_name.find('_', name_start)
trials_nr = full_name[name_end:line_pos] # trials_nr
return trials_nr, line_pos, name_end, name_start,
def freq_two_mult_recalc(frame_cell_orig, freqs):
freqs = [((freqs[0][0] - 1) * frame_cell_orig.f0.iloc[0], (freqs[0][1] - 1) * frame_cell_orig.f0.iloc[0])]
return freqs
def figure_out_score_and_add(add, c_here, frame_cell):
score1, score2, score3 = score_choice(c_here, add)
if score1 not in frame_cell.keys():
add = ''
score1, score2, score3 = score_choice(c_here, add)
if score1 not in frame_cell.keys():
add = ''
score1, score2, score3 = score_choice(c_here, add)
return add, score1, score2, score3
def plt_matrix_saturation_loss(ax, frame_cell, c_here='c1', add='', ims=[], ims_diff=[], imshow=False,
xlabel=True):
add, score1, score2, score3 = figure_out_score_and_add(add, c_here, frame_cell)
cls = ["RdBu_r", "RdBu_r", "RdBu_r"]
for ss, score_here in enumerate([score1, score2, score3]):
new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).mean() # ['score']
matrix = new_frame.pivot(index='df2', columns='df1', values=score_here)
if ss == 2:
lim = np.max([np.max(matrix), np.abs(np.min(matrix))])
ax[ss].set_title(score_here)
if imshow:
im = ax[2].imshow(matrix, origin='lower', cmap="RdBu_r", vmin=-lim, vmax=lim,
extent=[np.min(matrix.columns),
np.max(matrix.columns),
np.min(matrix.index),
np.max(matrix.index)])
else:
try:
im = ax[ss].pcolormesh(
np.array(list(map(float, matrix.columns))), np.array(matrix.index),
matrix,
cmap=cls[ss],
rasterized=False) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95),
except:
print('ims probelem')
embed()
if ss == 2:
im.set_clim(-lim, lim)
ims_diff.append(im)
else:
ims.append(im)
cbar = plt.colorbar(im, ax=ax[ss])
cbar.set_label(score_here, labelpad=100) # rotation=270,
ax[0].set_ylabel('df2')
if xlabel:
ax[ss].set_xlabel('df1')
return lim, matrix, ss, ims
def score_choice(c_here, add=''):
if c_here == 'c1': # 'B1_diff'
score1 = 'amp_B1_012_mean' + add
score2 = 'amp_B1_01_mean' + add
score3 = 'diff'
if c_here == 'c1': # 'B1_diff'
score1 = 'amp_B1_012_mean' + add
score2 = 'amp_B1_01_mean' + add
score3 = 'diff'
return score1, score2, score3
def area_vs_single_peaks_frame(frame_cell, cs_type='all'):
if cs_type != 'all':
frame_cell = frame_cell[(frame_cell.c1 == cs_type) & (frame_cell.c2 == cs_type)]
return frame_cell
def plt_cross(matrix, ax, type='small'):
zero_val = np.min([np.min(matrix.index) + 1, np.min(matrix.columns) + 1])
max_val = np.max([np.max(matrix.index) - 1, np.max(matrix.columns) - 1])
x_axis = np.arange(zero_val, max_val, 1)
y_axis = np.arange(zero_val, max_val, 1)
if type == 'big':
ax.plot(x_axis, y_axis, color='black', zorder=100)
ax.plot(np.array(x_axis), -np.array(y_axis), color='black', zorder=100)
ax.set_xlim(x_axis[0], x_axis[-1])
ax.set_ylim(y_axis[0], y_axis[-1])
elif type == 'small':
for y_axis_here in [y_axis, -y_axis]:
restrict = ((x_axis > np.min(matrix.columns)) & (x_axis < np.max(matrix.columns))) & (
(y_axis_here > np.min(matrix.index)) & (y_axis_here < np.max(matrix.index)))
ax.plot(x_axis[restrict], y_axis_here[restrict], color='black', zorder=100)
ax.set_xlim(np.min(matrix.columns), np.max(matrix.columns))
ax.set_ylim(np.min(matrix.index), np.max(matrix.index))
elif type == 'zerozentered':
x_rt = np.max(matrix.index)
x_lb = np.min(matrix.index)
y_rt = np.max(matrix.columns)
y_lb = np.min(matrix.columns)
ax.plot([0, x_rt], [0, x_rt], color='black', zorder=100)
ax.plot([0, x_lb], [0, x_lb], color='black', zorder=100)
ax.plot([0, y_rt], [0, y_rt], color='black', zorder=100)
ax.plot([0, y_lb], [0, y_lb], color='black', zorder=100)
ax.plot([0, -x_rt], [0, x_rt], color='black', zorder=100)
ax.plot([0, -x_lb], [0, x_lb], color='black', zorder=100)
ax.plot([0, -y_rt], [0, y_rt], color='black', zorder=100)
ax.plot([0, -y_lb], [0, y_lb], color='black', zorder=100)
ax.plot([0, x_rt], [0, -x_rt], color='black', zorder=100)
ax.plot([0, x_lb], [0, -x_lb], color='black', zorder=100)
ax.plot([0, y_rt], [0, -y_rt], color='black', zorder=100)
ax.plot([0, y_lb], [0, -y_lb], color='black', zorder=100)
embed()
def plt_square_row(orientation, matrix_extend, dev, name, fig, title_pos, shrink, auc_lim, addsteps, counter_contrast,
cut, dfs, gridspacing, scores, contr2, cc, grid0, cell, combinations,
matrix_sorted='grid_sorted', xlim=[], ylim=[], text=True):
axes = []
y_max = []
y_min = []
x_max = []
x_min = []
c = -1
for contrast11, contrast22 in combinations:
print('print: ' + str(cell))
grid1 = gridspec.GridSpecFromSubplotSpec(1, int(len(combinations)),
hspace=0.7,
wspace=0.2,
subplot_spec=
grid0
[cc]) # hspace=0.4,wspace=0.2,len(chirps)
if os.path.exists(name):
c += 1
df = pd.read_pickle(name)
if len(df) > 0:
try:
frame = df[(df['cell'] == cell) & (
df[contr2] == contrast22) & (df['dev'] == dev)] #
except:
print('contrast think')
embed()
pivots = {}
if len(frame) > 0:
if 'auc' not in scores[0]:
grid2 = gridspec.GridSpecFromSubplotSpec(1, len(scores) + 1,
hspace=0.4,
wspace=0.4,
subplot_spec=
grid1[c]) # hspace=0.4,wspace=0.2,len(chirps)
trans = False
ax, pivots_scores, orientation, limit_diff, axs = plt_ingridients(trans, gridspacing, scores,
scores, xlim, ylim, cut,
grid2, frame, dfs,
matrix_sorted, orientation,
matrix_extend, colorbar=False)
axes.append(ax)
if len(pivots_scores) > 0:
try:
y_max.append(np.max(pivots_scores[scores[0]].index))
except:
print('y_max')
embed()
y_min.append(np.min(pivots_scores[scores[0]].index))
x_max.append(np.max(pivots_scores[scores[0]].columns))
x_min.append(np.min(pivots_scores[scores[0]].columns))
else:
grid2 = gridspec.GridSpecFromSubplotSpec(1, len(scores) + 2,
hspace=0.4,
wspace=0.4,
subplot_spec=
grid1[c]) # hspace=0.4,wspace=0.2,len(chirps)
####################################
# plot the auc parts
if counter_contrast == 0:
colorbar = True
else:
colorbar = False
orientation, ax, limit_diff, axs, pivots = plt_auc(grid2, matrix_sorted,
orientation, shrink, cut,
xlim, ylim, pivots, trans,
gridspacing,
matrix_extend, dfs, frame, [scores],
cell, auc_lim, addsteps,
scores, 0,
contr2, contrast22,
pad=0.3,
bar_orientation="horizontal",
colorbar=colorbar, fig=fig)
pivots_scores = pivots
axes.append(ax)
y_max.append(np.max(pivots[scores[0]].index))
y_min.append(np.min(pivots[scores[0]].index))
x_max.append(np.max(pivots[scores[0]].columns))
x_min.append(np.min(pivots[scores[0]].columns))
for _ in ax:
ax.axvline(1, color='grey', linewidth=0.5, )
ax.axhline(1, color='grey', linewidth=0.5, )
if c != len(combinations) - 1:
for a in ax:
if a != 0:
try:
ax[a].set_xticks([])
except:
a.set_xticks([])
else:
# if cell_counter in np.arange(row*col-col,row*col,1):
chose_xlabel_roc_matrix(ax, contrast11, contrast22, contr2, pivots_scores, scores)
for s in range(len(scores)):
if s != 0:
ax[s].set_yticks([])
else:
# if cell_counter in np.arange(0,row*col,col):
if c == len(combinations) - 1:
chose_ylabel_ROC_matrix(ax, contrast11, contrast22, contr2, pivots_scores, s,
scores)
if text:
if c == 0:
ax[0].text(0, title_pos, cell, transform=ax[0].transAxes,
fontweight='bold')
ax[0].text(0
, 1.2, 'C1 ' + str(
contrast22) + ' C2 ' + str(
contrast11), transform=ax[3].transAxes,
fontweight='bold') #
ax[0].set_title('')
ax[1].set_title('')
ax[2].set_title('')
ax[3].set_title('')
ax[0].axvline(0, color='grey', linewidth=0.5, )
ax[0].axhline(0, color='grey', linewidth=0.5, )
ax[1].axvline(0, color='grey', linewidth=0.5, )
ax[1].axhline(0, color='grey', linewidth=0.5, )
ax[2].axvline(0, color='grey', linewidth=0.5, )
ax[2].axhline(0, color='grey', linewidth=0.5, )
ax[3].axvline(0, color='grey', linewidth=0.5, )
ax[3].axhline(0, color='grey', linewidth=0.5, )
if len(ax) > 4:
ax[4].set_title('')
set_same_lim(xlim, ylim, y_min, y_max, x_min, x_max, axes)
return axes, y_max, y_min, x_max, x_min
def chose_ylabel_ROC_matrix(ax, contrast11, contrast22, contrastc2, pivots_scores, s, scores):
if (('1' in contrastc2) & (
'1' in pivots_scores[scores[0]].index.name)) | (
(('2' in contrastc2) & (
'2' in pivots_scores[scores[0]].index.name))):
ax[s].set_ylabel(
pivots_scores[scores[0]].index.name + ' ' + str(
contrast22) + '%', labelpad=-25)
else:
ax[s].set_ylabel(
pivots_scores[scores[0]].index.name + ' ' + str(
contrast11) + '%', labelpad=-25)
def chose_xlabel_roc_matrix(ax, contrast11, contrast22, contrastc2, pivots_scores, scores):
if (('1' in contrastc2) &
('1' in pivots_scores[scores[0]].columns.name)) | (
('2' in contrastc2) & (
'2' in pivots_scores[scores[0]].columns.name)):
ax[0].set_xlabel(
pivots_scores[scores[0]].columns.name + ' ' + str(
contrast22) + '%', labelpad=-15)
else:
ax[0].set_xlabel(
pivots_scores[scores[0]].columns.name + ' ' + str(
contrast11) + '%',
labelpad=-15)
def set_same_lim(xlim, ylim, y_min, y_max, x_min, x_max, axes):
if len(xlim) > 0:
ylim_here = ylim
xlim_here = xlim
else:
ylim_here = [np.min(y_min) * 0.99, np.max(y_max) * 1.01]
xlim_here = [np.min(x_min) * 0.99, np.max(x_max) * 1.01]
for aa in range(len(axes)):
for a in axes[aa]:
try:
axes[aa][a].set_ylim(ylim_here)
axes[aa][a].set_xlim(xlim_here)
except:
a.set_ylim(ylim_here)
a.set_xlim(xlim_here)
def plt_ingridients(trans, gridspacing, pivots_diff, scores, xlim, ylim, cut, grid_orig2, frame, dfs,
matrix_sorted='grid_sorted', orientation='f1 on x, f2 on y', matrix_extent='min', title=True,
colorbar_title=True, fig=[], colorbar=False, pad=0.1, bar_orientation='vertical', cl_outside=False):
ax = {}
pivots_scores = {}
pivots_min = []
pivot = []
limit_diff = []
axs = []
if len(frame) > 0:
for s, score in enumerate(scores):
if score in frame.keys():
pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(frame, score,
matrix_extent=matrix_extent,
matrix_sorted=matrix_sorted,
orientation=orientation,
gridspacing=gridspacing,
dfs=dfs)
if trans: #
pivot = np.transpose(pivot)
if 'var' in score:
pivot = np.sqrt(pivot)
pivots_scores[score] = pivot
pivots_min.append(pivot)
else:
if len(pivot) > 0:
pivots_scores[score] = np.ones_like(pivot)
symbol = ['$-$', '$-$', '$+$', '$=$']
for s, score in enumerate(scores):
if score in frame.keys():
ax[s] = plt.subplot(grid_orig2[s])
ax[s].text(1.5, 0.5, symbol[s],
fontsize=15, va='center', ha='center',
transform=ax[s].transAxes) # ha='center', va='center',
vmax = np.nanmax(pivots_min)
vmin = np.nanmin(pivots_min)
axs = plt.imshow(pivots_scores[score], extent=[np.min(pivots_scores[score].columns),
np.max(pivots_scores[score].columns),
np.min(pivots_scores[score].index),
np.max(pivots_scores[score].index)], vmax=vmax,
vmin=vmin,
origin='lower')
if title:
plt.title(score, fontsize=7)
if colorbar:
if cl_outside:
_, _, _, _, _ = colorbar_outside(ax[s], axs, fig,
orientation='bottom')
if colorbar_title:
ax[s].text(0.2, -1.4, score, transform=ax[s].transAxes) # , va = 'center', ha = 'center'
else:
plt.colorbar(orientation=bar_orientation, pad=pad)
if cut:
ax[s].set_ylim(ylim)
ax[s].set_xlim(xlim)
###############################
# plot sum of both
if len(scores) == len(pivots_scores):
ax[s + 1] = plt.subplot(grid_orig2[4])
diff = pivots_scores[scores[0]] - pivots_scores[scores[1]] - pivots_scores[scores[2]] + pivots_scores[
scores[3]]
min = np.min(np.min(diff))
max = np.max(np.max(diff))
lim = np.max([np.abs(min), np.abs(max)])
axs = plt.imshow(diff, vmin=-lim, vmax=lim, origin='lower', cmap="RdBu_r",
extent=[np.min(diff.columns),
np.max(diff.columns),
np.min(diff.index),
np.max(diff.index)])
limit_diff = np.array([np.min(np.min(diff)), np.max(np.max(diff))])
plt.yticks([])
if cut:
ax[s + 1].set_ylim(ylim)
ax[s + 1].set_xlim(xlim)
if scores != pivots_diff[-1]: #
plt.xticks([])
if colorbar:
if cl_outside:
_, _, _, _, _ = colorbar_outside(ax[s + 1], axs, fig,
orientation='bottom', top=True)
else:
plt.colorbar(orientation=bar_orientation, pad=pad)
if len(ax) > 0:
ax[s + 1].scatter(1, 1, marker='o', facecolors='none', edgecolors='black')
return ax, pivots_scores, orientation, limit_diff, axs
def plt_auc(grid0, matrix, orientation, shrink, cut, xlim, ylim, pivots, trans, gridspacing, start, dfs, df_datapoint,
pivots_diff, cell, auc_lim, addsteps, scores,
di, contrastc2, contrast22, add=0, bar_orientation="vertical", pad=0.1, colorbar=True,
cl_outside=True, fig=[]):
good_cells = [
'2022-01-06-ai-invivo-1',
'2022-01-06-ag-invivo-1',
'2022-01-08-ad-invivo-1',
'2022-01-08-ah-invivo-1',
'2021-07-06-ab-invivo-1',
'2021-08-03-ac-invivo-1',
] #
if cell in good_cells:
pass
else:
pass
if addsteps == True:
pass
else:
pass
counter = 0
ax_cont = []
for ss, score in enumerate(scores):
what = score # + '_'+dev
df_contrast = df_datapoint[df_datapoint[contrastc2] == contrast22]
if len(df_contrast) > 0:
if score in df_contrast:
pivot, _, indexes, resorted, orientation, cut_type = get_data_pivot_three(df_contrast, what,
matrix_extent=start,
matrix_sorted=matrix,
orientation=orientation,
gridspacing=gridspacing,
dfs=dfs) # 35
try:
pass
except:
pass
if trans:
pivot = np.transpose(pivot)
pivots[score] = pivot
if addsteps == True:
ax = plt.subplot(grid0[counter])
counter += 1
plt.title(' ' + str(what), fontsize=8) # +' dev'+str(dev)
ax_cont.append(ax)
if 'auci' in score:
vmin = -0.5
vmax = 0.5
elif 'auc' in score:
vmin = 0
vmax = 1
else:
vmax = np.nanmax(pivot)
vmin = np.nanmin(pivot)
axs = plt.imshow(pivot, vmin=vmin,
vmax=vmax, origin='lower', cmap="RdBu_r",
extent=[np.min(pivot.columns),
np.max(pivot.columns),
np.min(pivot.index),
np.max(pivot.index)])
if cut:
plt.xlim(xlim)
plt.ylim(ylim)
if not ((di == 1) & (ss == 0)):
pass
else:
plt.xlabel('EOD mult 2')
plt.ylabel('EOD mult 1')
if colorbar:
if cl_outside:
_, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add,
top=True) # colorbar_outside
if di != len(pivots_diff) - 1:
plt.xticks([])
if ss != 0:
plt.yticks([])
symbol = ['$-$', '$=$']
ax.text(1.5, 0.5, symbol[ss],
fontsize=15, va='center', ha='center',
transform=ax.transAxes) # ha='center',
if cut:
ax.set_ylim(ylim)
ax.set_xlim(xlim)
else:
print(str(score) + ' score not found')
if di == 0:
if addsteps == False:
ax.text(1 + ws * 1, 0.5, '=', va='center', ha='center', fontsize=15,
transform=ax.transAxes) # ha='center',# not found
name = pivots_diff[di]
ax = plt.subplot(grid0[counter])
ax_cont.append(ax)
if name[0] in pivots.keys():
pivot_diff = pivots[name[0]]
title = name[0]
for p in range(1, len(name), 1):
if name[p] in pivots.keys():
pivot_diff = pivot_diff - pivots[name[p]]
title = title + ' - ' + name[p]
limit_diff = np.array([np.min(np.min(pivot_diff)), np.max(np.max(pivot_diff))])
axs = plt.imshow(pivot_diff, vmin=-0.5,
vmax=0.5, origin='lower', cmap="RdBu_r",
extent=[np.min(pivot_diff.columns),
np.max(pivot_diff.columns),
np.min(pivot_diff.index),
np.max(pivot_diff.index)])
if cut:
plt.xlim(xlim)
plt.ylim(ylim)
plt.ylabel('')
plt.xlabel('')
if colorbar:
if cl_outside:
_, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add, top=True)
else:
plt.colorbar(shrink=shrink, orientation=bar_orientation, pad=pad)
if di != len(pivots_diff) - 1:
plt.xticks([])
if ss != 0:
plt.yticks([])
if auc_lim != 'nonlinm':
ax = plt.subplot(grid0[counter + 1])
ax_cont.append(ax)
vmax = np.nanpercentile(np.abs(pivot_diff), 95)
vmin = np.nanpercentile(np.abs(pivot_diff), 5)
lim = np.max([vmax, np.abs(vmin)])
axs = plt.imshow(pivot_diff, vmin=-lim,
vmax=lim, origin='lower', cmap="RdBu_r",
extent=[np.min(pivot_diff.columns),
np.max(pivot_diff.columns),
np.min(pivot_diff.index),
np.max(pivot_diff.index)])
if cut:
plt.xlim(xlim)
plt.ylim(ylim)
plt.ylabel('')
plt.xlabel('')
if di != len(pivots_diff) - 1:
plt.xticks([])
if ss != 0:
plt.yticks([])
if colorbar:
if cl_outside:
_, _, _, _, _ = colorbar_outside(ax, axs, fig, add=add, top=True)
else:
plt.colorbar(orientation=bar_orientation, pad=pad) # shrink=shrinkorientation = 'horizontal'
counter += 1
if '*' in score:
pass
else:
limit_diff = []
axs = []
return orientation, ax_cont, limit_diff, axs, pivots
def condition_for_roc_thesis():
global diagonal, freq1_ratio, freq2_ratio, plus_q, way, length
combis = diagonal_points()
diagonal = 'B1+B2_diagonal2' # 'B1+B2_diagonal'#'diagonal11'#'test_data_cell_2022-01-05-aa-invivo-1'
freq1_ratio = combis[diagonal][0]
freq2_ratio = combis[diagonal][1]
plus_q = 'plus' # 'minus'#'plus'##'minus'
way = 'mult_minimum_1' # 'mult'#'absolut'
ways = ['absolut']
# das hier brauchen wir
# doch das brauchen wir hier sonst klappt das nicht mit dem ROC!
length = 25 # 20 # 5
trials_nr = 20 # 100
return trials_nr, length, ways, way, plus_q, freq2_ratio, freq1_ratio, diagonal
def save_RAM_to_csv(data_name, spikes_data_short, end=''):
file_name, spikes, spikes_selected = save_RAM_spikes_core(data_name, end, spikes_data_short)
save_RAM_overview_csv(data_name, end, spikes_data_short)
file_name, spikes_selected = save_RAM_both_csv(data_name, spikes, spikes_data_short)
save_RAM_eod_to_csv(data_name, spikes_selected)
def save_RAM_both_csv(data_name, spikes, spikes_data_short):
amp, file_name = get_min_amp_and_first_file(spikes_data_short, min_find=True)
spikes_selected = spikes_data_short[(spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] #
eod_path = load_only_spikes_RAM(data_name=data_name, emb=False, core_name='calc_RAM_data_eod_extra__first1_order__')
if os.path.exists(eod_path):
eod_data_short = pd.read_pickle(eod_path)
amp, file_name = get_min_amp_and_first_file(eod_data_short, min_find=True)
eod_selected = eod_data_short[(eod_data_short.amp == amp) & (eod_data_short.file_name == file_name)] #
frame_cell = save_spikes_csv(eod_selected, spikes, spikes_selected)
frame_cell.to_pickle('calc_RAM/spikes_and_eod_' + data_name + '.pkl')
return file_name, spikes_selected
def save_RAM_overview_csv(data_name, end, spikes_data_short): # , name = ''
amps, file_names = get_min_amp_and_first_file(spikes_data_short)
path_sascha = load_folder_name('calc_base') + '/' + 'calc_base_data-base_frame_overview.pkl'
frame = pd.read_pickle(path_sascha)
frame_c = frame[frame.cell == data_name]
for a, amp in enumerate(amps):
for file_name in file_names:
frame_ov = pd.DataFrame()
spikes_selected = spikes_data_short[
(spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] #
spikes = get_array_from_pandas(spikes_selected['spikes'], abs=False)
if len(spikes_selected) > 0:
frame_ov['file_name'] = spikes_selected.file_name
frame_ov['sampling'] = spikes_selected.sampling
frame_ov['cell'] = spikes_selected.cell
# das entspricht der abschätzung aus der baseline, aber ich könnte das auch aus dem RAM global EOD berechnen
frame_ov['eod_fr'] = spikes_selected.eod_fr
frame_ov['species'] = frame_c.species.iloc[0]
lim = find_lim_here(data_name, 'individual')
frame_ov['burst_corr_individual'] = lim
frame_ov['cell_type_reclassified'] = frame_c.cell_type_reclassified.iloc[0]
spikes, pos_reshuffled = reshuffle_spike_lengths(spikes)
names = names_eodfs()
vars = []
for name in names:
vars.append(spikes_selected[name].iloc[0])
frame_ov = reshuffle_eodfs(frame_ov, names, pos_reshuffled, vars, res_name='res')
cell_type = frame_c.cell_type_reclassified.iloc[0]
name = end_calc_ram_name_eod(end='-overview_') + end_calc_ram_name(data_name, end, file_name, amp,
cell_type='_' + cell_type,
species=frame_c.species.iloc[
0]) # data_name + end +'_amp_'+str(amp)+'_filename_'+str(file_name)+ '.csv'
frame_ov.to_csv(name)
del frame
def save_RAM_spikes_core(data_name, end, spikes_data_short, cell_type='', species=''):
amps, file_names = get_min_amp_and_first_file(spikes_data_short)
for a, amp in enumerate(amps):
for file_name in file_names:
spikes_selected = spikes_data_short[
(spikes_data_short.amp == amp) & (spikes_data_short.file_name == file_name)] #
spikes = get_array_from_pandas(spikes_selected['spikes'], abs=False)
if len(spikes) > 0:
spikes_df = pd.DataFrame()
spikes, pos_reshuffled = reshuffle_spike_lengths(spikes)
try:
spikes_df = save_spikestrains_several(spikes_df, spikes)
except:
print('shuffling thing')
embed()
name = end_calc_ram_name_eod(end='-spikes_') + end_calc_ram_name(data_name, end, file_name, amp,
'_' + cell_type, species.replace(' ',
'')) # data_name + end + '_amp_' + str(amp) + '_filename_' + str(file_name) + '.csv'
spikes_df.to_csv(name, index=False)
return file_name, spikes, spikes_selected
def get_min_amp_and_first_file(spikes_data_short, min_find=False):
if min_find == True:
amp = [np.min(spikes_data_short.amp)]
if len(spikes_data_short.file_name.unique()) > 1:
print('alignment problem')
embed()
file_name = [spikes_data_short.file_name.unique()[0]]
else:
amp = spikes_data_short.amp.unique()
file_name = spikes_data_short.file_name.unique() # [0]
return amp, file_name
def save_RAM_eod_to_csv(data_name, spikes_selected):
eod_path = load_only_spikes_RAM(data_name=data_name, emb=False, core_name='calc_RAM_data_eod_extra__first1_order__')
if os.path.exists(eod_path):
eod_data_short = pd.read_pickle(eod_path)
save_RAM_wod_to_csv_core(data_name, eod_data_short)
else:
if spikes_selected.file_name2.iloc[0] == 'InputArr_400hz_30s':
eod_path = load_only_spikes_RAM(data_name='2022-01-06-aa-invivo-1', emb=False,
core_name='calc_RAM_data_eod_extra__first1_order__')
eod_data_short = pd.read_pickle(eod_path)
save_RAM_wod_to_csv_core(data_name, eod_data_short)
else:
pass
def save_RAM_wod_to_csv_core(data_name, eod_data_short, end='', cell_type='', species='', ):
amps, file_names = get_min_amp_and_first_file(eod_data_short)
for a, amp in enumerate(amps):
for file_name in file_names:
eod_selected = eod_data_short[(eod_data_short.amp == amp) & (eod_data_short.file_name == file_name)] #
eods = get_array_from_pandas(eod_selected['eod'], abs=False)
if len(eods) > 0:
length = []
for eod in eods:
length.append(len(eod))
try:
eods_df = pd.DataFrame(eods[np.argmax(length)], columns=['eod'])
except:
print('something')
embed()
name = end_calc_ram_name_eod(end='-eod_') + end_calc_ram_name(data_name, end, file_name, amp,
cell_type='_' + cell_type,
species=species) # +cell_type
eods_df.to_csv(name, index=False) # calc_nix_RAM
def end_calc_ram_name_eod(end='-eod_'):
return 'calc_RAM/calc_nix_RAM' + end
def end_calc_ram_name(data_name='', end='', file_name='', amp='', cell_type='', species=''):
return data_name + end + '_amp_' + str(amp) + '_filename_' + str(file_name) + cell_type.replace(' ',
'') + species.replace(
' ', '') + '.csv'
def save_spikes_csv(eod_selected, spikes, spikes_selected):
frame_cell = pd.DataFrame()
frame_cell['spikes'] = spikes
print(spikes_selected.file_name2.iloc[0])
frame_cell['file_name'] = spikes_selected.file_name2.iloc[0]
frame_cell['sampling'] = eod_selected.sampling
frame_cell['cell'] = eod_selected.cell
frame_cell['eod_fr'] = eod_selected.eod_fr
return frame_cell
def compare_powers(show=False, step=str(30), v='diff', a_fes=[0.1], a_fjs=[0.1],
names=['amp_max_05']): # a_fjs=[0, 0.01, 0.05, 0.1, 0.2]
lim = []
default_settings(column=2, length=4)
#
cells = ['2012-07-03-ak-invivo-1' '2012-04-20-ak-invivo-1', '2012-05-10-ad-invivo-1', '2012-06-27-ah-invivo-1',
'2012-06-27-an-invivo-1'] # ,'2012-07-03-ak-invivo-1' '2012-04-20-ak-invivo-1','2012-05-10-ad-invivo-1', '2012-06-27-ah-invivo-1','2012-06-27-an-invivo-1', '2012-07-03-ak-invivo-1']
nrs = [0.5, 1, 1.5, 3]
adapt = 'adaptoffsetallall2'
self = ''
version_sinz = 'sinz'
for aa, a_fe in enumerate(a_fes):
for a, a_fj in enumerate(a_fjs):
for what in names:
plt.figure()
grid = gridspec.GridSpec(len(cells), len(nrs), hspace=0.65, wspace=0.34, bottom=0.15, top=0.97)
for c, cell in enumerate(cells):
for n, nr in enumerate(nrs):
full_name = 'modell_all_cell_no_sinz' + str(
nr) + '_afe0.1__afr1__afj0.1__length0.5_adaptoffsetallall2_0.995_1.005____stepefish50Hz_ratecorrrisidual35__modelbigfit_nfft4096_StartEmitter0.5_EndEmitter1.5_StartJammer0.5_EndJammer1.5Three_SameOffset'
versions, arrays = get_all_squares(adapt=adapt, full_name=full_name, self=self,
version_sinz=version_sinz, cell=cell, a_fe=a_fe, nr=nr,
a_fj=a_fj, what=what, step=step,
)
if len(versions['diff']) > 0:
if len(versions) > 0:
plt.subplot(grid[c, n])
plt.title('Power:' + str(nr))
if lim == []:
min = np.min(np.min(versions[v]))
max = np.max(np.max(versions[v]))
lim = np.max([np.abs(min), np.abs(max)])
axs = sns.heatmap(versions[v], vmin=-lim, vmax=lim, cmap="RdBu_r",
cbar_kws={'label': 'Nonlinearity [Hz]'}) # 'location': "left"
axs.invert_yaxis()
plt.subplots_adjust(hspace=0.8, wspace=0.8)
save_visualization(show)
def get_all_squares(full_name='', self='', version_sinz='sinz', cell=[], a_fe=0.2, nr=1, a_fj=0.2, what='std',
step='30', adapt='adaptoffset_bisecting', variant='no'):
squares = ['012', 'control_01', 'control_02', 'base_0'] # 'base_0'
versions = {}
arrays = []
for s, square in enumerate(squares):
control, vers_here, cell, eod_m, fr_rate_mult = define_squares_model_three(a_fe=a_fe, nr=nr, a_fj=a_fj,
what=what, step=step,
adapt=adapt, variant=variant,
square=square, cell_data=cell,
full_name=full_name, self=self,
minimum=0.5,
maximum=1.5,
version_sinz=version_sinz)
versions[square] = vers_here
arrays.append(np.array(vers_here))
if len(versions['012']) > 0:
versions['diff'] = versions['012'] - versions['control_01'] - versions['control_02'] + versions['base_0']
else:
versions['diff'] = []
return versions, arrays
def plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, ylim=[-1.1, 1.1], g=0, jammer_label='f2',
emitter_label='f1', remove=True, receiver_label='f0',
waves_present=['receiver', 'emitter', 'jammer', 'all'], color='grey', eod_fr=700, add=0,
xlim=[0, 0.05], extract_here=True, colors = [], sheme_shift=0, extract='', title=[], threshold=0.02, color_eod='grey'):
stimulus = np.zeros(len(eod_fish_r))
axes = []
for ww, w in enumerate(waves_present):
ax = plt.subplot(grid0[ww + sheme_shift, g])
axes.append(ax)
if len(colors)<1:
color_eod = 'grey'
else:
color_eod = colors[ww]
if w == 'receiver':
ax.plot(time, eod_fish_r, color=color_eod, lw=0.5)
stimulus += eod_fish_r
ax.set_ylim(ylim)
if len(xlim) > 0:
ax.set_xlim(xlim)
ax.spines['bottom'].set_visible(False)
if g == 0:
ax.set_ylabel(receiver_label)
elif w == 'emitter':
ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
ax.plot(time, eod_fish_e, color=color_eod, lw=0.5)
stimulus += eod_fish_e
ax.set_ylim(ylim)
if len(xlim) > 0:
ax.set_xlim(xlim)
ax.spines['bottom'].set_visible(False)
if g == 0:
ax.set_ylabel(emitter_label) # , color='grey'
elif w == 'jammer':
ax.text(0.5, 1.01, '$+$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
ax.plot(time, eod_fish_j, color=color_eod, lw=0.5)
stimulus += eod_fish_j
ax.set_ylim(ylim)
if len(xlim) > 0:
ax.set_xlim(xlim)
if g == 0:
ax.set_ylabel(jammer_label)
elif w == 'all':
if title:
ax.set_title(title, color=color)
ax.text(0.5, 1.45 + add, '$=$', va='center', ha='center', transform=ax.transAxes, fontsize=20)
eod_interp, eod_norm = extract_am(stimulus, time, extract=extract, norm=False, sampling=1 / time[1],
eodf=eod_fr,
emb=False, threshold=threshold) # , extract=extract
plt.plot(time, stimulus, color=color_eod, lw=0.5)
if extract_here:
plt.plot(time[1::], eod_interp[1::], color=color) # , clip_on = False
plt.ylim(-1.22, 1.22)
if len(xlim) > 0:
plt.xlim(xlim)
plt.ylim(ylim)
if g == 0:
plt.ylabel('stimulus')
if g == 0:
if ww == 0:
plt.ylabel(receiver_label)
elif ww == 1:
plt.ylabel(emitter_label)
elif ww == 2:
plt.ylabel(jammer_label)
elif ww == 3:
plt.ylabel('stimulus')
if remove:
ax.show_spines('')
ax.set_xticks([])
ax.set_yticks([])
return ax, axes
def experimental_protocol_lissbon_amps(add='', show=True,
):
default_figsize(column=2, length=3)
grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=0.99, bottom=0.07,
right=0.98) # height_ratios = [1,6]bottom=0.25, top=0.8,
stimulus_length = 0.1
deltat = 1 / 20000
eod_fr = 750
a_fr = 1
eod_fe = 680 # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fes = [0.01, 0.2, 0.6, 1] # ,1.2,2]1
ylim = [-2.3, 2.3] # [-2, 2]
eod_fj = 730 # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = 0.05
variant_cell = 'no' # 'receiver_emitter_jammer'
waves_presents = [['receiver', 'emitter', 'all']] * len(a_fes)
color = ['black'] * len(a_fes)
symbols = [''] * len(a_fes)
gs = np.arange(0, len(color), 1)
grid0 = gridspec.GridSpecFromSubplotSpec(3, len(gs), wspace=0.3, hspace=0.35,
subplot_spec=grid[0])
axes0 = []
for i in range(len(waves_presents)):
eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fes[i], [eod_fe], 0, eod_fj, a_fj)
time = time * 1000
print(eod_fe - eod_fr)
xlim = [0, 30]
if i == 3:
extract_here = False
else:
extract_here = True
ax, axes = plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=gs[i],
waves_present=waves_presents[i], color=color[i], eod_fr=eod_fr, ylim=ylim, xlim=xlim,
jammer_label='', emitter_label='', receiver_label='',
title='', extract_here=extract_here, remove=True,
threshold=-0.25) # extract = 'globalmax',
axes0.append(axes[0])
if extract_here == False:
time = np.arange(0, stimulus_length, deltat)
time_fish_r = time * 2 * np.pi * np.abs((eod_fr - eod_fe))
eod_fish_r = 1 + (a_fes[i] - 0.12) * np.cos(time_fish_r)
time = time * 1000
ax.plot(time, eod_fish_r + 0.3, color='black')
ax.set_ylim(ylim)
ax.text(1, 1.03, '$c=%s' % (int(np.round(a_fes[i] * 100))) + '\,\%$', transform=ax.transAxes,
ha='right') # +' distance = '+str(int(np.round(a_fes_cm[i])))+' cm')
test = False
if test:
print(str(np.max(eod_fish_r)) + ' ' + str(np.max(eod_fish_e)))
plt.plot(eod_fish_r)
plt.plot(eod_fish_e)
plt.show()
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
if i == 0:
manual = False
if manual:
ax.plot([0, 10], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.1, '10 ms', va='center', fontsize=11, transform=ax.transAxes)
ax.xscalebar(0.25, -0.02, 10, 'ms', va='right', ha='bottom')
extra = False
if extra:
models = resave_small_files("models_big_fit_d_right.csv", load_folder='calc_model_core')
cell = '2012-07-03-ak-invivo-1'
deltat, eod_fr, model_params, offset = get_model_params(models, cell=cell)
spikes = [[]] * 5
for t in range(5):
cvs, adapt_output, baseline_after, _, rate_adapted, rate_baseline_before, rate_baseline_after, spikes[
t], \
stimulus_altered, \
v_dent_output, offset_new, v_mem_output, noise_final = simulate(cell, offset,
eod_fish_r + eod_fish_e + eod_fish_j,
EODf=eod_fr, deltat=deltat,
**model_params)
base_cut, mat_base = find_base_fr(spikes, deltat, stimulus_length, time, dev=0.0005)
ax = plt.subplot(grid0[-2, i])
ax.eventplot(np.array(spikes) * 1000, color='grey')
ax.set_xlim(xlim)
ax.show_spines('')
ax = plt.subplot(grid0[-1, i])
ax.plot(time, mat_base[0:len(time)], color='black')
ax.set_xlim(xlim)
ax.set_ylabel('Firing Rate [Hz]')
fig = plt.gcf()
fig.tag(axes0, xoffs=-3.5, yoffs=0.1)
save_visualization('', show, jpg=True, png=False, counter_contrast=0, savename='', add=add)
def experimental_protocol_lissbon(add='', color=['green', 'blue', 'red', 'orange'], titles=['receiver',
'receiver + female',
'receiver + intruder',
'receiver + female + intruder',
[]],
waves_presents=[['receiver', '', '', 'all'],
['receiver', 'emitter', '', 'all'],
['receiver', '', 'jammer', 'all'],
['receiver', 'emitter', 'jammer', 'all'],
], figsize=(12, 5.5),
show=True,
):
plt.figure(figsize=figsize)
grid = gridspec.GridSpec(1, 1, wspace=0.7, hspace=0.5, left=0.05, top=0.99, bottom=0.07,
right=0.95) # height_ratios = [1,6]bottom=0.25, top=0.8,
grid0 = gridspec.GridSpecFromSubplotSpec(4, 4, wspace=0.3, hspace=0.35, height_ratios=[1, 1, 1, 1],
subplot_spec=grid[0])
stimulus_length = 0.3
deltat = 1 / 40000
eod_fr = 750
a_fr = 1
eod_fe = 600 # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
a_fe = 0.5
eod_fj = 680 # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
a_fj = 0.05
variant_cell = 'no' # 'receiver_emitter_jammer'
eod_fish_j, time, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
variant_cell, '',
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
gs = [0, 1, 2, 3, 4]
symbols = ['', '', '', '', '']
ylim = [-2, 2]
time = time * 1000
for i in range(len(waves_presents)):
ax, axes = plot_shemes_lis(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time, g=gs[i],
waves_present=waves_presents[i], color=color[i], eod_fr=eod_fr, ylim=ylim,
xlim=[0, 70],
jammer_label='intruder', emitter_label='female', receiver_label='receiver',
title=titles[i])
if ax != []:
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
if i == 0:
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
ax.text(0, -0.1, '20 ms', va='center', fontsize=11, transform=ax.transAxes)
ax.set_ylim(ylim)
axes.append(ax)
fig = plt.gcf()
axes = fig.axes
fig.tag(axes[0::4], xoffs=-3, yoffs=0.3)
save_visualization('', show, jpg=True, png=False, counter_contrast=0, savename='', add=add)
def rainbow_title(fig, axt, titles, color_add_pos, ha='left', a=0, start_xpos=0, y_pos=1.1):
if type(titles) != str:
for aa in range(len(titles)):
if aa == 0:
pos = start_xpos # + add_pos[a][aa]
text = axt.text(pos, y_pos, titles[aa], color=color_add_pos[a][aa],
transform=axt.transAxes, ha=ha) # verticalalignment='right',
text.draw(fig.canvas.get_renderer())
ex = text.get_window_extent()
ex2 = ex.transformed(axt.transAxes.inverted())
pos = ex2.get_points()[1][0] + 0.01
else:
axt.text(0, 1.1, titles, color='black',
transform=axt.transAxes) # verticalalignment='right',add_pos[a]
def calc_areas(path, frame_ref, colr, x_pos, cells_chosen):
default_settings(column=2, length=3)
frame = pd.read_csv(path)
cvs = frame_ref.cv_0
cells = frame_ref.cell.unique()
areas_01 = np.ones(len(cells)) * float('nan')
areas_012 = np.ones(len(cells)) * float('nan')
areas_01_one = np.ones(len(cells)) * float('nan')
areas_012_one = np.ones(len(cells)) * float('nan')
nonlin = np.ones(len(cells)) * float('nan')
nonlin_area = np.ones(len(cells)) * float('nan')
areas_01_scatter = colr * 1
name = 'score'
for c, cell in enumerate(cells):
frame_cell = frame[frame.cell == cell]
frame_cell['score'] = get_nonlin_scores(frame_cell)
areas_01_one[c] = fin_min_pos(frame_cell, 'auci_base_01', x_pos)
areas_012_one[c] = fin_min_pos(frame_cell, 'auci_02_012', x_pos)
areas_01[c] = metrics.auc(frame_cell.c1, frame_cell['auci_base_01'])
areas_012[c] = metrics.auc(frame_cell.c1, frame_cell['auci_02_012'])
nonlin[c] = fin_min_pos(frame_cell, name, x_pos)
nonlin_area[c] = metrics.auc(frame_cell.c1, frame_cell[
name]) # metrics.auc(frame_cell.c1,frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean'])
if cell in cells_chosen:
areas_01_scatter[c] = 'black'
diff_areas = np.array(areas_012) - np.array(areas_01)
return cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one - areas_01_one
def get_nonlin_scores(frame_cell):
score = val_new_core(frame_cell)
return score
def nonlinval_core(frame_cell):
return frame_cell[val_nonlin_chapter4()] / (frame_cell['amp_B1_01_mean'])
def val_nonlin_chapter4():
return 'amp_B1+B2_012_mean'
def fin_min_pos(frame_cell, name, x_pos):
nonlin = frame_cell[name].iloc[np.argmin(np.abs(
frame_cell.c1 - x_pos))] # metrics.auc(frame_cell.c1,frame_cell['amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean'])
return nonlin
def plt_scatter_nonlin_all_main():
default_settings(column=2, length=2.3)
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal',
'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal']
path = load_folder_name('calc_ROC') + '/' + frame_names[1] + '.csv'
path_ref = load_folder_name(
'calc_ROC') + '/' + 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_C1_0.02_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal.csv'
frame_ref = pd.read_csv(path_ref)
frame_ref = frame_ref.sort_values(by='cv_0')
cm = plt.get_cmap("hsv")
cells_chosen = ['2013-01-08-aa-invivo-1', "2012-06-27-ah-invivo-1",
"2014-06-06-ac-invivo-1"] # '2012-06-27-an-invivo-1',
colr = [cm(float(i) / (len(frame_ref))) for i in range(len(frame_ref))]
fig, ax = plt.subplots(1, 3) # figsize = (12,5.5)
x_pos = 0.02 # 'amp_B1+B2_012-01-02+0_norm_01B1+02B2_mean',
names = ['amp_B1+B2_012-01-02+0_norm_01B1_mean'] # 'amp_B1+B2_012-01-02+0_norm_01B1_mean',
for _, _ in enumerate(names):
cvs, nonlin_area, diff_areas, areas_01_scatter, nonlin, areas_012_one = calc_areas(path, frame_ref, colr, x_pos,
cells_chosen)
color = 'grey'
ax[0].scatter(cvs, diff_areas, color=color, s=15, clip_on=False) # color=colr,
ax[0].axhline(0, linestyle='--', linewidth=0.5, color='grey')
ax[0].set_xlabel('CV')
ax[0].set_xlim(0, 1.1)
ax[0].set_ylabel(core_scatter_wfemale())
ax[1].scatter(cvs, nonlin_area, color=color, s=15, clip_on=False) # color = colr,
ax[1].axhline(0, linestyle='--', linewidth=0.5, color='grey')
ax[1].set_xlabel('CV')
ax[1].set_xlim(0, 1.1)
ax[1].set_ylabel(peak_b1b2_name())
ax[2].set_xlabel(core_scatter_wfemale())
ax[2].set_ylabel(peak_b1b2_name())
corr, p_value = stats.pearsonr(nonlin_area, diff_areas)
label = pearson_label(corr, p_value, nonlin_area, n=True)
ax[2].text(1, 1.05, label, ha='right', transform=ax[2].transAxes)
ax[2].scatter(nonlin_area, diff_areas, color=color, s=15, clip_on=False) # color=colr,
model = LinearRegression()
model.fit(nonlin_area.reshape((-1, 1)), diff_areas.reshape((-1, 1)))
slope = model.coef_
intercept = model.intercept_
ax[2].plot([0, np.max(nonlin_area) * 1.05], [intercept, intercept + np.max(nonlin_area) * 1.05 * slope],
color='grey', linewidth=0.5) ##embed()'Correlation='+str(np.round(corr,2))
make_simple_tags(ax, xpos=-0.07, letters=['A', 'B', 'C'])
plt.subplots_adjust(wspace=0.85, hspace=0.4, bottom=0.21, right=0.95) # , top = 0.6
save_visualization()
plt.show()
def core_scatter_wfemale():
return '$\mathrm{AUC_{Female}}-\mathrm{AUC_{NoFemale}}$ '
def peak_b1b2_name():
return 'Nonlinearity $A(\Delta \mathrm{f_{Sum}})$ [Hz]'
def core_decline_ROC(trial_nr='20', absolut=True, short=True, pos=True):
lastnr = '0.1'
b_cond = b_cond_core()
combpos = b_cond + '_FrF1rel_0.3_FrF2rel_0.7'
combneg = 'vertical1_FrF1rel_1_FrF2rel_0.7'
if not absolut:
if not short:
if pos:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal']
else:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal']
else:
if pos:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal']
else:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal']
else:
if not short:
if pos:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal']
else:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_50_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal']
else:
if pos:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combpos + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal']
else:
frame_names = [
'calc_ROC_contrasts-ROCmodel_contrasts1_' + combneg + '_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_' + lastnr + '_len_25_nfft_32768_trialsnr_' + trial_nr + '_absoluttemporal']
return frame_names, trial_nr
def comb_talk():
return 'B1+B2_diagonal3'
def comb_talk2():
return 'B1-B2_diagonal3'
def b_cond_core():
return 'B1+B2_diagonal3'
def calc_right_core_nonlin(transient, steps_name, T, steps, results, data_mat, position, dt, add_name=''):
c = [[]] * 4
phi = [[]] * 4
c0_tmp = 0.0
a_tmp = np.zeros(4)
b_tmp = np.zeros(4)
# Calculate a and b for groundmode and higher harmonics
for j in range(4):
freq_here = (j + 1) / T
time_here_all = np.arange(transient + steps) * dt
# also hier gibts die option den anfang für die trans rauszunehmen, kann aber auch Null sein
if j < 1:
c0_tmp = np.mean(data_mat[transient:transient + len(time_here_all)])
try:
a_tmp[j] = np.mean(data_mat[transient:len(time_here_all)] * np.cos(
2.0 * np.pi * time_here_all[transient:len(time_here_all)] * freq_here))
except:
print('a tmp problems')
embed()
b_tmp[j] = np.mean(data_mat[transient:len(time_here_all)] * np.sin(
2.0 * np.pi * time_here_all[transient:len(time_here_all)] * freq_here))
c0 = c0_tmp # / (steps)
a = a_tmp * 2.0 # / (steps)
b = b_tmp * 2.0 # / (steps)
for k in range(4):
c[k] = np.sqrt(a[k] ** 2 + b[k] ** 2)
phi[k] = np.arctan(a[k] / b[k])
test = False
if test:
c0, c, phi = average_fft() # a_tmp, b_tmp, c, c0_tmp, phi, s
for c_nr in range(len(c)):
try:
results.loc[position, 'c_' + str(c_nr) + steps_name + add_name] = c[c_nr]
except:
print('c problem')
embed()
for c_nr in range(len(c)):
results.loc[position, 'phi_' + str(c_nr) + steps_name + add_name] = phi[c_nr]
results.loc[position, 'c0' + steps_name + add_name] = c0
return results
def average_fft(a_tmp, b_tmp, c, c0_tmp, phi, s, steps=1):
c0 = c0_tmp / (1.0 * steps) # (single_period)
a = a_tmp * 2.0 / (1.0 * steps) # / (single_period)
b = b_tmp * 2.0 / (1.0 * steps) # / (single_period)
for k in range(4):
c[s, k] = np.sqrt(a[k] ** 2 + b[k] ** 2)
phi[s, k] = np.arctan(a[k] / b[k])
return c0, c, phi
def load_savedir(level=0, individual_tag='', frame=[], save=False, csv=False, pkl=False, emb=False):
# ich kann das einmal als übersichtsfile haben auf level 0
# als variierendes file auf level 1
# und vielleich ein example file auf level 1 ziehen damit man eine versinosübersicht hat
if 'miniconda3' in inspect.stack()[1][1]:
initial_function = \
inspect.stack()[-16][1].split('/')[-1].split('.')[0]
last_function = inspect.stack()[-16][4][0].split('(')[0].strip()
else:
initial_function = \
inspect.stack()[-1][1].split('\\')[-1].split('.')[0]
list_name = []
for i in range(len(inspect.stack())):
save_name = inspect.stack()[i][3]
list_name.append(save_name)
pos = -2 # np.where(np.array(list_name) == '<module>')[0][0]-1
last_function = list_name[pos]
if emb:
embed()
print(initial_function)
t1 = time.time()
data_extra_fold = '' # '_data'
if level == 0: # Null für die neuen
# auf dem Nuller Level muss man das mit dem Funktionsnamen machen
# aber das sollte man selten verwenden
save_name = initial_function + data_extra_fold + '/' + last_function + '-'
elif level == 1: # 1 für die alten
# auf dem Nuller Level muss man das mit dem Funktionsnamen machen
# aber das sollte man selten verwenden
save_name = initial_function + data_extra_fold + '/'
elif level == 2:
# am besten tut man die Basic functions auf das 1er Level
# die haben einen eigenständigen Namen und sind in dem Funktions Ordner
save_name = initial_function + data_extra_fold + '/' + last_function + '/'
elif level == 3:
# und die Zellen etc Sachen bzw die Versions Sachen sind dann im nächsten Ordner
if not os.path.isdir(initial_function + + data_extra_fold + '/' + last_function + '/cells'):
os.mkdir(initial_function + data_extra_fold + '/' + last_function + '/cells')
save_name = initial_function + data_extra_fold + '/' + last_function + '/cells/'
try:
if save:
if csv:
frame.to_csv(save_name + individual_tag + '.csv', index=False)
if pkl:
frame.to_pickle(save_name + individual_tag + '.pkl')
except:
print('save problem')
embed()
t2 = time.time() - t1
print(f'save time {t2}')
return save_name
def redo_on_cell_level(redo_level, append_cells, redo, beat_results, cell, counter_continued, cell_name='dataset',
range_orig1=[], range_orig2=[]):
# do_thiseod - of true do this frequency new
combs = []
if 'celllevel' in redo_level:
# decide if cell in sample or not
if (append_cells == True) and (redo == False):
if cell in np.unique(beat_results[cell_name]):
if 'clusters' in redo_level:
f1_present = np.unique(beat_results[beat_results[cell_name] == cell].f1) # ].f1
f2_present = np.unique(beat_results[beat_results[cell_name] == cell].f2) # ].f1
len_required = len(range_orig1) * len(range_orig2)
len_present = len(f1_present) * len(f2_present)
# ich subtrahiere noch die kontrolle mit 10 Hz,
# also ich schau erst ob das wirklich nur eine Kontrolle ist oder doch nicht
if (10 not in list(map(int, range_orig1))) & (10 not in list(map(int, range_orig2))):
len_remaining = len_present - (len(f1_present) + len(f2_present))
beat_cell = beat_results[beat_results[cell_name] == cell]
beat_corrected = beat_cell[(beat_cell['f2'] != 10) & (beat_cell['f1'] != 10)]
combs_all = beat_corrected[['f2', 'f1']]
elif 10 not in range_orig1:
len_remaining = len_present - (len(f1_present))
beat_cell = beat_results[beat_results[cell_name] == cell]
beat_corrected = beat_cell[beat_cell['f1'] != 10]
combs_all = beat_corrected[['f2', 'f1']]
elif 10 not in range_orig2:
len_remaining = len_present - (len(f2_present))
beat_cell = beat_results[beat_results[cell_name] == cell]
beat_corrected = beat_cell[beat_cell['f2'] != 10]
combs_all = beat_corrected[['f2', 'f1']]
else:
len_remaining = len_present
combs_all = beat_results[beat_results[cell_name] == cell][['f2', 'f1']]
combs = np.unique(combs_all, axis=0)
if len_remaining != len_required:
do_thiscell = True
do_thiseod = True
else:
do_thiscell = False
do_thiseod = False
counter_continued += 1
print('Nr ' + str(counter_continued) + 'already there')
else:
do_thiscell = False
do_thiseod = False
counter_continued += 1
print('Nr ' + str(counter_continued) + 'already there')
else:
do_thiscell = True
do_thiseod = True
# just do all cells irrespective if they are in any files or not
else:
do_thiscell = True
do_thiseod = True
else:
do_thiscell = True
do_thiseod = True
return do_thiscell, do_thiseod, counter_continued, combs
def redo_or_append(save_name, redo=False, name_orig=[]):
# output
# beatresults - preallocated array
# addcell - add cells or redo
# if we dont redo but continue saving
if redo == False:
if len(name_orig) < 1:
name = save_name + '.pkl'
# name1 = folder_name('calc_model')+'/modell_all_cell_' + save_name + '.pkl'
# if the datataname exists add new cells to the existing
if os.path.exists(name):
preallocated = pd.read_pickle(name)
position = len(preallocated)
if len(preallocated) > 0:
append_cells = True
else:
append_cells = False
preallocated = pd.DataFrame()
position = 0
else:
preallocated = pd.DataFrame()
position = 0
append_cells = False
else:
if os.path.exists(name_orig):
preallocated = pd.read_pickle(name_orig)
position = len(preallocated)
append_cells = True
else:
preallocated = pd.DataFrame()
position = 0
append_cells = False
# else preallocate an empty array
else:
# if we want to redo the whole simulation
# preallocated = []
preallocated = pd.DataFrame()
position = 0
append_cells = False
return append_cells, preallocated, position
def calc_nonlinearity_contrasts(transient_s=0, cells=[], n=1, adapt_offset='adaptoffsetallall2', stimulus_length_orig=2,
freq_type='_beat_', single_train='', fft='fft', dev='original', trials_nr=150,
zeros='zeros', a_f1s=[0, 0.005, 0.01, 0.05, 0.1, 0.2, ], a_frs=[1], add_half=0,
nfft=int(2 ** 15), beat='', nfft_for_morph=4096 * 4, gain=1, fish_jammer='Alepto',
redo_level='celllevel', us_name='', adapt_type=''):
# adapt = ''_adaptMean_
stimulus_length = stimulus_length_orig
# single_train = 'single_train'
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
cell_nrs = [11, 22, 5, 13, 12, 6, 15, 20, 4, 7, 26, 23]
if len(cells) < 1:
frame = pd.read_pickle('cv.pkl')
frame = frame.sort_values(by='cv')
cells = frame.cell # np.array(model_cells['cell'])
if len(cell_nrs) < 1:
pass
freqs2 = [0] # freqs1#[0]#
single_waves = ['_SingleWave_']
a_f2s = [np.max(a_f1s)]
for a, a_fr in enumerate(a_frs):
####### VARY HERE
for _ in single_waves:
# if single_wave == '_SingleWave_':
for a_f2 in a_f2s:
try:
save_name = save_name_nonlinearity(add_half, a_f2s=a_f2s, freqs2=freqs2, a_f1_end=a_f1s[-1],
transient_s=transient_s, n=n, adapt_offset=adapt_offset,
freq_type=freq_type, adapt=adapt_type,
stimulus_length=stimulus_length, fft=fft, dev=dev,
trials_nr=trials_nr, a_fr=a_fr, zeros=zeros)
except:
print('save name problem')
embed()
print(save_name)
redo = False
save_dir = load_savedir(level=1)
append_cells, results, position = redo_or_append(
save_dir + 'modell_all_cell_' + save_name, redo=redo, name_orig=save_name)
counter_continued = 0
for cell in cells: # cell_nr in cell_nrs:
###########################################
# fig, ax = plt.subplots(len(cell_nrs), 1, figsize=(12, 5.5)) # sharex=True,
try:
model_cells_here = model_cells[model_cells['cell'] == cell]
model_params = model_cells_here.iloc[0]
except:
print('single positional index doesnt exists')
embed()
eod_fr = model_params['EODf']
offset = model_params.pop('v_offset')
cell = model_params.pop('cell')
print(cell)
do_this_cell_orig, do_thiseod, counter_continued, combs_all = redo_on_cell_level(
redo_level, append_cells, redo, results, cell, counter_continued, cell_name='cell')
if type(add_half) == str:
freqs1_len = freqs_array(add_half, eod_fr)
else:
freqs1_len = [0]
if do_this_cell_orig == False:
results_cell = results[results['cell'] == cell]
if len(results_cell) == len(a_f1s) * len(freqs1_len):
do_this_cell_now = False
else:
do_this_cell_now = True
else:
do_this_cell_now = True
if do_this_cell_now:
f1 = 0
f2 = 0
sampling_factor = ''
phaseshift_fr = 0
cell_recording = ''
mimick = 'no'
fish_morph_harmonics_var = 'harmonic'
fish_emitter = 'Alepto' # ['Sternarchella', 'Sternopygus']
fish_receiver = 'Alepto' #
phase_right = '_phaseright_'
constant_reduction = ''
lower_tol = 0.995
upper_tol = 1.005
SAM = '' # ,
damping = 0.45 # 0.65,0.2,0.5,0.2,0.6,0.45,0.6,0.35
damping_type = ''
exponential = ''
dent_tau_change = 1
# in case you want a different sampling here we can adujust
time_array, sampling, deltat = deltat_choice(model_params, sampling_factor, eod_fr,
stimulus_length)
# generate the eod_fish_r in the four mimick variants (copy, thunderfish, mimick, just sinus)
eod_fish_r, deltat, eod_fr, time_array = eod_fish_r_generation(time_array, eod_fr, a_fr,
stimulus_length, phaseshift_fr,
cell_recording, zeros, mimick,
sampling, fish_receiver, deltat,
nfft, nfft_for_morph,
fish_morph_harmonics_var=fish_morph_harmonics_var,
beat=beat)
sampling = 1 / deltat
slope = 0
add = 0
plus = 0
if exponential == '':
v_exp = 1
exp_tau = 0.001
# prepare for adapting offset due to baseline modification
_, _ = prepare_baseline_array(time_array, eod_fr,
nfft_for_morph,
phaseshift_fr,
mimick, zeros,
cell_recording,
sampling,
stimulus_length,
fish_receiver,
deltat, nfft,
damping_type,
damping, us_name,
gain, beat=beat,
fish_morph_harmonics_var=fish_morph_harmonics_var)
# now we are ready for the final modeling part in this function
trials_nr_base = 10
spike_adapted = [[]] * trials_nr_base
for t in range(trials_nr_base):
# get the baseline properties here
# baseline_after,spike_adapted,rate_adapted, rate_baseline_before, rate_baseline_after, np.array(spike_times), stimulus_power, v_dent_output[int(0.05 / deltat):-1], offset, v_mem_output
if a_fr == 0:
power_here = 'sinz' + '_' + zeros
else:
power_here = 'sinz'
# embed()
# todo: evnetuell das mit dem zeros am ende dazu!
cvs, adapt_output, baseline_after_b, _, rate_adapted_b, rate_baseline_before_b, rate_baseline_after_b, \
spike_adapted[t], _, _, offset_new, _, noise_final = simulate(cell, offset, eod_fish_r,
deltat=deltat,
adaptation_variant=adapt_offset,
adaptation_yes_j=f2,
adaptation_yes_e=f1,
adaptation_yes_t=t,
power_variant=power_here,
power_alpha=alpha, power_nr=n,
**model_params)
if t == 0:
print('first Baseline ' + str(rate_baseline_before_b))
# here we record the changes in the offset due to the adaptation
# and we subsequently reset the offset to be the new adapted for all subsequent trials
offset = offset_new * 1
# Baseline Characteristics
mean_isi, std_isi, fr, isi, cv0, ser0, ser_first0, ser_sum = calc_baseline_char(spike_adapted,
stimulus_length,
trials_nr_base)
print('fr base ' + str(fr))
freqs1 = get_freqs_contrasts(a_fr, add_half, eod_fr, fr)
for ff, freq1 in enumerate(freqs1):
freq1 = [freq1]
sampling_rate = 1 / deltat
for fff, freq2 in enumerate(freqs2): #
freq2 = [freq2[fff]]
# Array Fouriercoefficients and phase, columns for groundmode and higher harmonics (4 total)
# Array for proportional lines
if 'f1' in results.keys():
results_f = results[results['f1'] == freq1[0]]
else:
results_f = results # [results['f1'] == freq1[0]]
for aa, a_f1 in enumerate(a_f1s):
if do_this_cell_orig:
do_af1 = True
else:
if np.round(a_f1, 6) not in np.array(np.round(results_f.a_f1, 6)):
do_af1 = True
else:
do_af1 = False
print('f1_' + str(freq1) + '_af1_' + str(a_f1))
if do_af1:
results, position = calc_single_af_nonlin(transient_s, freq_type, ser0,
single_train, dev, fft, a_fr,
trials_nr, results, nfft,
damping_type, gain,
save_name, cvs, position, cv0, fr,
cell,
sampling_rate, model_params, n,
dent_tau_change, constant_reduction,
exponential, plus,
slope, add,
deltat, alpha,
lower_tol,
upper_tol,
v_exp,
exp_tau, f2,
fish_jammer,
freq2,
damping, us_name, a_f2, eod_fish_r,
SAM,
aa, offset, freq1, eod_fr,
phase_right,
a_f1, phaseshift_fr, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, time_array,
mimick,
fish_emitter,
f1, sampling, stimulus_length,
adapt_type=adapt_type)
def get_freqs_contrasts(a_fr, add_half, eod_fr, fr):
if a_fr == 1: # das ist fals wir einen freq scan haben
if type(add_half) == str:
freqs1 = freqs_array(add_half, eod_fr)
else:
beat1 = fr / 2 + add_half
freqs1 = [eod_fr - beat1]
else:
freqs1 = [fr / 2 + add_half]
return freqs1
def freqs_array(add_half, eod_fr):
from_val = add_half.split('frange_from_')[1].split('_to')[0]
to_val = add_half.split('to_')[1].split('_in')[0]
in_step = add_half.split('in_')[1].split('_')[0] # frange_from_0_to_400_in_1
freqs1 = np.arange(eod_fr + float(from_val), eod_fr + float(to_val), float(in_step))
return freqs1
def calc_single_af_nonlin(transient_s, freq_type, ser0, single_train, dev, fft, a_fr, trials_nr, results, nfft,
damping_type, gain, save_name, cvs, position, cv0, fr, cell, sampling_rate, model_params, n,
dent_tau_change, constant_reduction, exponential, plus,
slope, add,
deltat, alpha,
lower_tol,
upper_tol,
v_exp,
exp_tau, f2,
fish_jammer, freq2, damping, us_name, a_f2, eod_fish_r, SAM, aa,
offset, freq1, eod_fr, phase_right, a_f1, phaseshift_fr, nfft_for_morph,
cell_recording,
fish_morph_harmonics_var, time_array, mimick,
fish_emitter,
f1, sampling, stimulus_length, adapt_type=''):
print('af1_nr ' + str(aa) + ' offset ' + str(offset))
beat1 = freq1 - eod_fr
phaseshift_f1, phaseshift_f2 = get_phaseshifts(a_f1, a_f2, phase_right, phaseshift_fr)
eod_fish1, time_fish_e = eod_fish_e_generation(time_array, a_f1, freq1, f1, phaseshift_f1, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, 'zeros', mimick, fish_emitter,
thistype='emitter')
eod_fish2, time_fish_j = eod_fish_e_generation(time_array, a_f2, freq2, f2, phaseshift_f2, sampling,
stimulus_length, nfft_for_morph, cell_recording,
fish_morph_harmonics_var, 'zeros', mimick, fish_jammer,
thistype='jammer')
eod_stimulus = eod_fish1 + eod_fish2
t1 = time.time()
spikes = [[]] * trials_nr
for t in range(trials_nr):
stimulus, eod_fish_sam = create_stimulus_SAM(SAM, eod_stimulus, eod_fish_r, freq1, f1,
eod_fr,
time_array, a_f1, eod_fj=freq2, j=f2,
a_fj=a_f2)
# damping variants
std_dump, max_dump, range_dump, stimulus, damping_output = all_damping_variants(stimulus, time_array,
damping_type, eod_fr, gain,
damping, us_name, plot=False,
std_dump=0, max_dump=0,
range_dump=0)
# WIR BRAUCHEN KEIN ADAPT HIER DAS HABEN WIR SCHON BEI DER BASELIN GEMACHT
adapt_offset_here = 'no'
if a_fr == 0:
power_here = 'sinz' + '_' + zeros # not found
else:
power_here = 'sinz'
_, adapt_output, baseline_after, _, _, _, \
_, spikes[t], \
_, \
_, offset_new, _, noise_final = simulate(cell, offset, stimulus, deltat=deltat,
adaptation_variant=adapt_offset_here, adaptation_yes_j=f2,
adaptation_yes_e=f1, adaptation_yes_t=t,
adaptation_upper_tol=upper_tol, adaptation_lower_tol=lower_tol,
power_variant=power_here, power_alpha=alpha, power_nr=n,
tau_change_choice=constant_reduction, tau_change_val=dent_tau_change,
sigmoidal_mult=1, sigmoidal_plus=plus, sigmoidal_slope=slope,
sigmoidal_add=add, LIF_adapt_type=adapt_type,
LIF_exponential=exponential, LIF_exponential_tau=exp_tau,
LIF_expontential__v=v_exp, **model_params)
mean_isi, std_isi, fr1, isi, cv1, ser1, ser_first_stim, ser_sum_stim = calc_baseline_char(spikes, stimulus_length,
trials_nr)
print('fr stim' + str(fr1))
# hier noch das psd einer gemittelten rate
##################
# hier das mittel der psds
t2 = time.time()
print('model' + str(t2 - t1))
if fft == 'psd':
spikes_mat = [[]] * len(spikes)
pp = [[]] * len(spikes)
for s in range(len(spikes)):
spikes_mat[s] = cr_spikes_mat(spikes[s], 1 / deltat,
int(stimulus_length * 1 / deltat))
for s in range(len(spikes)):
pp[s], f = ml.psd(spikes_mat[s] - np.mean(spikes_mat[s]), Fs=1 / deltat,
NFFT=nfft,
noverlap=nfft // 2)
pp_mean = np.mean(pp, axis=0)
names = peaks_1d(fr, a_fr, beat1, freq1)
results = find_peaks_power(names, f, pp_mean, '_original', position, results)
t2 = time.time()
print('psd' + str(t2 - t1))
elif 'fft' in fft:
test = False
if test:
from utils_test import test_spikes
test_spikes()
spikes_here = np.concatenate(spikes)
mat_orig, time2, samp2 = calc_spikes_hist(trials_nr, stimulus_length,
spikes_here,
deltat=deltat)
mat_orig2, time2, samp2 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 500)
mat_orig_eod, time_eod, samp_eod = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / eod_fr)
mat_orig_02, time02, samp02 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 2000)
mat_orig_05, time05, samp05 = calc_spikes_hist(trials_nr, stimulus_length, spikes_here, deltat=1 / 5000)
sampling_rates = [sampling_rate]
if single_train == '_singletrain_':
if dev == 'original':
mats = [spikes_mat[0]]
mat_names = ['']
else:
smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate)
mats = [smoothed05[0]]
mat_names = ['']
else:
if dev == 'original':
several = False
if several == 'several':
mats = [mat_orig2, mat_orig_eod, mat_orig_02, mat_orig_05, mat_orig]
mat_names = ['_500', '_eodfr', '_2000', '_5000', '_dt']
test = False
if test:
from utils_test import test_mean
test_mean()
sampling_rates = [500, eod_fr, 2000, 5000, sampling_rate]
else:
mats = [mat_orig]
mat_names = ['_dt']
sampling_rates = [sampling_rate]
else:
smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate) # not there
mat = np.mean(smoothed05, axis=0)
mats = [mat]
mat_names = ['']
for m, mat_use in enumerate(mats):
dt = 1 / sampling_rates[m]
transient = int(transient_s / dt)
if 'fr' in freq_type:
try:
freq_type_var = 0.5 * fr
except:
print('fr0 problem')
embed()
elif 'beat' in freq_type:
freq_type_var = np.abs(beat1[0])
T = 1 / freq_type_var ## 1 Period of the external signal (1/f)
stimulus_times = int(
(len(mat_use) - transient) * dt / T) # np.arange(1, int(stimulus_length / (1 / (0.5 * fr))), 25)
try:
maximal_period = int((T / dt) * stimulus_times) # Mascha hat 476 steps
except:
print('fr1 problem')
embed()
single_period = int(np.round(T / dt)) # single_period = int(T / dt+0.5)
steps_all = [maximal_period] # , single_period]
steps_name = ['_all'] # , '_one']
for s, steps in enumerate(steps_all):
results = calc_right_core_nonlin(transient, steps_name[s], T, steps, results, mat_use, position, dt,
add_name=mat_names[m])
subsequent = ''
if 'subsequent' in subsequent:
#################################################
# for subseuqent steps
## das ist die fft berechnung!
get_cfs2(T, dt, m, mat_names, mat_use, maximal_period, position, results, single_period)
results.loc[position, 'a_fr'] = a_fr
results.loc[position, 'a_f2'] = a_f2
results.loc[position, 'a_f1'] = a_f1
results.loc[position, 'f1'] = freq1[0]
results.loc[position, 'f2'] = freq2[0]
results.loc[position, 'cell'] = cell
results.loc[position, 'max_adapt'] = np.nanmin(adapt_output)
results.loc[position, 'min_adapt'] = np.nanmax(adapt_output)
results.loc[position, 'fr'] = fr
results.loc[position, 'ser'] = ser0
results.loc[position, 'cv'] = cv0
results.loc[position, 'fr_stim'] = fr1
results.loc[position, 'ser_stim'] = ser1
results.loc[position, 'ser_sum_stim'] = ser_sum_stim
results.loc[position, 'ser_first_stim'] = ser_first_stim
results.loc[position, 'cv_stim'] = cv1
results.loc[position, 'eod_fr'] = eod_fr # 500 *30
for cv_all in cvs:
if 'cv' in cv_all:
try:
results.loc[position, cv_all] = cvs[cv_all]
except:
print('cv something')
embed()
position += 1
results.to_pickle(save_name)
return results, position
def get_cfs2(T, dt, m, mat_names, mat_use, maximal_period, position, results, single_period):
steps_all = np.arange(0, maximal_period, single_period)
c = np.zeros([len(steps_all), 4])
phi = np.zeros([len(steps_all), 4])
for s in range(len(steps_all) - 1):
c0_tmp = 0.0
a_tmp = np.zeros(4)
b_tmp = np.zeros(4)
# Calculate a and b for groundmode and higher harmonics
for j in range(4):
if j < 1:
c0_tmp = np.mean(mat_use[steps_all[s]:steps_all[s + 1]])
freq_here = (j + 1) / T
time_here_all = np.arange(0, single_period, 1) * dt
try:
a_tmp[j] = np.mean(
mat_use[steps_all[s]:steps_all[s + 1]] * np.cos(2.0 * np.pi * time_here_all * freq_here))
except:
print('a tmp problems')
embed()
b_tmp[j] = np.mean(
mat_use[steps_all[s]:steps_all[s + 1]] * np.sin(2.0 * np.pi * time_here_all * freq_here))
# Average
c0, c, phi = average_fft(a_tmp, b_tmp, c, c0_tmp, phi, s)
row, col = np.shape(c)
for c_nr in range(col):
results.loc[position, 'c_' + str(c_nr) + '_mean' + mat_names[m]] = np.mean(c[:, c_nr])
for c_nr in range(col):
results.loc[position, 'phi_' + str(c_nr) + '_mean' + mat_names[m]] = np.mean(phi[:, c_nr])
results.loc[position, 'c0' + '_mean' + mat_names[m]] = c0
def arg_left_corr(arg, right_step=3, left_step=2):
arg_right = arg + right_step
arg_left = arg - left_step
if arg_left < 0:
arg_right += np.abs(arg_left)
arg_left = 0
return arg_left, arg_right
def find_peaks_power(names, f, pp, title, position, results, a_start='a_', f_start='f_', points=5): #
for name in names:
arg = np.argmin(np.abs(f - names[name]))
if points == 5:
arg_left, arg_right = arg_left_corr(arg, right_step=3, left_step=2)
results = results.copy()
results.loc[position, a_start + name + title] = np.sqrt(
np.sum((pp[arg_left:arg_right]) * np.abs(f[1] - f[0])))
results.loc[position, f_start + name + title] = names[name]
elif points == 3:
arg_left, arg_right = arg_left_corr(arg, right_step=2, left_step=1)
results = results.copy()
results.loc[position, a_start + name + title] = np.sqrt(
np.sum((pp[arg_left:arg_right]) * np.abs(f[1] - f[0])))
results.loc[position, f_start + name + title] = names[name]
elif points == 1:
results = results.copy()
results.loc[position, a_start + name + title] = np.sqrt(pp[arg] * np.abs(f[1] - f[0]))
results.loc[position, f_start + name + title] = names[name]
return results
def plot_stimulus(ax, time_ms, beat_here, am_corr_synth, color, ylim=[-2.5, 4]):
ax.show_spines('')
ax.plot(time_ms, beat_here, color='grey', linewidth=0.5)
ax.plot(time_ms, am_corr_synth, color=color)
ax.set_xlim(0, time_ms[-1])
ax.set_ylim(ylim)
def plot_raster(ax, all_spikes, color, i, plot_segment, name='raster'):
ax.eventplot(all_spikes, orientation='horizontal',
linelengths=0.8, linewidths=1, colors=[color])
ax.show_spines('')
ax.set_xlim(0, plot_segment)
if i % 3 == 0:
ax.text(-0.05, 0.6, name, transform=ax.transAxes, rotation=90, va='center', ha='right')
def plot_peri(ax, smoothed, sampling_rate, lw_beat_corr, color, i, name='rate'):
time = np.arange(len(smoothed[-1])) / sampling_rate
mean_smoothed = np.mean(smoothed, axis=0)
ax.plot(1000 * time, mean_smoothed, linewidth=lw_beat_corr, color=color, clip_on=False)
ax.set_xlim(0, 1000 * time[-1])
ax.set_ylim(-10, 850)
ax.show_spines('')
if i % 3 == 0:
ax.text(-0.05, 0.3, name, transform=ax.transAxes, rotation=90, va='center', ha='right')
if i % 3 == 2:
ax.scalebars(1.05, 0, 10, 800, 'ms', 's$^{-1}$', ha='right', vat='bottom')
def load_cell(data, fname='singlecellexample5', big_file='beat_results_smoothed_limit35minimalduration0.3', redo=False):
if (not os.path.exists(fname + '.csv')) or (redo == True):
print('reloaded')
data_all = pd.read_pickle(load_folder_name('calc_model') + '/' + big_file + '.pkl')
just_cell = data_all[data_all['dataset'] == data]
spikes_data = just_cell[just_cell['contrasts'] == 20]
results1 = pd.DataFrame(spikes_data)
results = results1.groupby(['df']).mean()
spikes = []
for d in np.unique(results1['df']):
spikes.append(results1[results1['df'] == d].spike_times.iloc[0])
results['base'] = results1['amp_max_beat_05']
results['spikes'] = spikes
results['df'] = np.unique(results1['df'])
baseline = pd.read_pickle(load_folder_name('calc_base') + '/calc_base_data-base_frame.pkl')
baseline_cell = baseline[baseline.cell == data]
base = baseline_cell['fr'].iloc[0]
results['fr'] = base
spikes_data = results
names = ['spikes']
results.to_csv(fname + '.csv')
save_object_from_frame(results, names, fname)
else:
names = ['spikes']
spikes_data = load_object_to_pandas(names, fname)
return spikes_data
def ws_nonlin_systems():
ws = 0.15
return ws
def restrict_cell_type(cells, cell_type):
p_units_cells = []
pyramidals = []
for cell in cells:
if cell == '2021-11-05-ai-invivo-1':
p_units_cells.append(cell)
elif ('2022-02-08' not in cell) & ('2022-02-07' not in cell) & (
'2022-02-04' not in cell) & ('2022-02-03' not in cell) & (
'2021-11-11' not in cell) & ('2021-11-05' not in cell) & (
'2021-11-04' not in cell):
p_units_cells.append(cell)
else:
pyramidals.append(cell)
if cell_type == 'p-units':
cells = p_units_cells
else:
cells = pyramidals
return cells, p_units_cells, pyramidals
def plt_peaks_several(freqs, p_arrays, axs_p, p0_means, fs, labels=None, j=1, colors=None, emb=False, marker='o',
markeredgecolors=None, zorder=2, ha='left', add_texts=None, limit=None, texts_left=None,
add_log=2,
rots=None, several_peaks_nr=2, exact=True, text_extra=False, perc_peaksize=0.04, rel='rel',
alphas=None,
extend=False, ms=25, clip_on=False, several_peaks=True, alpha=1, log='', add_not_log = 0):
df_passed = [] # ,
p_passed = []
for ff, f in enumerate(range(len(freqs))):
add, add_text = get_add_for_several_peaks(add_log, df_passed, emb, exact, f, freqs, j, log, p_arrays,
perc_peaksize, rel, add_not_log = add_not_log)
if alphas is not None:
if len(alphas)> 0:
alpha = alphas[f]
if rots is not None:
if len(rots)>0:
rot = rots[f]
else:
rot = 45
else:
rot = 45
if add_texts is None:
add_text = 0
else:
add_text = add_texts[f]
print('extraf' + str(add_texts[f]))
if texts_left is not None:
if len(texts_left)>0:
text_left = texts_left[f]
print('extraf' + str(add_texts[f]))
else:
text_left = 0
else:
text_left = 0
try:
if colors is None:
color = 'black'
else:
color = colors[f]
except:
print('colors something')
embed()
if markeredgecolors is not None:
if len(markeredgecolors)> 0:
try:
markeredgecolor = markeredgecolors[f]
except:
print('marker something')
embed()
else:
markeredgecolor = color
else:
markeredgecolor = color
if labels is None:
label = ''
else:
label = labels[f]
#embed()
f_scatter, p_scatter = plt_peaks(axs_p, p0_means, freqs[f], fs, fr_color=color, s=ms,
label=label, marker=marker, zorder=zorder, markeredgecolor=markeredgecolor, ha=ha,
several_peaks_nr=several_peaks_nr, limit=limit, rot=rot, text_left=text_left, add_text=add_text,
text_extra=text_extra, extend=extend, add=add, alpha=alpha, clip_on=clip_on,
several_peaks=several_peaks)
df_passed.append(int(freqs[f]))
p_passed.append(p_scatter)
return p_passed
def get_add_for_several_peaks(add_log, df_passed, emb, exact, f, freqs, j, log, p_arrays, perc, rel, add_not_log = 0, j_extra=False):
if rel == 'rel':
peak_in_peaks = check_if_peak_occured(df_passed, exact, f, freqs)
if peak_in_peaks:
count = count_of_occurance(df_passed, exact, f, freqs)
if log == 'log':
add = (np.max(np.max(p_arrays)) * perc) + add_log * count
else:
add = (np.max(np.max(p_arrays)) * perc * count)+add_not_log*count
add_ext = 30
add_text = add + add_ext * count
if j_extra: # ich wüsste nicht warum man das brauchen würde
if j == 0:
add = (np.max(np.max(p_arrays)) * perc * count)
add_text = add + add_ext * count
if emb:
embed()
else:
add = (np.max(np.max(p_arrays)) * 0.01) # * 0.01
add_text = (np.max(np.max(p_arrays)) * perc) # * 0.01
else:
if int(freqs[f]) in df_passed:
add = 60
add_text = add + 30
if j_extra:
if j == 0:
add = (np.max(np.max(p_arrays)) * perc) # * 0.25
add_text = add + 30
else:
add = 30
add_text = (np.max(np.max(p_arrays)) * perc) # * 0.01
return add, add_text
def count_of_occurance(df_passed, exact, f, freqs):
count = 0
for df in df_passed:
if exact == True:
if int(freqs[f]) == df:
count += 1
else:
count = np.sum(np.abs(df_passed - freqs[f]) < 10)
return count
def check_if_peak_occured(df_passed, exact, f, freqs):
if exact == True:
peak_in_peaks = int(freqs[f]) in df_passed
else:
if len(df_passed) == 0:
peak_in_peaks = False
else:
peak_in_peaks = np.min(np.abs(df_passed - freqs[f])) < 10
return peak_in_peaks
def plt_peaks(ax, p, fr, f_axis, several_peaks_nr=2, zorder=2, marker='o', markeredgecolor=None, several_peaks=True,
ha='left', limit=None, text_left=0, rot=45, add_text=0, text_extra=False, extend=True, fr_color='grey',
add=0, s=12, label='',f_scatter = None, p_scatter = None, alpha=1, clip_on=False):
# DAS ist die RICHTIGE Variante
if fr < f_axis[-1]:
minimum = np.argmin(np.abs(fr - f_axis))
#embed()
try:
if several_peaks:
# das machen wir in der Regel bei Power Spektren
max_pos, minimums = chose_beat_peak(minimum, p, several_peaks_nr)
else:
# das hier in den anderen Fällen
minimums = [minimum]
max_pos = np.argmax(p[minimums])
except:
print('maxima things')
embed()
new_f = minimums[max_pos]
max_pos_neg = np.argmin(p[minimums])
new_f_neg = minimums[max_pos_neg]
if not markeredgecolor:
markeredgecolor = fr_color
try:
if p[new_f] > p[new_f_neg]:
f_scatter = f_axis[new_f]
p_scatter = p[new_f]
if limit:
if p_scatter > limit:
cont = True
else:
cont = scatter_peaks()
else:
cont = True
if cont:
if label != '':
ax.scatter(f_scatter, p_scatter + add, color=fr_color, zorder=zorder, s=s, label=label,
clip_on=clip_on,
alpha=alpha, edgecolor=markeredgecolor, marker=marker)
else:
ax.scatter(f_scatter, p_scatter + add, color=fr_color, zorder=zorder, s=s, clip_on=clip_on,
alpha=alpha, marker=marker, edgecolor=markeredgecolor)
if extend:
ax.plot(f_axis[new_f - 2: new_f + 3], p[new_f - 2: new_f + 3], color=fr_color, alpha=0.5,
zorder=100)
if text_extra: # +add_text
ax.text(f_scatter - text_left, p_scatter + add + add_text, label, ha=ha, rotation=rot,
color=fr_color)
else:
try:
max_pos = np.argmin(p[minimums])
new_f = minimums[max_pos]
except:
new_f = minimum # minimums[minimum]
f_scatter = f_axis[new_f]
p_scatter = p[new_f]
if limit:
if p_scatter > limit:
cont = True
else:
cont = False
else:
cont = True
if cont:
if label != '':
ax.scatter(f_scatter, p_scatter - add, color=fr_color, zorder=2, s=s, label=label, clip_on=clip_on,
alpha=alpha, edgecolor=markeredgecolor, marker=marker)
else:
ax.scatter(f_scatter, p_scatter - add, color=fr_color, zorder=2, s=s, clip_on=clip_on, alpha=alpha,
edgecolor=markeredgecolor, marker=marker)
if extend:
ax.plot(f_axis[new_f - 2:new_f + 3], p[new_f - 2:new_f + 3], color=fr_color, alpha=0.5, zorder=100)
if text_extra: # +add_text
ax.text(f_scatter + 4, p_scatter - add + 2 + add_text, label, rotation=rot, color=fr_color, ha=ha)
except:
print('peaks thing inside')
embed()
return f_scatter, p_scatter
def chose_beat_peak(minimum, p, several_peaks_nr):
# das machen wir in der Regel bei Power Spektren
try:
minimum_array = [minimum] * (several_peaks_nr * 2 + 1)
minus_array = np.arange(0, several_peaks_nr * 2 + 1, 1) - several_peaks_nr
minimums = minimum_array + minus_array # [minimum - 2, minimum - 1, minimum, minimum + 1, minimum + 2]
max_pos = np.argmax(p[minimums])
except:
minimums = [minimum]
max_pos = np.argmax(p[minimums])
return max_pos, minimums
def scatter_peaks():
cont = False
return cont
def calc_beat_spikes(final_eod, sampling_rate, final_DF, i, cell, plus_bef, minus_bef, version='spikes', data_beat=[],
trial_nr=0):
ll = np.abs(plus_bef)
ul = np.abs(minus_bef)
df = final_DF[i]
eod = final_eod[i]
len_smoothed = []
len_smoothed_b = []
if version == 'spikes':
if len(data_beat[data_beat['df'] == df]['spikes']) == 1:
tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[0])
if len(data_beat[data_beat['df'] == df]['spikes'].iloc[0]) == 1:
tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[0][0])
else:
tranformed_spikes = np.array(data_beat[data_beat['df'] == df]['spikes'].iloc[trial_nr])
size = int(tranformed_spikes[-1] * sampling_rate + 5) # duration.iloc[0]
spikes_mat = np.zeros(size)
spikes_idx = np.round(tranformed_spikes * sampling_rate)
for spike in spikes_idx:
spikes_mat[int(spike)] = 1 * sampling_rate
smoothed = gaussian_filter(spikes_mat, sigma=gaussian_intro() * sampling_rate)
else:
spikes_mat = []
spikes = cell[cell['df'] == df]['local']
if len(spikes) == 1:
tranformed_spikes = np.array(spikes.iloc[0])
else:
tranformed_spikes = np.array(spikes.iloc[trial_nr])
smoothed = tranformed_spikes * 1
smoothed[smoothed < 0] = 0
_, _ = ml.psd(smoothed ** 3 - np.mean(smoothed ** 3), Fs=sampling_rate, NFFT=2 ** 15,
noverlap=2 ** 14)
plot_segment = ul - ll
_, _ = ml.psd(smoothed - np.mean(smoothed), Fs=sampling_rate, NFFT=4096,
noverlap=4096 // 2)
corr = create_beat_corr2(df, eod)
# den Beat nehmen wir aus den Daten als das local EOD
time = np.arange(0, len(smoothed) / sampling_rate, 1 / sampling_rate)
beat_version = 'sumu'
if beat_version == 'local':
beat = data_beat[data_beat['df'] == df]['local']
if len(beat) == 1:
beat = np.array(beat.iloc[0])
if len(beat) == 1:
beat = np.array(beat.iloc[0][0])
else:
beat = np.array(beat.iloc[trial_nr])
else:
if len(data_beat[data_beat['df'] == df]['efield']) == 1:
efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[0])
if len(efield) == 1:
efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[0][0])
else:
efield = np.array(data_beat[data_beat['df'] == df]['efield'].iloc[trial_nr])
efield = zenter_and_normalize(efield, 0.2)
if len(data_beat[data_beat['df'] == df]['global']) == 1:
global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[0])
if len(global_eod) == 1:
global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[0][0])
else:
global_eod = np.array(data_beat[data_beat['df'] == df]['global'].iloc[trial_nr])
global_eod = zenter_and_normalize(global_eod, 1)
beat = global_eod + efield
if 'ds' in data_beat.keys():
ds = int(data_beat.ds.iloc[0])
time_beat = np.arange(0, ds * len(beat) / sampling_rate, 1 / sampling_rate)
beat = interpolate(time_beat[::ds], beat, time_beat, kind='cubic')
beat3 = beat * 1
beat3[beat3 < 0] = 0
_, _ = ml.psd(beat3 ** 3 - np.mean(beat3 ** 3), Fs=sampling_rate, NFFT=2 ** 15,
noverlap=2 ** 14)
# period bestimmen wir lieber aus dem corr, weil das f_max ist nfft abhäängig
period = 1 / corr
if period < plot_segment:
pass
else:
pass
# und diese Shifts die sollten hatl ja die Länge des segments haben und keine 0.05 Sekunden..
###################################
# ich mache ein festes fenster also habe ich einen schift der in einem sehr kleinen schritt durchgeht
# das period 2 hätte ich wenn das Fenster immer die gleiche länge hätte
shift_period = 0.005 # period * 2#
shifts = np.arange(0, 200 * shift_period, shift_period)
time_b = np.arange(0, len(beat) / sampling_rate, 1 / sampling_rate)
am_corr = extract_am(beat, time_b, eodf=eod, norm=False, extract='globalmax', kind='cubic')[0]
len_smoothed, smoothed_trial, all_spikes, maxima, error, spikes_cut, beat_cut, am_corr_cut = create_shifted_spikes(
eod, len_smoothed_b, len_smoothed, beat, am_corr, sampling_rate, time_b, time, smoothed, shifts, plot_segment,
tranformed_spikes, version=version)
am_final, beat_final, most_similiar, spike, spike_sm = get_most_similiar_spikes(all_spikes, am_corr_cut, beat_cut,
error, maxima, spikes_cut)
test = False
if test:
from utils_test import test_spikes
test_spikes()
test = False
if test == True:
from utils_test import test_maximal
test_maximal()
return am_final, beat_final, smoothed, tranformed_spikes, spike_sm, spike, spikes_mat, plot_segment
def gaussian_intro():
return 0.001
def plot_power(ax, stim_f, spikes_mat, sampling_rate, main_color, eod, color, i, ms=3, nfft=4096 * 4):
p, f = ml.psd(spikes_mat - np.mean(spikes_mat), Fs=sampling_rate, NFFT=nfft, noverlap=nfft // 2)
db = 10 * np.log10(p / np.max(p))
ax.plot(f, db, zorder=1, color=main_color, linewidth=1)
maxi = np.argmax(db[f < 0.5 * eod[i]])
# hier habe ich die eine Funktion wo man nur die Frequenzen und Farben reingibt und die kümmert sich um die Punkte, dass sie
# nicht überlappen etc.
xlim_max = 1000
if stim_f < xlim_max:
freqs = [eod[i], f[maxi], stim_f]
else:
freqs = [eod[i], f[maxi]]
plt_peaks_several(freqs, [db], ax, db, f, ['', '', ''], 5, ['white', color, 'black'],
markeredgecolors=['black', color, 'black'], add_log=0.1, several_peaks_nr=4, rel='rel', ms=ms,
clip_on=False, log='log')
ax.axvline(x=eod[i] / 2, color='black', linestyle='dashed', lw=0.5)
ax.set_xlim(0, xlim_max)
ax.set_ylim(-20, 10)
ax.show_spines('b')
ax.set_xticks_delta(500)
if i >= 3:
ax.set_xlabel('Frequency [Hz]')
else:
ax.set_xticks_blank()
if i % 3 == 0:
ax.text(-0.05, 0.5, 'psd', transform=ax.transAxes, rotation=90, va='center', ha='right')
if i % 3 == 2:
ax.yscalebar(1.05, 0.0, 20, 'dB', ha='right')
def plt_RAM_explained_single3():
plot_style()
cells = ["2012-06-27-ah-invivo-1"] # ,"2013-01-08-aa-invivo-1" , "2014-06-06-ac-invivo-1"]
for run in range(1):
default_figsize(column=2, length=2.5)
grid = gridspec.GridSpec(1, 1, wspace=0.35, left=0.1, top=0.95, bottom=0.13, right=0.87,
hspace=0.35)
grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.5, hspace=0.4,
subplot_spec=grid[0])
for _, _ in enumerate(cells):
''' arrays, arrays2, colors, deltat, spikes = get_RAM_stimulus(cell, exp_tau, exponential, lower_tol,
model_cells, upper_tol, v_exp)
##################################################
plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes)'''
##################################################
##################################
# model part
ax = plt.subplot(grid2[:])
perc, im, stack_final, stack_saved = plt_model_big(ax)
set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95)
plt.subplots_adjust(hspace=0.85, wspace=0.25)
save_visualization(str(run), False, counter_contrast=0, savename='')
def plt_RAM_explained_single2(exponential=''):
plot_style()
default_settings(width=12) # , ts=12, ls=13, fs=11
cells = ["2012-06-27-ah-invivo-1"] # ,"2013-01-08-aa-invivo-1" , "2014-06-06-ac-invivo-1"]
model_cells = pd.read_csv(load_folder_name('calc_model_core') + "/models_big_fit_d_right.csv")
for run in range(1):
default_settings(column=2, length=3)
grid = gridspec.GridSpec(1, 2, wspace=0.35, left=0.1, top=0.95, bottom=0.16, right=0.87,
hspace=0.35)
grid1 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0.5, hspace=0.02,
subplot_spec=grid[0])
grid2 = gridspec.GridSpecFromSubplotSpec(2, 1, wspace=0.5, hspace=0.4,
subplot_spec=grid[1])
for c, cell in enumerate(cells):
arrays, arrays2, colors, deltat, spikes = get_RAM_stimulus(cell, exponential, model_cells)
##################################################
plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes)
##################################################
##################################
# model part
ax = plt.subplot(grid2[:])
perc, im, stack_final, stack_saved = plt_model_big(ax)
set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95)
fig = plt.gcf()
fig.tag([fig.axes[1], fig.axes[-2]], xoffs=-5.5, yoffs=0.7)
plt.subplots_adjust(hspace=0.85, wspace=0.25)
save_visualization(str(run), False, counter_contrast=0, savename='')
def plt_RAM_arrays(arrays, arrays2, colors, deltat, grid1, spikes):
ax = plt.subplot(grid1[1])
ax.eventplot(np.array(spikes) * 1000, color='black')
ax.show_spines('')
rx = 0.1 * 1000
ax.set_xlim(0, rx)
ylabel = ['', '', 'Firing Rate [Hz]']
for i in range(len(arrays)):
if arrays[i] != '':
ax = plt.subplot(grid1[i])
try:
ax.plot(np.arange(0, len(arrays[i]) * deltat, deltat) * 1000, arrays[i], color=colors[i])
except:
print('arrays problem')
embed()
if arrays2[i] != '':
ax.plot(np.arange(0, len(arrays2[i]) * deltat, deltat) * 1000, arrays2[i], color='black')
ax.set_xlim(0, rx)
if i < 2:
ax.show_spines('')
remove_xticks(ax)
remove_xticks(ax)
ax.set_ylabel(ylabel[i])
ax.set_xlabel('Time [ms]')
def phaselocking_loss2(show=True):
_, _ = find_all_dir_cells()
data_names = ['2019-09-10-ae-invivo-1']
plot_style()
default_figsize(column=2, length=3.5)
for data_name in data_names:
print(data_name)
#############################################
# print traces
name_core = load_folder_name('data') + 'cells/' + data_name
nix_name = name_core + '/' + data_name + '.nix' # '/'
if os.path.exists(name_core):
f = nix.File.open(nix_name, nix.FileMode.ReadOnly)
nix_there = True
if nix_there:
b = f.blocks[0]
all_mt_names = find_mt_all(b)
ts = find_tags_list(b, names='ficurve')
if len(ts) > 0:
for n, names_mt_gwn in enumerate(all_mt_names):
if ('rectangle' in names_mt_gwn) | ('FI' in names_mt_gwn):
mt = b.multi_tags[names_mt_gwn]
features, delay_name = feature_extract(mt, )
Intensity, preIntensity, contrasts, precontrasts = find_contrasts(features, mt)
if len(np.shape(contrasts)) > 1:
contrasts = np.concatenate(contrasts)
negativ = 'negativ' # 'positiv'#'highest'#'negativ' # 'positiv'
val = 31
indeces_show = np.arange(0, len(contrasts), 1)[
(contrasts > val - 3) & (contrasts < val + 3)] ##[np.argsort(contrasts)[-1]]
save_name = load_folder_name('calc_FI_Curve') + '\FI5_with_f0_nfft_16384' # FI_with_f0'
frame = pd.read_csv(save_name + '.csv')
names_all = [['ss_s', 'ss_r']] # , [['ss_s', 'on_s']] # , 'on_s', 'on_r',
linestyles = ['-', '--', '-', '--', '-', '--', '-', '--']
_, _ = find_row_col(frame.cell.unique())
axes = []
grid0 = gridspec.GridSpec(1, 2, bottom=0.15, top=0.92, left=0.1,
right=0.98,
wspace=0.27, hspace=0.45, width_ratios=[2, 1.3]) #
gridr = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid0[1],
hspace=0.5)
frame_cv = pd.read_pickle(
load_folder_name('calc_base') + '/calc_base_data-base_frame.pkl')
not_frame_punits_cells = frame_cv[
frame_cv['cell_type_reclassified'] != ' P-unit'].cell.unique()
cells_saved = frame.cell.unique()
c_counter = 0
cells_p_units = np.setdiff1d(list(cells_saved), list(not_frame_punits_cells))
frame_cell = frame[frame.cell == data_name]
frame_cell['contrasts'] = np.round(frame_cell['contrasts'])
frame_cell = frame_cell.groupby('contrasts', as_index=True).mean()
sort_order = np.argsort(frame_cell.index)
logs = [''] # , 'log'
lables = ['Onset State Curve', 'Steady State Curve'] #
coloro = 'red'
colors = 'green'
delta = 200
for l in range(len(logs)):
my_curves = 'chirp'
if my_curves == 'new':
for nn, names in enumerate(names_all):
ax = plt.subplot(gridr[nn, l])
for n, name in enumerate(names):
contrast_labels = frame_cell.index[sort_order] / 100
flip = False
if flip:
steady_func = frame_cell[name].iloc[sort_order][::-1]
ax.scatter(-contrast_labels,
steady_func,
color=colors[n],
linestyle=linestyles[n]) # label=name,
else:
steady_func = frame_cell[name].iloc[sort_order]
ax.scatter(contrast_labels,
steady_func, color=colors[n],
linestyle=linestyles[n]) # label=name,
try:
plt_FI_curve(contrast_labels, bolzmann_steady, steady_func,
label=lables[n], color=colors[n])
except:
print('color something')
embed()
if logs[l] == 'log':
ax.set_xscale(logs[l])
else:
ax.axvline(0, color='grey', linestyle='--', linewidth=0.5)
if l == 0:
ax.legend(ncol=2, loc=(0, 1.05))
ax.set_ylabel('Firing Frequency [Hz]')
else:
remove_yticks(ax)
elif my_curves == 'chirp':
onset_state, steady_state, sorted_contrast, steady, onset, indices, mean_indices = np.load(
load_folder_name(
'calc_FI_Curve') + '/F_I_curve-distances5st_nm_w2_dm_alpha_consp2_bnr6_ROCsFI_cells.npy',
allow_pickle=True)
ax = plt.subplot(gridr[0])
ax.plot(sorted_contrast[data_name],
steady_function(sorted_contrast[data_name], steady[data_name][0],
steady[data_name][1],
steady[data_name][2]), zorder=1,
label='Fitted function for the steady F-I Cure', color='black')
ax.plot(sorted_contrast[data_name],
onset_function(sorted_contrast[data_name], onset[data_name][0],
onset[data_name][1], onset[data_name][2]),
label='Fitted function for the onset F-I Cure', zorder=1, color='black')
s = 30
steady_val = np.array(steady_state[data_name])[sorted_contrast[data_name] == val]
onset_val = np.array(onset_state[data_name])[sorted_contrast[data_name] == val]
ax.scatter(sorted_contrast[data_name], onset_state[data_name], s=s, clip_on=False,
color=coloro, zorder=120, alpha=0.5) # color='black',
ax.scatter(sorted_contrast[data_name], steady_state[data_name], s=s, clip_on=False,
color=colors, zorder=120, alpha=0.5) # color='grey',
ax.scatter(sorted_contrast[data_name][sorted_contrast[data_name] == val], onset_val,
s=s, color=coloro,
clip_on=False, zorder=100, alpha=1, edgecolor='black')
ax.scatter(sorted_contrast[data_name][sorted_contrast[data_name] == val],
steady_val, s=s, color=colors,
clip_on=False, zorder=100, alpha=1, edgecolor='black')
ax.set_yticks_delta(delta)
print('onsetsnip' + str(onset_val))
print('steadysnip' + str(steady_val))
else:
data = pd.read_csv(
'../data/Kennlinien/cell_fi_curves_csvs/' + cell + '.csv') # not found
''' df_cell['inputs'].iloc[0]
df_cell['on_r'].iloc[0]
df_cell['ss_r'].iloc[0]
df_cell['on_s'].iloc[0]
df_cell['ss_s'].iloc[0]
data['contrasts']'''
sort_data = np.argsort(data['contrasts'])
plt.scatter(data['contrasts'][sort_data],
(data['f_onset'][sort_model] - ymin) / ymax,
color='black', s=7, zorder=2) # not found
colors = ['green', 'blue', 'orange', 'pink', 'purple', 'red']
plt.xlabel('Contrast [%]')
plt.ylabel('Firing Rate [Hz]')
ax.set_xlabel('Contrast [$\%$]')
ax.set_ylabel('Firing Rate [Hz]')
axes.append(ax)
spike_mats = []
smootheneds = []
for idx, mt_idx in enumerate(indeces_show): # range(len(mt.positions[:]))
print(idx)
gridl = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=grid0[0],
hspace=0.1, wspace=0.35)
try:
delay_orig = mt.features[delay_name].data[:][mt_idx][0]
except:
delay_orig = mt.features[delay_name].data[:][mt_idx]
delay = delay_and_reality_alignment(mt, mt.extents[mt_idx], mt_idx,
mt.extents[mt_idx])
negativ = 0.5
if delay < negativ:
negativ = delay
if negativ < 0:
negativ = delay_orig
if mt_idx != len(mt.extents) - 1:
try:
pass
except:
print('positiv thing')
embed()
positive = 0.5
if delay < positive:
positive = delay
else:
positive = delay
# ich glaube da gibts porbleme wenn das mt davor oder danach negativ war
# deswegen werden beide in dem fall einfach zum delay!
if positive < 0:
positive = negativ * 1
if positive < 0:
positive = delay_orig
print('positive thing')
embed()
duration = mt.extents[mt_idx][0]
if mt.extents[mt_idx] > 0:
f_snippet = np.min([duration / 2, negativ, positive])
if f_snippet < 0:
print('snippet thing')
embed()
if f_snippet > 0.1:
start_time = mt.positions[mt_idx] - negativ
eod_g, spikes_mt, sampling = link_arrays(b, start_time,
duration + negativ + positive,
start_time,
load_eod_array='LocalEOD-1') # 'EOD'
spike_mats.append(spikes_mt)
v1, sampling = link_arrays_eod(b, start_time,
duration + negativ + positive,
array_name='V-1')
eod_field, sampling = link_arrays_eod(b, start_time,
duration + negativ + positive,
array_name='GlobalEFieldStimulus')
spikes_mat = cr_spikes_mat(spikes_mt, sampling, int((mt.extents[
mt_idx] + negativ + positive) * sampling))
smoothened = gaussian_filter(spikes_mat, sigma=0.001 * sampling)
smootheneds.append(smoothened)
dt = 1 / sampling
axs = []
xlim = [-0.1 * 1000, 0.5 * 1000]
ax = plt.subplot(gridl[0])
axes.append(ax)
ax.set_xlim(xlim)
time_array = np.arange(0, len(eod_g) * dt, dt) - negativ
time_fish_e = time_array * 2 * np.pi * 750 # eod_fe[e]
eod_g = 100 * np.sin(time_fish_e)
eod_g[(time_array > 0) & (time_array < 0.4)] = eod_g[(time_array > 0) & (
time_array < 0.4)] * ((100 + val) / 100)
ax.plot(time_array * 1000, eod_g, color='grey', linewidth=0.5) # 0.2
axs.append(ax)
ax.set_ylabel('Contrast [$\%$]')
remove_xticks(ax)
ax.show_spines('l')
ax.set_title('Contrast\,$=%s$' % val + '\,$\%$')
ax = plt.subplot(gridl[1])
axs.append(ax)
ax.set_xlim(xlim)
time_array = np.arange(0, len(v1) * dt, dt) - negativ
ax.eventplot((spike_mats - negativ) * 1000, color='black', linewidths=0.3)
ax.show_spines('l')
ax.set_ylabel('Trials')
remove_xticks(ax)
ax = plt.subplot(gridl[-1])
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Firing Rate [Hz]')
smoothed_mean = np.mean(smootheneds, axis=0)
time_cut = time_array[0: len(np.mean(smootheneds, axis=0))] * 1000
ax.plot(time_cut, smoothed_mean, color='black')
ax.set_xlim(xlim)
ax.axhline(np.mean(smoothed_mean[time_cut < 1]), color='grey', linewidth=0.5)
ax.set_yticks_delta(delta)
minus = 15
onset_snip = smoothed_mean[(time_cut > start_fi_o()) & (time_cut < end_fi_o() - minus)]
steady_snip = smoothed_mean[
(time_cut > (300 + start_fi_s())) & (time_cut < (300 + end_fi_s()))]
print('onsetsnip' + str(np.mean(onset_snip)))
print('steadysnip' + str(np.mean(steady_snip)))
ax.plot(time_cut[(time_cut > start_fi_o()) & (time_cut < end_fi_o() - minus)], onset_snip,
color=coloro)
ax.plot(time_cut[(time_cut > (300 + start_fi_s())) & (time_cut < (300 + end_fi_s()))],
steady_snip,
color=colors)
c_counter += 1
print(show)
fig = plt.gcf()
fig.tag(axes[::-1], xoffs=-3.5, yoffs=1.8)
save_visualization(
data_name + '_idx_' + str(mt_idx) + '_contrast_' + str(contrasts[mt_idx]),
show)
print('finished plotting')
def plt_model_small(ax, pos_rel=-0.07, ls='--', lw=0.5, cell='2012-07-03-ak-invivo-1', colorx='black', colory='black'):
cells_given = [cell]
# doch das müsste jetzt mit denen hier funkionieren
save_name_rev = load_folder_name(
'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_revQuadrant_'
save_name = load_folder_name(
'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
cell_add, cells_save = find_cell_add(cells_given)
perc = 'perc'
path_rev = save_name_rev + '.pkl' # '../'+
path = save_name + '.pkl' # '../'+
# path_rev = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1__revQuadrant_2012-07-03-ak-invivo-1.csv'
# path = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1_2012-07-03-ak-invivo-1.csv'
#
stack_rev, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path_rev, save_name_rev, redo=False,
creation_time_update=True, size_update=True) # direct_load = True,
stack, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path, save_name) # direct_load = True
# embed()
full_matrix = create_full_matrix2(np.array(stack), np.array(stack_rev))
stack_final = stack#get_axis_on_full_matrix(full_matrix, stack)
stack_final, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_final,
add_nonlin_title='k') # , add_nonlin_title = 'k'
add_nonlin_title = ''
im = plt_RAM_perc(ax, perc, np.abs(stack_final))
set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95)
set_xlabel_arrow(ax, xpos=1, ypos=pos_rel, color=colorx)
set_ylabel_arrow(ax, xpos=pos_rel, ypos=0.97, color=colory)
cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01)
cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8)
ax.axhline(0, color='white', linewidth=lw, linestyle=ls)
ax.axvline(0, color='white', linewidth=lw, linestyle=ls)
return perc, im, stack_final
def plt_model_big(ax, pos_rel=-0.07, ls='--', lw=0.5, cell='2012-07-03-ak-invivo-1', colorx='black', colory='black', lines = True):
cells_given = [cell]
# doch das müsste jetzt mit denen hier funkionieren
save_name_rev = load_folder_name(
'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV_revQuadrant_'
save_name = load_folder_name(
'calc_model') + '/' + 'calc_RAM_model-2__nfft_whole_power_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TrialsStim_1000000_a_fr_1__trans1s__TrialsNr_1_fft_o_forward_fft_i_forward_Hz_mV'
cell_add, cells_save = find_cell_add(cells_given)
perc = 'perc'
path_rev = save_name_rev + '.pkl' # '../'+
path = save_name + '.pkl' # '../'+
# path_rev = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1__revQuadrant_2012-07-03-ak-invivo-1.csv'
# path = 'model_full_nfft_whole_p_1_RAM_additiv_cv_adapt_factor_scaled_cNoise_0.1_cSig_0.9_cutoff1_300_cutoff2_300no_sinz_length1_TS_1000000_a_fr_1__TrialsNr_1_2012-07-03-ak-invivo-1.csv'
#
stack_rev, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path_rev, save_name_rev, redo=False,
creation_time_update=True, size_update=True) # direct_load = True,
stack, stack_saved = get_stack_one_quadrant(cell, cell_add, cells_save, path, save_name, redo = True) # direct_load = True
#embed()
full_matrix = create_full_matrix2(np.array(stack), np.array(stack_rev))
stack_final = get_axis_on_full_matrix(full_matrix, stack)
add_nonlin_title = ''
stack_final, add_nonlin_title, resize_val = rescale_colorbar_and_values(stack_final,
add_nonlin_title= add_nonlin_title) # , add_nonlin_title = 'k'
im = plt_RAM_perc(ax, perc, np.abs(stack_final))
set_clim_same([im], mats=[np.abs(stack_final)], lim_type='up', nr_clim='perc', clims='', percnr=95)
set_xlabel_arrow(ax, xpos=1, ypos=pos_rel, color=colorx)
set_ylabel_arrow(ax, xpos=pos_rel, ypos=0.97, color=colory)
cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01)
cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8)
#embed()
if lines:
ax.axhline(0, color='white', linewidth=lw, linestyle=ls)
ax.axvline(0, color='white', linewidth=lw, linestyle=ls)
stack_final = stack_final * resize_val
return perc, im, stack_final, stack_saved
def FI_curves_plot(contrast_labels, onset_function, onset_state, steady_function, steady_state,
label='Fitted function for the onset F-I Cure'):
plt_FI_curve(contrast_labels, steady_function, steady_state)
plt.scatter(contrast_labels, onset_state, color='black')
params_o, params_covariance = optimize.curve_fit(onset_function, contrast_labels, onset_state, bounds=(
[0, -np.inf, - np.inf], [np.max(onset_state) * 4, np.inf, np.inf]))
plt.plot(contrast_labels, onset_function(contrast_labels, params_o[0], params_o[1], params_o[2]),
label=label, color='black')
plt.ylim([0, 720])
plt.xlabel('Contrasts [%]', labelpad=10)
plt.ylabel('Firing Frequency [Hz]')
plt.legend(['Steady State Curve', 'Onset State Curve'])
def plt_FI_curve(contrast_labels, steady_function, steady_state, color='grey',
label='Fitted function for the steady F-I Cure'):
plt.scatter(contrast_labels, steady_state, color='grey')
params_s, params_covariance = optimize.curve_fit(steady_function, contrast_labels, steady_state, bounds=(
[0, -np.inf, - np.inf], [np.max(steady_state) * 4, np.inf, np.inf]))
plt.plot(contrast_labels, steady_function(contrast_labels, params_s[0], params_s[1], params_s[2]),
label=label, color=color)
def second_saturation_freq():
return (39.5,
-10.5)
def time_nonlin_first_sine(sampling=20000, f0=40, duration=0.1, amp = 1):
delta = 1 / sampling
time_array = np.arange(0, duration, 1 / sampling)
time_s = time_array * 2 * np.pi * f0
sine = np.sin(time_s)*amp
return delta, f0, sine, time_array
def second_sine(f0, time_array, amp=1.25, phase=1):
time_s = time_array * 2 * np.pi * f0
sine = amp * np.sin(time_s + phase)
return sine
def circle_plot(ax, ax_prev, ws=None, lw=1.5):
if not ws:
ws = ws_nonlin_systems()
#rectangle = plt.Circle((0, 0), fc='black', ec="black")
rectangle = plt.Rectangle((0, 0), 20, 20, fc='black', ec="black")
ax.add_patch(rectangle)
#ax.set_title('$H\{s(t)\}$')
#embed()
ax.show_spines('')
def rectangle_plot(ax, ax_prev, ws=None, lw=1.5):
if not ws:
ws = ws_nonlin_systems()
rectangle = plt.Rectangle((0, 0), 20, 20, fc='black', ec="black")
ax.add_patch(rectangle)
ax.annotate('', ha='center', xycoords='axes fraction',
xy=(1 + ws, 0.5), textcoords='axes fraction',
xytext=(1, 0.5),
arrowprops={"arrowstyle": "->",
"linestyle": "-",
"linewidth": lw,
"color":
'black'},
zorder=1, annotation_clip=False, transform=ax_prev.transAxes, )
ax.set_title('$H\{s(t)\}$')
ax.show_spines('')
def base_csvs_save(cell, frame=[], load_folder='calc_base'):
if len(frame) < 1:
path_sascha = load_folder_name('calc_base') + '/' + 'calc_base_data-base_frame_nfftmedium__overview.pkl'
frame = pd.read_pickle(path_sascha)
frame_c = frame[frame.cell == cell]
frame_cell = pd.DataFrame()
spikes_all, isi, frs_calc, spikes_cont = load_spikes(np.array(frame_c.spikes.iloc[0]), 1, ms_factor=1)
spikes_all, pos_reshuffled = reshuffle_spike_lengths(spikes_all)
save_spikestrains_several(frame_cell, spikes_all)
frame_cell['fr'] = frame_c.fr.iloc[0]
if len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 2:
vars = [frame_c['EODfs_medium'].iloc[0][0], frame_c['freq_steps_medium'].iloc[0][0],
frame_c['EODfs'].iloc[0][0], frame_c['freq_steps_trial'].iloc[0][0]]
elif len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 1:
vars = [frame_c['EODfs_medium'].iloc[0], frame_c['freq_steps_medium'].iloc[0],
frame_c['EODfs'].iloc[0], frame_c['freq_steps_trial'].iloc[0]]
elif len(np.shape(frame_c['freq_steps_medium'].iloc[0])) == 3:
vars = [frame_c['EODfs_medium'].iloc[0][0][0], frame_c['freq_steps_medium'].iloc[0][0][0],
frame_c['EODfs'].iloc[0][0][0], frame_c['freq_steps_trial'].iloc[0][0][0]]
names = ['EODf_res', 'freq_step_res', 'EODf', 'freq_step_trial']
frame_cell = reshuffle_eodfs(frame_cell, names, pos_reshuffled, vars)
lim = find_lim_here(cell, 'individual')
frame_cell['burst_corr_individual'] = float('nan')
frame_cell['burst_corr_individual'].iloc[0] = lim
frame_cell['sampling'] = frame_c.sampling.iloc[0]
frame_cell['cell'] = frame_c.cell.iloc[0]
frame_cell['eod_fr'] = frame_c.EODf.iloc[0]
save = True # .iloc[0]
if save:
frame_cell.to_csv(load_folder + '/base_csvs_save-spikesonly_' + cell + '.csv')
del frame
del frame_cell
return frame_c
def reshuffle_eodfs(frame_cell, names, pos_reshuffled, vars, res_name='arbitrary'):
stack_sp = {}
for v, var in enumerate(vars):
if (res_name not in names[v]) & ('all' not in names[v]):
try:
var = np.array(var)[pos_reshuffled]
except:
print('reshuffle thing')
stack_sp = resave_vars_corr(names, res_name, stack_sp, v, var)
for key in stack_sp:
if key not in frame_cell.keys():
frame_cell[key] = np.float('nan')
try:
frame_cell[key].loc[0] = stack_sp[key] # .iloc[0]
except:
print('sequence thing')
embed()
return frame_cell
def save_spikestrains_several(frame_cell, spikes_all):
for ss, sp in enumerate(spikes_all):
try:
frame_cell['spikes' + str(ss)] = sp # frame_c.spikes.iloc[0][0][0]
except:
frame_cell['spikes' + str(ss)] = float('nan')
frame_cell['spikes' + str(ss)].iloc[0:len(sp)] = sp
print('spikes something')
return frame_cell
def reshuffle_spike_lengths(spikes_all):
lengths = []
for r in spikes_all:
lengths.append(len(r))
pos_reshuffled = np.argsort(lengths)[::-1]
spikes_all = np.array(spikes_all)[pos_reshuffled]
return spikes_all, pos_reshuffled
def rename(model_folder, dir_prev, dir_new, function_name='calc_phaselocking-'):
# damit das nicht mehrmals passiert
not_renamed = True
if (function_name not in dir_prev) | (function_name == ''):
change_to = model_folder + '/' + function_name + dir_new
if not os.path.exists(change_to):
try:
os.rename(model_folder + '/' + dir_prev, change_to)
except:
print('some problem renaming')
embed()
not_renamed = False
else:
print('already there:', model_folder + '/' + dir_prev, 'to', change_to)
not_renamed = True
else:
change_to = model_folder + '/' + function_name + dir_prev
return not_renamed, change_to
def title_motivation():
titles = [f_eod_name_rm(),
r'$' + f_eod_name_core_rm() + '$ \& $f_{1}$',
r'$' + f_eod_name_core_rm() + '$ \& $f_{2}$',
r'$' + f_eod_name_core_rm() + '$ \& $f_{1}$ \& $f_{2}$',
[]] ##'receiver + ' + 'receiver + receiver
return titles
def rem_variable(rm_var = {'rm':True, 'size': 'small'}):
return rm_var
def f_eod_name_rm():
return r'$' + f_eod_name_core_rm() + '$'
def f_eod_name_core_rm():
rm_var = rem_variable()
if rm_var['rm'] == True:
val = r'f\rm{_{EOD}}'
else:
val = r'f_{EOD}'
return val
def exp_params(exp_tau, exponential, v_exp):
if exponential == '':
v_exp = 1
exp_tau = 0.001
elif exponential == 'EIF':
v_exp = np.array([0])
exp_tau = np.array([0.001, 0.01, 0.1]) # 10
elif exponential == 'CIF':
v_exp = np.array([0, 0.5, 1, 1.5, 2, 0.2, -0.5, -1]) #
exp_tau = np.array([0]) # 10
return exp_tau, v_exp
def filter_square_params(c_grouped, cell_here, frame, frame_cell_orig):
new_f2_tuple = frame_cell_orig[['df2', 'f2']].apply(tuple, 1).unique()
dfs = [tup[0] for tup in new_f2_tuple]
frame_cell = frame[(frame.cell == cell_here)] # & (frame[c_here] == c_h)]
frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell)
diffs = find_deltas(frame_cell, c_grouped[0])
frame_cell = find_diffs(c_grouped[0], frame_cell, diffs, add='_original')
new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).sum() # ['score']
matrix = new_frame.pivot(index='df2', columns='df1', values='diff')
return frame_cell, matrix
def plt_square_here3(ax, frame, score_here, c_nr=0.1, cls="RdBu_r", c_here='c1'):
cs = frame[c_here].unique()
c_chosen = cs[np.argmin(np.abs(cs - c_nr))]
frame_cell = frame[(frame[c_here] == c_chosen)] # & (frame[c_here] == c_h)]
frame_cell = frame_cell[~ (frame_cell.f1 == frame_cell.f2)]
frame_cell, df1s, df2s, f1s, f2s = find_dfs(frame_cell)
new_frame = frame_cell.groupby(['df1', 'df2'], as_index=False).mean() # ['score']
matrix = new_frame.pivot(index='df2', columns='df1', values=score_here)
try:
im = ax.pcolormesh(
np.array(list(map(float, matrix.columns))), np.array(matrix.index),
matrix,
cmap=cls,
rasterized=False) # 'Greens'#vmin=np.percentile(np.abs(stack_plot), 5),vmax=np.percentile(np.abs(stack_plot), 95),
except:
print('ims probelem')
embed()
return im, matrix
def roc_filename2():
return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_1000_mult_minimum_1temporal'
def roc_filename1():
return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_100_mult_minimum_1temporal'
def roc_filename0():
return 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_0.1_len_25_nfft_32768_trialsnr_20_mult_minimum_1temporal'
# 'calc_ROC_contrasts-ROCmodel_contrasts1_diagonal1_FrF1rel_0.33_FrF2rel_0.67_C2_0.1_LenNrs_20_1Nrs_0.0001_LastNrs_1.0_len_25_nfft_32768_trialsnr_' + trial_nr + '_mult_minimum_1temporal',
def isi_xlabel():
return '1/$'+f_eod_name_core_rm()+'$'
def get_global_eod_for_eodf(b, duration, mt, mt_nr_small):
try:
global_eod, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small], duration, 'EOD')
except:
try:
global_eod, sampling = link_arrays_eod(b, mt.positions[:][mt_nr_small], duration,
'EOD-1')
except:
b_names = get_data_array_names(b)
for b_name in b_names:
if 'eod' in b_name:
print('EOD 1 prob')
embed()
sampling = 0
return global_eod, sampling
def get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, nfft_eod, sampling):
if sampling > 0:
if len(global_eod) / sampling > 0.1:
try:
if len(eodf_orig) > 0:
try:
eod_fr_orig = eodf_orig.iloc[mt_nr_small]
except:
eod_fr_orig = eodf_orig[mt_nr_small]
else:
eod_fr_orig = eodf_orig
eod_fr, p, f = get_eodf(global_eod, b, eod_fr_orig, nfft_eod=nfft_eod)
except:
print('still eodf problem')
embed()
else:
eod_fr = float('nan')
else:
eod_fr = float('nan')
return eod_fr
def get_eodf(global_eod, b, eodf_orig, nfft_eod=2 ** 16):
if len(global_eod) > 0:
sampling_rate = get_sampling(b, load_eod_array='EOD')
eod_fr, p, f = calc_freq_from_psd(global_eod, sampling_rate, nfft=nfft_eod) # v
else:
p = None
f = None
if eodf_orig:
eod_fr = eodf_orig
else:
eod_fr = float('nan')
if eodf_orig:
if not np.isnan(eodf_orig):
if (np.max(np.array(eodf_orig)) > 300) & (np.min(np.array(eodf_orig)) < 1000):
eodf_orig = np.array(eodf_orig)
if np.abs(eod_fr - eodf_orig) > 25:
sampling_rate = get_sampling(b, load_eod_array='EOD')
# ok hier checke ich nochmal ob diese Effekte stabil sind
frequency_saves, frequency = nfft_improval(sampling_rate, eodf_orig,
global_eod, eod_fr)
if np.min(np.abs(frequency_saves - eodf_orig)) > 300:
print('EODf diff too big')
# also hier sage ich wenn alle immer noch das gleiche sagen dann passt das schon
if np.min(np.abs(frequency_saves - eodf_orig)) < 10:
# und hier sage ich wenn es doch unterschiede gibt dann wähle das minimum
eod_fr = frequency_saves[np.argmin(np.abs(frequency_saves - eodf_orig))]
# print('EODf diff solvable')
if (eod_fr > 1000) & (eod_fr < 350): # das ist unmöglich
# wenn es unmöglich ist nehmen wir wieder die ursprüngliche Abschätzung
eod_fr = eodf_orig
test = False
if test:
plt.plot(f, p)
return eod_fr, p, f
def calc_freq_from_psd(noise_eod, eod_sampling_frequency, nfft=2 ** 16):
p, f = ml.psd(noise_eod - np.mean(noise_eod), Fs=eod_sampling_frequency, NFFT=nfft, noverlap=nfft // 2)
eod_fr = f[np.argmax(p)]
return eod_fr, p, f
def nfft_improval(sampling_rate, frequency_orig, global_stimulus, frequency,
nffts=[2 ** 16, 2 ** 15, 2 ** 14, 2 ** 13]):
# diesen Teil gibt es wegen einer Zelle, da finde ich das nfft was am nächsten zu der Urpsurngsfrequnez ist
frequency_saves = []
for nfft_here in nffts:
frequency_save, p, f = calc_freq_from_psd(global_stimulus, sampling_rate, nfft=nfft_here)
if np.abs(frequency_orig - frequency_save) < 20:
frequency = frequency_save
frequency_saves.append(frequency_save)
return frequency_saves, frequency
def get_nffts_medium(baseline_eod_long, nfft, sampling):
freq_step_maximal, maximal_length = get_freq_step(baseline_eod_long, sampling)
eod_fr_long, p, f = calc_freq_from_psd(baseline_eod_long, sampling, nfft=nfft)
return eod_fr_long, freq_step_maximal, maximal_length
def get_freq_step(baseline_eod_long, sampling):
maximal_length = len(baseline_eod_long)
freq_step_maximal = get_freq_steps(maximal_length, sampling)
return freq_step_maximal, maximal_length
def get_freq_steps(maximal_length, sampling):
freq_step_maximal = sampling / maximal_length
return freq_step_maximal
def find_eod_fr_mt(b, mts, extended=False, indices=[], freq_step_nfft_eod=0.6103515625):
# also manche der arrays haben das ja nicht
# Hier laden wir erstmal das was schon da ist und dann verifizieren wir das mit dem power spectrum!
# ok wir nehmen einfach die weil für alle tags das zu haben ist schon schwierig!
eod_redo, eod_frs_orig_b = find_eod_fr_orig(mts, b, mts.positions[:])
if len(eod_frs_orig_b) != len(mts.positions[:]):
eod_redo, eod_frs_orig_b = find_eod_fr_orig(mts, b, mts.positions[:], type='redo')
# data_array_names = get_data_array_names(b)
features, delay_name = feature_extract(mts)
eod_frs_orig = []
for feat in features:
if 'EODf' in feat:
print(True)
if len(indices) > 0:
eod_frs_orig = mts.features[feat].data[:] # [indices]
else:
eod_frs_orig = mts.features[feat].data[:] # [indices]
if len(eod_frs_orig) < 1: # | (np.sum(np.isnan(np.array(eod_frs_orig_b)))> 0)
if len(indices) > 0:
eod_frs_orig = np.array(eod_frs_orig_b) # [indices]
else:
eod_frs_orig = eod_frs_orig_b
eod_frs, eodf_orig, freq_steps_single = find_eodf_three(b, mts, eod_frs_orig, mt_idx=indices,
freq_step_nfft_eod=freq_step_nfft_eod, max_eod=True)
eod_fr_medium = []
freq_step_medium = []
freq_step_mts = []
eod_fr_mts = []
if extended:
##################################
# eodfs for all mts
print('doing the rest')
mt_poss = concat_mts_pos(mts) # [indices]
mt_poss_ind = concat_mts_pos(mts)[indices]
min_pos = mt_poss_ind[np.argmin(mt_poss_ind)]
mt_nr_small = np.where(mt_poss == min_pos)[0][0]
start_pos = mts.positions[:][mt_nr_small]
max_pos = mt_poss_ind[np.argmax(mt_poss_ind)]
mt_nr_max = np.where(mt_poss == max_pos)[0][0]
duration = (np.max(mt_poss_ind) + mts.extents[:][mt_nr_max]) - start_pos
global_eod, sampling = get_global_eod_for_eodf(b, duration, mts, mt_nr_small)
nfft_eod = len(global_eod)
eod_fr_mts = get_eodf_here(b, eodf_orig, global_eod, mt_nr_small, nfft_eod, sampling)
freq_step_mts, maximal_length = get_freq_step(global_eod, sampling)
##################################
# eodfs for all with higher resolution
# hier das mit der desired auflösung noch machen
freq_step = 0.01
nfft = int(np.round(sampling / freq_step))
baseline_eod_long = link_arrays_eod(b, first=start_pos, second=nfft / sampling,
array_name='EOD')[0]
eod_fr_medium, freq_step_medium, maximal_length = get_nffts_medium(baseline_eod_long,
nfft, sampling)
return eod_frs, eod_frs_orig, eod_fr_medium, freq_step_medium, freq_step_mts, eod_fr_mts, freq_steps_single
def find_eod_fr_orig(mts, b, mt_length, type='SAM'):
names = get_data_array_names(b)
if (mts.name + '_EOD Rate' in names) & (type == 'SAM'):
eod_frs = b.data_arrays[mts.name + '_EOD Rate'][:] # sinewave-1
eod_redo = False
else:
try:
eod_frs = b.metadata['Recording']['Subject']['EOD Frequency']
except:
eod_frs = float('nan') # b.metadata.pprint(max_depth = -1)
eod_frs = [eod_frs] * len(mt_length)
eod_redo = True
return eod_redo, eod_frs
def concat_mts(indices, mt):
if len(np.shape(mt.extents[:][indices])) > 1:
try:
mt_extends = np.concatenate(mt.extents[:][indices])
except:
print('still some shape mt problems')
embed()
else:
mt_extends = mt.extents[:][indices]
return mt_extends
def concat_mts_pos(mts):
if len(np.shape(mts.positions[:])) > 1:
try:
mt_extends = np.concatenate(mts.positions[:])
except:
print('still some shape mt problems')
embed()
else:
mt_extends = mts.positions[:]
return mt_extends
def resave_vars_eodfs(names, stack_sp, vars, res_name='res'):
for v, var in enumerate(vars):
stack_sp = resave_vars_corr(names, res_name, stack_sp, v, var)
return stack_sp
def resave_vars_corr(names, res_name, stack_sp, v, var):
if (res_name not in names[v]) & ('all' not in names[v]):
for vv, var_trial in enumerate(var):
stack_sp[names[v] + str(vv)] = var_trial
stack_sp[names[v]] = np.mean(var)
else:
stack_sp[names[v]] = var
return stack_sp
def names_eodfs():
names = ['EODf', 'EODf_all', 'EODf_res', 'freq_step_trial', 'freq_step_res', 'freq_step_all', ]
return names
def first_saturation_freq():
return 20.5, -300.5 # (39.5, -210.5)#(39.5, -210.5)
def plt_FI_data_alex(cell, data, model, sort_data, sort_model):
plt.title(cell)
x, d, y, ymax, ymin, _ = interp_fi(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data])
plt.plot(x, d, color='black', linewidth=2, zorder=2, label='data')
plt.scatter(data['contrasts'][sort_data], (data['f_onset'][sort_model] - ymin) / ymax,
color='black', s=7, zorder=2)
colors = ['green', 'blue', 'orange', 'pink', 'purple', 'red']
plt.xlabel('Contrast [%]')
plt.ylabel('Firing Rate [Hz]')
return colors
def load_fi_curves_alex(cell, df_cell, n, nn, results):
results.append({})
results[-1]['cell'] = cell
results[-1]['n'] = n
model = df_cell[df_cell['n'] == n]
data = pd.read_csv('../data/Kennlinien/cell_fi_curves_csvs/' + cell + '.csv')
''' df_cell['inputs'].iloc[0]
df_cell['on_r'].iloc[0]
df_cell['ss_r'].iloc[0]
df_cell['on_s'].iloc[0]
df_cell['ss_s'].iloc[0]
data['contrasts']'''
sort_data = np.argsort(data['contrasts'])
sort_model = np.argsort(model['inputs'].iloc[0])
mses = mse(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data],
model['on_r'].iloc[0][sort_model])
names = ['', '_fmax', '_k', 'half']
for nn, n in enumerate(names):
results[-1]['on_r' + n] = mses[nn]
mses = mse(model['inputs'].iloc[0][sort_model], data['f_onset'][sort_data],
model['on_s'].iloc[0][sort_model])
for nn, n in enumerate(names):
results[-1]['on_s' + n] = mses[nn]
mses = mse(model['inputs'].iloc[0][sort_model], data['f_steady_state'][sort_data],
model['ss_r'].iloc[0][sort_model])
for nn, n in enumerate(names):
results[-1]['ss_r' + n] = mses[nn]
mses = mse(model['inputs'].iloc[0][sort_model], data['f_steady_state'][sort_data],
model['ss_s'].iloc[0][sort_model])
for nn, n in enumerate(names): #
results[-1]['ss_s' + n] = mses[nn]
return data, model, n, nn, sort_data, sort_model
def interp_fi(xdata, ydata):
try:
popt, pcov = curve_fit(bolzmann, xdata, ydata, bounds=(
[0, - np.inf, - np.inf], [np.max(ydata) * 4, np.inf, np.inf]))
x = np.linspace(xdata[0] * 1.1, xdata[-1] * 1.1, 1000)
y = bolzmann(x, *popt)
y_norm1 = y - np.min(y)
y_norm2 = y_norm1 / np.max(y_norm1)
except:
popt = [float('nan'), float('nan'), float('nan')]
x = float('nan')
y = float('nan')
y_norm1 = float('nan')
y_norm2 = float('nan')
plot = False
if plot:
plt.subplot(1, 2, 1)
plt.plot(x, y)
plt.scatter(xdata, ydata)
plt.subplot(1, 2, 2)
plt.plot(x, y_norm2)
plt.scatter(xdata, ydata / (np.max(y)))
plt.show()
return x, y_norm2, y, np.max(y_norm1), np.min(y), popt
def mse(x, data, model):
_, d, _, _, _, d_popt = interp_fi(x, data)
_, m, _, _, _, m_popt = interp_fi(x, model)
return np.mean((d - m) ** 2), (d_popt[0] - m_popt[0]) ** 2, (d_popt[1] - m_popt[1]) ** 2, (
d_popt[2] - m_popt[2]) ** 2
def onset_function(x, f_max, k, I_half):
return f_max / (1 + np.exp(-k * (x - I_half)))
def steady_function(x, f_max, k, I_half):
return f_max / (1 + np.exp(-k * (x - I_half)))
def start_fi_o():
return 7
def end_fi_o():
return 55
def start_fi_s():
return 30
def end_fi_s():
return 90
def trial_nrs_ram_model():
trial_nrs_here = np.array([9, 11, 20, 30, 100, 500, 1000, 10000, 100000, 250000, 500000, 750000, 1000000])
return trial_nrs_here
def colors_suscept_paper_dots():
color0 = 'blue'
color0_burst = 'darkgreen'
color01 = 'green'
color02 = 'purple'
color012 = 'orange'
color01_2 = 'red' ##
return color01, color012, color01_2, color02, color0_burst, color0
def plt_voltage_trace(cell, eod_fr, frame_cell, axs, lim_here, test, spikes_plotted_lower=True, spikes_plotted=True,
dir='', scaling_factor=1, color_trace='grey', color_first_spike='black',
color_second_spike='blue', xlim=0.400):
spike_times_all_full = []
if os.path.exists(dir + '../data/cells/' + cell + '/' + cell + '.nix'):
f, nix_exists, nix_missing = load_f(['cells'], 0, cell, dir=dir)
if nix_exists:
b = f.blocks[0]
cont_baseline, nix_missing, ts = find_tags_baseline(b, nix_missing)
if cont_baseline & (len(ts) > 0):
spike_times_all = []
data_array_names = get_data_array_names(b)
if 'eod' in ''.join(data_array_names).lower():
lengths = []
for t in ts:
lengths.append(t.extent[:][0])
tag = ts[np.argmax(lengths)]
add_mt = True
if add_mt:
if 'base' in tag.name.lower():
print(tag.name)
tag_here = b.tags[tag.name] # 2/3
if len(tag_here.extent[:]) > 0:
duration = restrict_base_durationts(tag_here.extent[:][0])
if duration < xlim:
duration_base = xlim
else:
duration_base = xlim
spike_times = link_arrays_spikes(b, first=tag.position[:][0],
second=duration_base,
minus_spikes=tag.position[:][0])
spike_times_full = link_arrays_spikes(b, first=tag.position[:][0],
second=tag.extent[:][0],
minus_spikes=tag.position[:][0])
spike_times_all.append(spike_times)
spike_times_all_full.append(spike_times_full)
eods_g, sampling = link_arrays_eod(b, first=tag.position[:][0],
second=duration_base,
array_name='V-1')
axs.plot(np.arange(0, len(eods_g) / sampling, 1 / sampling) * scaling_factor,
eods_g, linewidth=0.8, color=color_trace)
if spikes_plotted:
if len(spike_times) > 0:
spikes_here = spike_times
spikes_here = spikes_here[spikes_here < xlim]
axs.scatter(spikes_here * scaling_factor,
np.percentile(eods_g, 100) * np.ones(len(spikes_here)),
color=color_first_spike, clip_on=False)
if spikes_plotted_lower:
if len(spike_times) > 0:
axs.scatter(spike_times * scaling_factor,
np.percentile(eods_g, 90) * np.ones(len(spike_times)),
color=color_first_spike)
hists2 = [(np.diff(spike_times) / (1 / eod_fr))]
if len(hists2[0]) > 0:
try:
np.min(hists2) < 1.5
except:
print('hist thing')
embed()
if np.min(hists2[0]) < 1.5:
burst_corr = '_burstIndividual_'
hists2, spikes_ex, frs_calc2 = correct_burstiness(hists2, [spike_times],
[eod_fr] * len(
[spike_times]),
[eod_fr] * len(
[spike_times]),
lim=lim_here,
burst_corr=burst_corr)
if spikes_plotted:
try:
spikes_here = spikes_ex[0][spikes_ex[0] < xlim]
axs.scatter(spikes_here * scaling_factor,
np.percentile(eods_g, 100) * np.ones(len(spikes_here)),
clip_on=False,
color=color_second_spike)
except:
print('scatter thing')
embed()
if spikes_plotted_lower:
axs.scatter(spikes_ex[0] * scaling_factor,
np.percentile(eods_g, 90) * np.ones(len(spikes_ex[0])),
color=color_second_spike)
axs.set_xlim(0, xlim * scaling_factor)
if test:
plt.plot(np.arange(0, len(eods_g) / sampling, 1 / sampling) * scaling_factor,
eods_g, color=color_first_spike)
plt.scatter(spike_times * scaling_factor,
np.percentile(eods_g, 90) * np.ones(len(spike_times)),
color=color_first_spike)
if len(ts) < 1:
axs.set_title('no nix')
return spike_times_all_full
def find_tags_baseline(b, cont_rlx):
try:
ts = find_tags_list(b, names='baseline')
cont_baseline = True
except:
print('ts problem')
try:
ts = find_tags_list(b, names='baseline')
cont_baseline = True
except:
cont_baseline = False
cont_rlx = True
ts = []
return cont_baseline, cont_rlx, ts
def load_f(data_dir, c, cell, dir=''):
cont_rlx = False
try:
f = nix.File.open(dir + '../data/' + data_dir[c] + '/' + cell + '/' + cell + '.nix', nix.FileMode.ReadOnly)
cont_here = True
except:
f = []
cont_here = False
cont_rlx = True
return f, cont_here, cont_rlx
def perc_model_full():
return 95
def get_frame_for_base_plot(cells, save_names=None, based_on_ram_overview=True, species=' Apteronotus leptorhynchus'):
frame, frame_spikes = load_cv_vals_susept(cells, EOD_type='synch',
names_keep=['gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual', 'fr_burst_corr_individual',
'width_75_burst_corr_individual', 'vs_burst_corr_individual',
'cell_type_reclassified', 'cell']) # redo = True,
cell_type_type = 'cell_type_reclassified'
frame = unify_cell_names(frame, cell_type=cell_type_type)
redo = False
if not save_names:
save_names = [
'calc_RAM_overview-_simplified_' + version_final()] # 'calc_RAM_overview-_simplified_noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_','calc_RAM_overview-noise_data9_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s__burstIndividual_',
frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo)
test = False
if test:
frame_load_sp[frame_load_sp.species == species]
frame_load_sp.cell_type_reclassified.unique()
dated_up = update_ssh_file(load_folder_name('calc_RAM') + '/' + save_names[0] + '.csv')
if dated_up == 'yes':
frame_load_sp = load_overview_susept(save_names[0], redo=True, redo_class=redo)
cell_types = [' P-unit', ' Ampullary', ]
if based_on_ram_overview:
cells_exclude = []
for cell_type_here in cell_types:
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='min', species=species)
cells_exclude.extend(frame_file.cell.unique())
frame_load = frame[frame['cell'].isin(cells_exclude)]
else:
frame_load = frame
return cell_type_type, frame_load, frame_spikes
def get_grids_for_cv_fr(grid_scatter):
grid_panel = gridspec.GridSpecFromSubplotSpec(2, 2, grid_scatter, wspace=0, hspace=0.05, height_ratios=[0.4, 2.5],
width_ratios=[2.5,
0.25]) # 0.2, 2.5,width_ratios=[2.5, 0.2, 2.5, 0.7],
axx = plt.subplot(grid_panel[
0, 0]) # , height_ratios=[0.35, 3], width_ratios=[3, 0.5], height_ratios=[0.7, 2.5],width_ratios=[2.5,0.5]
axs = plt.subplot(grid_panel[1, 0])
try:
axy = plt.subplot(grid_panel[1, 1])
except:
print('hist thing')
embed()
return axs, axx, axy
def plt_fr_cv_base(ax0, ax_cv, ax_fr, add, frame_load,
gg, s=5, fr_lim=700, cell_types=[' P-unit', ' Ampullary'], alpha=0.3, xmax=[1.51, 1.51], cvmax=1.51,
cell_type_type='cell_type_reclassified', color_given=None, colors=[], annotate=False,
species=' Apteronotus leptorhynchus'):
if not colors:
colors = colors_overview()
for c_here, cell_type in enumerate(cell_types):
if not color_given:
color_given = colors[str(cell_type)] # not found
frame_g = frame_load[
(frame_load[cell_type_type] == cell_type) & ((frame_load.gwn == True) | (frame_load.fs == True))]
print(cvmax)
kernel_histogram(ax_cv, color_given, np.array(frame_g['cv' + add[gg]]), xmin=0, xmax=xmax[c_here],
alpha=0.5, step=0) # step=0.06
ax_cv.show_spines('')
remove_yticks(ax_cv)
ax_fr.get_shared_y_axes().join(*[ax_fr, ax0])
ax_cv.get_shared_x_axes().join(*[ax_cv, ax0])
test = False
if test:
pass
kernel_histogram(ax_fr, color_given, np.array(frame_g['fr' + add[gg]]),
step=0, alpha=0.5, orientation='vertical') # step=4,
ax_fr.show_spines('')
ax_cv.show_spines('')
remove_xticks(ax_cv)
remove_xticks(ax_fr)
remove_yticks(ax_fr)
y_axis = 'fr'
x_axis = 'cv'
frame_g = ptl_fr_cv(add[gg], alpha, annotate, ax0, cell_type_type, cell_types, frame_load, s, cv=y_axis, fr=x_axis,
color_given=color_given)
add_namex = [cv_base_name(), cv_base_name_corr()]
add_namey = [fbasenamehz(), fbasecorrectedname()]
ax0.set_xlabel(add_namex[gg])
ax0.set_ylabel(add_namey[gg])
ax0.set_ylim(0, fr_lim)
ax0.set_xlim(0, cvmax)
return x_axis, y_axis
def fbasecorrectedname():
rm_var = rem_variable()
if rm_var['rm']:
val = r'$f\rm{_{BaseCorrected}}$ [Hz]'
else:
val = r'$f_{BaseCorrected}$ [Hz]'
return val
def fbasenamehz():
rm_var = rem_variable()
if rm_var['rm']:
val = fbasename() + ' [Hz]'
else:
val = fbasename() + ' [Hz]'
return val
def fbasename():
return r'$f' + basename() + '$'
def fbasename_small():
return r'$f' + basename_small() + '$'
def stimname():
rm_var = rem_variable()
if rm_var['rm']:
val = r'\rm{_{Stim}}'
else:
val = r'_{Stim}'
return val
def basename():
rm_var = rem_variable()
if rm_var['rm']:
val = r'\rm{_{Base}}'
else:
val = r'_{Base}'
return val
def basename_small():
rm_var = rem_variable()
if rm_var['rm']:
val = r'\rm{_{base}}'
else:
val = r'_{base}'
return val
def cv_base_name_corr():
rm_var = rem_variable()
if rm_var['rm']:
val = r'CV$\rm{_{BaseCorrected}}$'
else:
val = r'CV$_{BaseCorrected}$'
return val
def annotate_left_arrow(ax, lw=1.5, ws=None):
if not ws:
ws = ws_nonlin_systems()
ax.annotate('', ha='center', xycoords='axes fraction',
xy=(1 + ws, 0.5), textcoords='axes fraction',
xytext=(1, 0.5),
arrowprops={"arrowstyle": "->",
"linestyle": "-",
"linewidth": lw,
"color":
'black'},
zorder=1, annotation_clip=False, transform=ax.transAxes, ls=8)
def plt_single_matrix(ax, stack_final, ls=8, y_label=True, fr_name = 'fr'):
new_keys, stack_plot = convert_csv_str_to_float(stack_final)
mat = RAM_norm_data(stack_final['isf'].iloc[0], stack_plot,
stack_final['snippets'].unique()[0], stack_here=stack_final) #
mat, add_nonlin_title, resize_val = rescale_colorbar_and_values(mat)
im = plt_RAM_perc(ax, 'no', mat)
if y_label:
set_ylabel_arrow(ax, xpos=-0.17, ypos=0.97)
set_xlabel_arrow(ax, xpos=1, ypos=-0.23)
set_clim_same([im], mats=[mat], lim_type='up', nr_clim='perc', clims='', percnr=95)
fr = stack_final[fr_name].iloc[0]
#embed()
plt_triangle(ax, fr, fr, new_keys[-1], eod_metrice=False,
nr=1) # eod_fr_half_color='purple', power_noise_color='blue',
# todo: change clim values with different Hz values
cbar, left, bottom, width, height = colorbar_outside(ax, im, add=5, width=0.01, ls=ls)
cbar.set_label(nonlin_title(add_nonlin_title=' [' + add_nonlin_title), rotation=90, labelpad=8)
''' eod_fr, stack_spikes = plt_data_suscept_single(ax, cbar_label, cell, cells, f, fig, file_names_exclude, lp, title,
width)'''
return cbar, fr, mat, im
def load_stack_data_susept(cell, save_name, end=''):
load_name = load_folder_name('calc_RAM') + '/' + save_name + end
add = '_cell' + cell + end # str(f) # + '_amp_' + str(amp)
stack_cell = load_data_susept(load_name + '_' + cell + '.pkl', load_name + '_' + cell, add=add,
load_version='csv')
file_names_exclude = get_file_names_exclude()
stack_cell = stack_cell[~stack_cell['file_name'].isin(file_names_exclude)]
file_names = stack_cell.file_name.unique()
file_names = exclude_file_name_short(file_names)
cut_off_nr = get_cutoffs_nr(file_names)
try:
maxs = list(map(float, cut_off_nr))
except:
embed()
file_names = file_names[np.argmax(maxs)]
stack_file = stack_cell[stack_cell['file_name'] == file_names]
stack_final = get_stack_final(cell, stack_file)
mat, new_keys = get_mat_susept(stack_final)
return mat, stack_final
def get_stack_final(cell, stack_file):
amps = [np.min(stack_file.amp.unique())]
amps = restrict_punits(cell, amps)
amp = np.min(amps) # [0]
stack_amps = stack_file[stack_file['amp'] == amp]
lengths = stack_amps.stimulus_length.unique()
trial_nr_double = stack_amps.trial_nr.unique()
trial_nr = np.max(trial_nr_double)
stack_final = stack_amps[
(stack_amps['stimulus_length'] == np.max(lengths)) & (stack_amps.trial_nr == trial_nr)]
return stack_final
def data_overview_punit(cell_types=[' P-unit']):
plot_style()
default_figsize(width=cm_to_inch(28), length=cm_to_inch(12))
default_ticks_talks()
var_it = 'Response Modulation [Hz]'
var_it2 = ''
#print(right)
grid0 = overview_mod_grid(cell_types)
##########################
# Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln
save_names = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_']
species = ' Apteronotus leptorhynchus'
burst_fraction = [1, 1] # ,1,1]
burst_corr_reset = 'burst_fraction_burst_corr_individual_stim'
redo = False
counter = 0
tags = []
frame_load_sp = load_overview_susept(save_names[0], redo=redo, redo_class=redo)
scores = ['max(diag5Hz)/med_diagonal_proj_fr', 'max(diag5Hz)/med_diagonal_proj_fr',
] # + '_diagonal_proj'
max_xs = [[[], [], []], [[], [], []]]
for c, cell_type_here in enumerate(cell_types):
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species)
test = False
# ok das schließe ich aus weil da irgendwas in der Detektion ist, das betrifft jetzt genau 3 Zellen, also nicht so schlimm
# 63 2018-08-14-af-invivo-1
# 241 2018-09-05-aj-invivo-1
# 252 2022-01-08-ah-invivo-1
frame_file = frame_file[frame_file.cv_stim < 5]
if test:
frame_file[frame_file.cv_base > 3].cell
frame_file[frame_file.cv_stim > 3].cv_stim
frame_file.groupby('cell').groups.keys()
frame_file.group_by('cell')
len(frame_file.cell.unique())
##############################################
# modulatoin comparison for both cell_types
################################
# Modulation, cell type comparison
x_axis = ['cv_stim', 'cv_base', 'response_modulation'] # ,'fr_base']#
var_item_names = [var_it, var_it, var_it2] # ,var_it2]#['Response Modulation [Hz]',]
var_types = [''] # ,'response_modulation','']#,'']#'response_modulation'
max_x = max_xs[c]
x_axis_names = [x_axis_talk(), 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',]
score = scores[c]
scores_here = [score, score, score] # ,score]
score_name = [nonlinearity_name_talk(), NLI_scorename2(), NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med'
ax_j = []
axls = []
axss = []
log = '' # 'logall'#''#'logy','logall'True#False
for v, var_type in enumerate(var_types):
#axx, axy, axs, axls, axss, ax_j = get_grid_4(ax_j, axls, axss, grid0[v, counter])
axs = plt.subplot(grid0[v, counter])
if log == 'logy':
pass
else:
pass
if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]):
pass
else:
pass
xlimk = None
labelpad = 0.5 # -1
fs, ms = size_talk_overview()
cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v],
frame_file, scores_here[v],
burst_fraction_reset=burst_corr_reset,
var_item=var_type, labelpad=labelpad, max_x=max_x[v],
x_pos=1, fs=fs, ms=ms, burst_fraction=burst_fraction[c],
sides=False, ha='right',
color_given=colors_overview()[' P-unit_talk'],
legend_spacing=0.15)
print(cell_type_here + ' median ' + scores_here[v] + '' + str(np.nanmedian(frame_file[scores_here[v]])))
print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]])))
if v == 0:
pass
axs.set_ylabel(score_name[v])
axs.set_xlabel(x_axis_names[v], labelpad=labelpad)
axs.set_ylim(0, 3.8)
axs.set_xlim(0, 1.8)
extra_lim = False
if extra_lim:
if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]):
axs.set_xlim(xlimk)
if log == 'logy':
axs.set_yscale('log')
make_log_ticks([axs])
elif log == 'logall':
axs.set_yscale('log')
make_log_ticks([axs])
axs.set_xscale('log')
counter += 1
save_visualization(pdf=True)
def overview_mod_grid(cell_types, right = 0.98, ws = 0.75):
grid0 = gridspec.GridSpec(1, len(cell_types), wspace=ws, bottom=0.2,
hspace=0.45, left=0.12, right=right, top=0.95)
return grid0
def nonlinearity_name_talk():
return 'Nonlinearity'
def size_talk_overview():
fs = 20
ms = 13
return fs, ms
def plt_specific_cells(axs, cell_type_here, cv_name, frame_file, score, marker=[]):
######################################################
# hier kommen die kontrast Punkte dazu
if cell_type_here == ' P-unit':
cells_plot2 = p_units_to_show(type_here='contrasts')[0:2]
else:
cells_plot2 = [p_units_to_show(type_here='amp')[0]]
cells_extra = frame_file[frame_file['cell'].isin(cells_plot2)].index
if not marker:
axs.scatter(frame_file[cv_name].loc[cells_extra], frame_file[score].loc[cells_extra],
s=9, facecolor="None", edgecolor='black', alpha=0.7, clip_on=False) # colors[str(cell_type_here)]
else:
axs.scatter(frame_file[cv_name].loc[cells_extra][0:2], frame_file[score].loc[cells_extra][0:2],
s=9, facecolor="None", marker=marker[1], edgecolor='black', alpha=0.7,
clip_on=False) # colors[str(cell_type_here)]
axs.scatter(frame_file[cv_name].loc[cells_extra][2:4], frame_file[score].loc[cells_extra][2:4],
s=9, facecolor="None", marker=marker[0], edgecolor='black', alpha=0.7,
clip_on=False) # colors[str(cell_type_here)]
def get_grid_4(ax_j, axls, axss, grid0):
grid_k = gridspec.GridSpecFromSubplotSpec(2, 2, grid0,
hspace=0.1, wspace=0.1, height_ratios=[0.35, 3], width_ratios=[3, 0.5])
axk = plt.subplot(grid_k[0, 0])
ax_j.append(axk)
axs = plt.subplot(grid_k[1, 0])
axss.append(axs)
axl = plt.subplot(grid_k[1, 1])
axls.append(axl)
return axk, axl, axs, axls, axss, ax_j
def conv_integers(threshold, power_kernel):
power_threshold = []
for i in range(len(threshold)):
power_threshold.append(np.array(power_kernel) + np.array(threshold[i]))
power_threshold = np.concatenate(power_threshold)
return power_threshold
def NLI_burstcorr_name2():
return 'PNL($f_{BurstCorr}$)'
def grid_evolutionary():
#gridr = gridspec.GridSpec(1, 2, wspace=0.45, hspace=0.5, top=0.85, left=0.05, bottom=0.45, right=1,
# width_ratios=[2,1.3]) # 2, 2,, height_ratios = [1,3]
#gridl = gridspec.GridSpecFromSubplotSpec(1, 2, gridr[0], wspace=0.6, hspace=0.5,
# width_ratios=[1, 1]) # 2, 2,, height_ratios = [1,3]
gridr = gridspec.GridSpec(1, 4, wspace=0.6, hspace=0.5, top=0.83, left=0.05, bottom=0.24, right=0.89,
width_ratios=[2.2,0, 1.8, 2.2]) # 2, 2,, height_ratios = [1,3] bottom = 0.45
return gridr
def didactic_sine_spectrum(axps, axts, color, sampling, sines, time_array, titles, freqs=None, colors=None,
colors_peaks=[['red', 'purple']], labels=[[r'$f_{1}$', r'$f_{2}$']]):
for ss, sine in enumerate(sines):
if colors:
color = colors[ss]
axts[ss].plot(time_array * 1000, sine, color=color)
axts[ss].set_ylim(np.min(sine) * 1.02, np.max(sine) * 1.02)
axts[ss].set_xlim(0, 0.1 * 1000)
axts[ss].show_spines('lb')
axts[ss].set_xlabel('Time [ms]')
axts[ss].set_title(titles[ss]) # , transform=axts[ss].transAxes) # r$\tilde{s}(f)$'
################################################################
p_array, f = ml.psd(sine - np.mean(sine), Fs=sampling, NFFT=2 ** 17,
noverlap=2 ** 15 // 2)
log = True
if log:
p_array = calc_log(p_array)
axps[ss].plot(f, p_array, color='black')
axps[ss].set_xlim(0, 100)
if ss == 0:
axts[ss].set_ylabel('Signal')
axps[ss].set_ylabel('PSD [1/Hz]')
else:
remove_yticks(axps[ss])
remove_yticks(axts[ss])
axps[ss].set_xlabel('Frequency [Hz]')
################################################################
# embed()
if freqs:
plt_peaks_several(freqs[ss], [p_array], axps[ss], p_array, f, labels=labels[ss],
colors=colors_peaks[ss], perc_peaksize=2)
if log:
axps[ss].set_ylim(-25,0)
axps[ss].set_ylabel('dB')
def retrieve_mat(diff_load, name):
droped = diff_load.dropna(axis=1)
cleaned = droped[droped['dist'] == name]
cleaned.pop('dist')
output = cleaned.reindex(sorted(cleaned.columns), axis=1)
return output
def load_baseline_matrix(what, cell, pivot1, a_fr=1):
baseline = pd.read_pickle(
load_folder_name(
'calc_model') + '/modell_all_cell_no_sinz1_afe0__afr1__afj0__length1.5_adaptoffsetallall2___stepefish10Hz_ratecorrrisidual35__modelbigfit_nfft4096_base.pkl')
baseline_cell = baseline[baseline['dataset'] == cell]
base_matrix = pivot1 * 1
if what != 'spike_times':
base = np.nanmean(baseline_cell[what])
if a_fr == 1:
base_matrix[:] = base * np.ones_like(pivot1)
else:
base_matrix[:] = np.zeros_like(pivot1)
else:
base = baseline_cell.iloc[0]['spike_times']
for i in range(len(base_matrix)):
for j in range(len(base_matrix.iloc[0])):
base_matrix.iloc[i, j] = baseline_cell.iloc[0]['spike_times']
# base_matrix = base * np.ones_like(pivot1)
return base, base_matrix, baseline
def get_control(nr, cell_nr, what, afe, contrast1='0.1', a_fr=1, contrast2='0', minimum=0.5, maximum=1.5, cell=[],
version_sinz='sinz',
adapt='adaptoffsetallall2', symetric='', beat_type='', step=10, variant='no', self='',
varied='emitter'):
name = 'modell_all_cell_' + variant + '_' + version_sinz + str(
nr) + self + '_afe' + str(contrast1) + '__afr' + str(a_fr) + '__afj' + str(
contrast2) + '__length1.5_' + adapt + '___stepefish' + str(
step) + 'Hz_ratecorrrisidual35__modelbigfit_nfft4096' + duration + beat_type + symetric + '.pkl'
control = pd.read_pickle(
load_folder_name('calc_model') + '/' + name)
if not cell:
cell = np.unique(control['dataset'])[cell_nr]
control_array = control[control['dataset'] == cell]
DF_j = np.unique(np.array(
(control_array['eodj'] - control_array['eodf']) / control_array[
'eodf'] + 1)) * 100
dict_here = dict(
zip(np.unique(control_array['eodj']), np.round(DF_j) / 100))
control_afj = rename(columns=dict_here)
DF_e = np.round(np.unique(
np.array((control_array['eode'] - control_array['eodf']) / control_array[
'eodf'] + 1)) * 100) / 100
control_afj = control_afj.set_index(DF_e)
if varied == 'emitter':
control_afj.columns.name = 'fish1-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eodj-eodf'
control_afj.index.name = 'fish2-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eode-eodf'
else:
control_afj.columns.name = 'fish2-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eodj-eodf'
control_afj.index.name = 'fish1-fish0 $f_{stim}/f_{EOD}$' # 'DeltaF-eode-eodf'
if maximum != []:
control_afj, column_chosen, index_chosen = cut_matrix_generation(control_afj, minimum, maximum)
return control_afj, DF_e, dict_here, control_array['eodf']
def create_spikes_mat(length, spikes_cut, sampling_rate, results=[], trial_nr=1, test_saturation=False):
# reset to the first spike
spikes_mat = [[]] * len(spikes_cut)
for s in range(len(spikes_cut)):
spikes_mat[s] = cr_spikes_mat(spikes_cut[s], sampling_rate, int(length * sampling_rate))
smoothed05 = gaussian_filter(spikes_mat, sigma=0.0005 * sampling_rate)
smoothed2 = gaussian_filter(spikes_mat, sigma=0.002 * sampling_rate)
smoothened_spikes_mat05 = np.mean(smoothed05, axis=0)
smoothened_spikes_mat2 = np.mean(smoothed2, axis=0)
if test_saturation:
# plt_saturation_effect(sampling_rate, smoothened_spikes_mat2, smoothed2, smoothed05, results, smoothened_spikes_mat05, spikes_mat)
from utils_test import plt_saturation_effect2
plt_saturation_effect2(sampling_rate, smoothened_spikes_mat2, smoothed2, smoothed05, results,
smoothened_spikes_mat05, spikes_mat, show=False)
return spikes_cut, spikes_mat, smoothened_spikes_mat05, smoothened_spikes_mat2
def get_cut_off_for_wn(cut_off_nr, file_name):
split = file_name.lower().split('hz')[0]
if 'wn' in split:
cut_off_nr.append(split.split('wn')[1])
else:
cut_off_nr.append(split.split('_')[1])
return cut_off_nr
def color_beats():
return 'red'
def power_didactic_subplots():
plt.subplots_adjust(wspace=0.25, top=0.85, left=0.1, hspace=0.85, bottom=0.2, right=0.97)
def reset_yaxis_cords(axes, ypos=-0.1):
for ax in axes:
ax.yaxis.set_label_coords(ypos, 0.5)
def x_axis_talk():
return 'CV (Noise)'
def cellscompar2(cells_plot2, amp_desired=[5, 20]): # [0, 1.1]
plot_style()
# 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_
# ]
# save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_']
save_names = [version_final()]
amps_desired = amp_desired
# amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class)
cell_type_type = 'cell_type_reclassified'
# frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type, redo = True)
frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual', 'cell_type_reclassified',
'cell'], path_sp='/calc_base_data-base_frame_overview.pkl',
frame_general=False)
default_settings_cells_susept(cells_plot2)
# 0.21
cell_types = [' P-unit']#, ' Ampullary']
cell_types_name = [' P-units']#, 'Ampullary cells', ]
plot_style()
size_evolutionary()
default_ticks_talks()
default_lw_RAM_talks()
names = [' P-unit_talk', ' eigen_P-unit_talk']
gridr = grid_evolutionary()
species = [' Apteronotus leptorhynchus', ' Eigenmannia virescens']
for c, cell in enumerate(cells_plot2):
print(cell)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame)
ims = []
add_here = '_cell' + cell
mats = []
zorders = [100, 50]
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot2)
if len(stack) > 0:
files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True)
stack_file = stack[stack['file_name'] == files[0]]
#embed()
amps = stack_file['amp'].unique()
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
trues = []
for amp in amps_defined:
if amp in amps:
trues.append(True)
# if len(trues) < len(amps):
amps_defined = [np.min(amps)]
cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1']
if cell == cells_amp:
print('cell thing')
embed()
ims = []
for aa, amp in enumerate(amps_defined):
mat, stack_final = load_stack_data_susept(cell, save_name=version_final(), end='')
if amp in np.array(stack_file['amp']):
print(zorders[aa])
ax = plt.subplot(gridr[c*3])
cbar, fr, mat, im = plt_single_matrix(ax, stack_final, ls=None)
colors = colors_overview()
ax.set_title(species[c], color=colors[names[c]], pad = 30)
set_clim_same(ims, mats=mats, lim_type='up', percnr=95)
##########################
# Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln
# save_names = []
save_names2 = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_']
burst_fraction = [1, 1] # ,1,1]
burst_corr_reset = 'burst_fraction_burst_corr_individual_stim'
redo = False
counter = 0
tags = []
frame_load_sp = load_overview_susept(save_names2[0], redo=redo, redo_class=redo)
scores = ['max(diag5Hz)/med_diagonal_proj_fr_base_w_burstcorr','max(diag5Hz)/med_diagonal_proj_fr',
] # + '_diagonal_proj'
x_axiss = ['cv_stim', 'cv_stim', ]# 'burst_fraction_burst_corr_individual_base']
max_xs = [[[], [], []], [[], [], []]]
for c, species in enumerate([' Apteronotus leptorhynchus',' Eigenmannia virescens'][0:len(cells_plot2)]):
for cc, cell_type_here in enumerate(cell_types):
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species)
frame_file = frame_file[frame_file.cv_stim < 5]
################################
# Modulation, cell type comparison
var_it = 'Response Modulation [Hz]'
x_axis = [x_axiss[c]]#, 'cv_base', 'response_modulation'] # ,'fr_base']#
var_item_names = [var_it, var_it] # ,var_it2]#['Response Modulation [Hz]',]
var_types = [''] # ,'response_modulation','']#,'']#'response_modulation'
max_x = max_xs[c]
x_axis_names = [x_axis_talk()]#, 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',]
score = scores[c]
scores_here = [score]#, score_burst_corr, score] # ,score]
score_name = ['Nonlinearity($f_{BaseCorr}$)']#nonlinearity_name_talk()]#, NLI_name2(),NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med'
axss = []
log = '' # 'logall'#''#'logy','logall'True#False
for v, var_type in enumerate(var_types):
if c == 0:
axs = plt.subplot(gridr[2])
axss.append(axs)
if log == 'logy':
ymin = 'no'
else:
ymin = 0
xmin = 0
xlimk = None
labelpad = 0.5 # -1
colors = colors_overview()
fs, ms = size_talk_overview()
cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v],
frame_file, scores_here[v], ymin=ymin,
xmin=xmin, burst_fraction_reset=burst_corr_reset,
var_item=var_type, labelpad=labelpad,
max_x=max_x[v], xlim=xlimk, x_pos=1, fs=fs, ms=ms,
c=c, burst_fraction=burst_fraction[c], sides=False,
color_text=colors[names[c]], ha='right', y_val=1.15,
color_given=colors[names[c]],
legend_spacing=0.1) # : 'tab:blue',
print(cell_type_here + ' median ' + scores_here[v] + '' + str(
np.nanmedian(frame_file[scores_here[v]])))
print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]])))
axs.set_ylabel(score_name[v])
axs.set_xlabel(x_axis_names[v], labelpad=labelpad)
axs.set_ylim(0,7)
axs.set_xlim(0,1.7)
if log == 'logy':
axs.set_yscale('log')
make_log_ticks([axs])
elif log == 'logall':
axs.set_yscale('log')
make_log_ticks([axs])
axs.set_xscale('log')
make_log_ticks([axs])
counter += 1
save_visualization(pdf=True, individual_tag=cells_plot2[0])
def size_evolutionary():
default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(11.8))
def cellscompar(amp_desired=[5, 20], xlim=[],
cells_plot2=[], RAM=True, scale_val=False): # [0, 1.1]
plot_style()
# 'noise_data8_nfft1sec_original__LocalEOD_CutatBeginning_0.05_s_NeurDelay_0.005_s',#__burstIndividual_
# ]
# save_names = ['noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_',
# 'noise_data10_nfft1sec_original__StimPreSaved4__mean5__CutatBeginning_0.05_s_NeurDelay_0.005_s_spikes_']
save_names = [version_final()]
amps_desired = amp_desired
# amps_desired, cell_type_type, cells_plot2, frame, cell_types = load_isis(save_names, amps_desired = amp_desired, cell_class = cell_class)
cell_type_type = 'cell_type_reclassified'
# frame = load_cv_base_frame(cells_plot2, cell_type_type=cell_type_type, redo = True)
frame, frame_spikes = load_cv_vals_susept(cells_plot2, EOD_type='synch',
names_keep=['spikes', 'gwn', 'fs', 'EODf', 'cv', 'fr', 'width_75', 'vs',
'cv_burst_corr_individual',
'fr_burst_corr_individual',
'width_75_burst_corr_individual',
'vs_burst_corr_individual', 'cell_type_reclassified',
'cell'], path_sp='/calc_base_data-base_frame_overview.pkl',
frame_general=False)
default_settings_cells_susept(cells_plot2)
if len(cells_plot2) == 1:
pass
else:
pass
grid = gridspec.GridSpec(1, 1, wspace=0.4, hspace=0.5, top=0.9, left=0.25, bottom=0.12, right=0.99)
# 0.21
cell_types = [' P-unit', ' Ampullary', ]
cell_types_name = [' P-units', 'Ampullary cells', ]
style()
default_figsize(width=cm_to_inch(30.5), length=cm_to_inch(12.19))
default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(13.19))
default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(11.8))
size_evolutionary()
#default_figsize(width=cm_to_inch(33.4), length=cm_to_inch(13.19))
default_ticks_talks()
default_lw_RAM_talks()
gridr = grid_evolutionary()
# 0.21
#grid1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid[0], hspace=0.35,
# wspace=0.35) # ,
# axos = []
# axds = []
name = [' P-unit_talk', ' Ampullary_talk']
for c, cell in enumerate(cells_plot2):
print(cell)
cell_type, eod_fr, fr, frs_calc, isi, spikes, spikes_all = get_base_params(cell, cell_type_type, frame)
# embed()
# stack_final
ims = []
add_here = '_cell' + cell # str(c)
# axo2 = None
# axd2 = None
mats = []
zorders = [100, 50]
if c == 1:
y_label = True
else:
y_label = True
for s, save_name in enumerate(save_names):
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
stack = load_data_susept(load_name + '.pkl', load_name, add=add_here, load_version='csv', cells=cells_plot2)
if len(stack) > 0:
files, stack = exclude_cut_filenames(cell_type, stack, fexclude=True)
stack_file = stack[stack['file_name'] == files[0]]
amps = stack_file['amp'].unique()
predefined_amp = True
if predefined_amp:
amps_defined = amps_desired
else:
amps_defined = amps
trues = []
for amp in amps_defined:
if amp in amps:
trues.append(True)
# if len(trues) < len(amps):
amps_defined = [np.min(amps)] #, np.max(amps)
# embed()
cells_amp = ['2017-10-25-am-invivo-1', '2010-11-26-an-invivo-1']
if cell == cells_amp:
print('cell thing')
embed()
ims = []
for aa, amp in enumerate(amps_defined):
mat, stack_final = load_stack_data_susept(cell, save_name=version_final(), end='')
if amp in np.array(stack_file['amp']):
print(zorders[aa])
ax = plt.subplot(gridr[c*3])
colors = colors_overview()
#axx.set_title(cell_types_name[c], color=colors[cell_type_here])
ax.set_title(cell_types_name[c], color = colors_overview()[name[c]])
cbar, fr, mat, im = plt_single_matrix(ax, stack_final, y_label = y_label, ls=None)
#if c == 0:
#cbar.set_label('')
#set_clim_same(ims, mats=mats, lim_type='up', percnr=95)
#if c == 1:
# remove_yticks(ax)
#################################
# overveiw
###################################
###############################
# Das ist der Finale Score
# 'max(diag5Hz)/med_diagonal_proj_fr','max(diag5Hz)/med_diagonal_proj_fr_base_w_burstcorr',
###################################
# scores = [scoreall+'_diagonal_proj']
##########################
# Auswahl: wir nehmen den mean um nicht Stimulus abhängigen Noise rauszumitteln
# save_names = []
save_names2 = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__direct_']
# save_names = ['calc_RAM_overview-_simplified_noise_data12_nfft0.5sec_original__StimPreSaved4__abs_']
#####################################################
# grid_lower_lower = gridspec.GridSpecFromSubplotSpec(1, 2, grid0[1], wspace = 0.5, hspace=0.55)#, height_ratios = [1,3]
species = ' Apteronotus leptorhynchus'
burst_fraction = [1, 1] # ,1,1]
burst_corr_reset = 'burst_fraction_burst_corr_individual_stim'
redo = False
# embed()
counter = 0
tags = []
frame_load_sp = load_overview_susept(save_names2[0], redo=redo, redo_class=redo)
scores = ['max(diag5Hz)/med_diagonal_proj_fr', 'max(diag5Hz)/med_diagonal_proj_fr',
] # + '_diagonal_proj'
max_xs = [[[], [], []], [[], [], []]]
for c, cell_type_here in enumerate(cell_types[0:len(cells_plot2)]):
frame_file = setting_overview_score(frame_load_sp, cell_type_here, min_amp='range', species=species)
# embed()
test = False
# ok das schließe ich aus weil da irgendwas in der Detektion ist, das betrifft jetzt genau 3 Zellen, also nicht so schlimm
# 63 2018-08-14-af-invivo-1
# 241 2018-09-05-aj-invivo-1
# 252 2022-01-08-ah-invivo-1
frame_file = frame_file[frame_file.cv_stim < 5]
if test:
frame_file[frame_file.cv_base > 3].cell
frame_file[frame_file.cv_stim > 3].cv_stim
#
frame_file.groupby('cell').count()
frame_file.groupby('cell').groups.keys()
frame_file.group_by('cell')
len(frame_file.cell.unique())
##############################################
# modulatoin comparison for both cell_types
################################
# Modulation, cell type comparison
var_it = 'Response Modulation [Hz]'
x_axis = ['cv_stim', 'cv_base', 'response_modulation'] # ,'fr_base']#
var_item_names = [var_it, var_it] # ,var_it2]#['Response Modulation [Hz]',]
var_types = [''] # ,'response_modulation','']#,'']#'response_modulation'
max_x = max_xs[c]
x_axis_names = [x_axis_talk(), 'CV$_{stim}$', 'Response Modulation [Hz]'] # $'+basename()+'$,'Fr$'+basename()+'$',]
# score = scores[0]
score = scores[c]
scores_here = [score, score, score] # ,score]
score_name = [nonlinearity_name_talk(), NLI_scorename2(),
NLI_scorename2()] # NLI_scorename()] # 'Fr/Med''Perc99/Med'
ax_j = []
axls = []
axss = []
# embed()
# frame_max = frame_file[frame_file[score]>5]
log = '' # 'logall'#''#'logy','logall'True#False
for v, var_type in enumerate(var_types):
# ax = plt.subplot(grid0[1+v])#grid_lower[0, v]
if c == 0:
#axx, axy, axs, axls, axss, ax_j = get_grid_4(ax_j, axls, axss, gridr[1])
axs = plt.subplot(gridr[2])
axss.append(axs)
if log == 'logy':
ymin = 'no'
else:
ymin = 0
xmin = 0
xlimk = None
labelpad = 0.5 # -1
fs, ms = size_talk_overview()
cmap, _, y_axis = scatter_with_marginals_colorcoded(var_item_names[v], axs, cell_type_here, x_axis[v],
frame_file, scores_here[v], ymin=ymin,
xmin=xmin, burst_fraction_reset=burst_corr_reset,
var_item=var_type, labelpad=labelpad, max_x=max_x[v],
xlim=xlimk, x_pos=1, fs=fs, ms=ms, c=c,
burst_fraction=burst_fraction[c], sides=False,
color_text=colors_overview()[name[c]], ha='right',
y_val=1.15, color_given=colors_overview()[name[c]],
legend_spacing=0.1)
print(cell_type_here + ' median ' + scores_here[v] + '' + str(
np.nanmedian(frame_file[scores_here[v]])))
print(cell_type_here + ' max ' + x_axis[v] + '' + str(np.nanmax(frame_file[x_axis[v]])))
axs.set_xlim(0, 1.7)
axs.set_ylim(0, 35)
axs.set_ylabel(score_name[v])
axs.set_xlabel(x_axis_names[v], labelpad=labelpad)
extra_lim = False
if extra_lim:
if (' P-unit' in cell_type_here) & ('cv' in x_axis[v]):
axs.set_xlim(xlimk)
if log == 'logy':
axs.set_yscale('log')
make_log_ticks([axs])
elif log == 'logall':
axs.set_yscale('log')
make_log_ticks([axs])
axs.set_xscale('log')
make_log_ticks([axs])
counter += 1
save_visualization(pdf=True, individual_tag=cells_plot2[0])
def default_lw_RAM_talks():
plt.rcParams['lines.linewidth'] = 3
#plt.rcParams['axes.linewidth'] = 22
def diff_label():
return '$|\Delta f_{1} - \Delta f_{2}|$'
def two_deltaf2_label():
return '$2|\Delta f_{2}|$'
def two_deltaf1_label():
return '$2|\Delta f_{1}|$'
def sum_label():
return '$|\Delta f_{1} + \Delta f_{2}|$'
def deltaf2_label():
return '$|\Delta f_{2}|$'
def deltaf1_label():
return '$|\Delta f_{1}|$'