susceptibility1/burst_cells_suscept_appendix.py
2024-02-19 15:46:02 +01:00

325 lines
15 KiB
Python

from utils_suseptibility import *#p_units_to_show,burst_cells
#from plt_RAM import plt_punit
from burst_cells_suscept import burst_cells
# plt_cellbody_singlecell
def plt_cellbody_singlecell_bursts2(grid1, frame, save_names, cells_plot, cell_type_type, plus=1, ax3=[],
burst_corr='_burst_corr_individual'):
#plt.rcParams["font.family"] = "Times New Roman"
colors = colors_overview()
stack = []
axis = []
ims = []
tags_cell = []
#titles = ['Low CV P-unit', 'High CV P-unit', 'Ampullary cell']
for c, cell in enumerate(cells_plot):
print(cell)
frame_cell = frame[(frame['cell'] == cell)]
frame_cell = unify_cell_names(frame_cell, cell_type=cell_type_type)
try:
cell_type = frame_cell[cell_type_type].iloc[0]
except:
print('cell type prob')
embed()
spikes = frame_cell.spikes.iloc[0]
wss = [0.15, 0.66]
# embed()
# eod, sampling_rate, ds, time_eod = find_eod(frame_cell)
tags = []
spikes_base = []
isi = []
frs_calc = []
fr = frame_cell.fr.iloc[0]
cv = frame_cell.cv.iloc[0]
vs = frame_cell.vs.iloc[0]
eod_fr = frame_cell.EODf.iloc[0]
print('EODF' + str(eod_fr))
spikes_base, isi, frs_calc, cont_spikes = load_spikes(spikes, eod_fr)
save_name_type = ['_allspikes_', '_burstIndividual_']
colors_b = ['grey', colors[cell_type]]
ims = []
wr_l = [1, 0.1, 1, 1]
wr_l = wr_l_cells_susept()
wr_u = [1, 0.1, 1, 1]
grid_cell, grid_upper = grids_upper_susept_pics(c, grid1, hs=0.75, row = 2, hr = [1, 0.8], wr_u = wr_u)
wss = ws_for_susept_pic()
# todo: das wenn die zwei ungleich ist ist noch ein Problem
widht_ratios = [2 + wss[0], 2 + wss[1]]
grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(save_names), grid_cell[1], hspace=0.1, wspace=0.15,
width_ratios=widht_ratios)
'''grid_s1 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[0],
hspace=0.1, wspace=wss[0],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
# plot the same also to the next plot
grid_s2 = gridspec.GridSpecFromSubplotSpec(2, 2, grid_lower[1],
hspace=0.1, wspace=wss[1],
width_ratios=[0.8,
1]) # height_ratios=[1.5, 1.5, 5],
'''
several = False
extra_input = False#>
axes = []
axos = []
axds = []
axd2, axi, axo2, grid_lower, grid_s1, grid_s2 = grids_for_psds(save_names, extra_input, grid_cell,
several, widht_ratios = widht_ratios, wss = wss, wr = wr_l)
#axo2 = None
#axd2 = None
test_clim = False
zorder = [100,1]
mats = []
ax_psds = []
add_nonlin_title = None
title_squares = ['All spikes, ', 'First spike, ']
var = ['fr','fr_burst_corr_individual']
for aa, save_name in enumerate(save_names):
add_save = '_cell' + cell + save_name_type[aa] #
# grid_lower = gridspec.GridSpecFromSubplotSpec(1, len(save_names), grid_cell[1], hspace=0.1,
# wspace=0.15)
title_square = title_squares[aa]
load_name = load_folder_name('calc_RAM') + '/' + save_name + '_' + cell
# embed()
# axes = []
stack = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv', cells = cells_plot2)
if len(stack) > 0:
files = stack['file_name'].unique()
amps = stack['amp'].unique()
file_name = files[0]
stack_file = stack[stack['file_name'] == file_name]
amps_defined = [np.min(stack_file['amp'].unique())]
# embed()
xpos_xlabel = -0.24
for aaa, amp in enumerate(amps_defined):
if amp in np.array(stack_file['amp']):
alpha = find_alpha_val(aa, save_names)
#add_save = '_cell' + str(cell) + '_amp_' + str(amp)
xlim = [0,1.1]#
if not several:
#embed()
#lim = find_lim_here(cell, 'individual')
fr = frame[frame.cell == cell][var[aa]].iloc[0]
fr_bc = frame[frame.cell == cell][var[-1]].iloc[0]
#embed()
diagonals_prj_l, axi, eod_fr, fr, stack_final1, axds, axos, ax_square, axo2, axd2,mat, add_nonlin_title = plt_psds_in_one_squares_next(
aa, add_save, amp, amps_defined, axds, axes, axis, axos, c, cells_plot, colors_b,
eod_fr, file_name, grid_lower, ims, load_name, save_names, stack_file, wss, xlim = [],
test_clim=test_clim, zorder=zorder[aa], alpha=alpha, extra_input=extra_input, fr=fr,
title_square=title_square,fr_diag = fr_bc, xpos_xlabel=xpos_xlabel, add_nonlin_title=add_nonlin_title,
color=colors[cell_type], axo2=axo2, peaks_extra = True, axd2=axd2, axi=axi, iterate_var=save_names, amp_give = False)
mats.append(mat)
print(np.max(np.max(mat)))
else:
axi, eod_fr, fr, stack_final1, stack_spikes, axds, axos, ax_square, axo2, axd2 = plt_psds_in_one_squares(
aa, add, amp,
amps_defined, axds, axes,
axis, axos, c, cells_plot,
colors_b, eod_fr,
file_name, files, fr,
grid_s1, grid_s2, ims,
load_name, save_names,
stack_file, wss, xlim, axo2=axo2, axd2=axd2, iterate_var=save_names)
if aa == 0:
if extra_input:
tags.append(axi)
else:
tags.append(axo2)
tags.append(ax_square)
if aa == 1:
tags.append(axd2)
ax_psds.append(axo2)
ax_psds.append(axd2)
################################
# do the scatter of these cells
add = ['', '_burst_corr', ]
add = ['', '_burst_corr_individual']
# embed()
if len(stack) > 0:
load_name = load_folder_name('calc_RAM') + '/' + save_names[aa] + '_' + cell
if ax3 != []:
try:
frame_g = base_to_stim(load_name, frame, cell_type_type, cell_type, stack=stack)
except:
print('stim problem')
embed()
try:
ax3.scatter(frame_g['cv'], frame_g['cv_stim'], zorder=2, alpha=1,
label=cell_type, s=15,
color=colors[str(cell_type)], facecolor='white')
except:
print('scatter problem')
embed()
######################################################
#if aa == 0:
# color_here = 'grey'
#else:
color_here = ['grey', colors[str(cell_type)]][aa]#[colors[str(cell_type)]
add = ['', '_burst_corr', ]
add = ['', '_burst_corr_individual']
#colors_hist = ['grey', colors[str(cell_type)]]
# if len(hists_both) > 1:
# colors_hist = ['grey', colors[str(cell_type)]]
# else:
# colors_hist = [colors[str(cell_type)]]
# for gg in range(len(hists_both)):
# if len(hists_both) > 1:
# hists_here = hists_both[gg]
# embed()
# spikes_base, hists, frs_calc, cont_spikes = load_spikes(spikes, eod_fr, spikes_base, hists, frs_calc)
xlim_e = [0, 200]
#if 'spikes' in stack_final1.keys():
if aa == 0:
axss = plt.subplot(grid_upper[1, -2::])
# axii = plt.subplot(grid_upper[:, 0])
# try:
# embed()
# ax_spikes = plt.subplot(grid_upper[1, 1::])
eod_fr, length, stack_final, stack_final1, trial_nr = stack_preprocessing(amp, stack_file)
stack_spikes = load_data_susept(load_name + '.pkl', load_name, add=add_save, load_version='csv',
load_type='spikes',
trial_nr=trial_nr, stimulus_length=length, amp=amp,
file_name=file_name, redo = True)
# das mache ich damit der Stimulus für Mascha hier einmal mit abgespeichert ist
#stack_stim = load_data_susept(load_name + '.pkl', load_name, add=add + '_sampling_' + str(sampling),
# load_version='csv', load_type='stimulus',
# trial_nr=trial_nr, stimulus_length=length, redo=True, amp=amp,
# file_name=file_name)
plt_spikes(amps_defined, aa, c, cell, cell_type, cells_plot, color_here, eod_fr, fr, axss,
stack_final1, stack_spikes, xlim, axi=axi,xlim_e = [0,150], alpha = alpha, spikes_max = 3)
# print('spikes not there yet')
#else:
# eod_mt, sampling, spikes_loaded = nix_load(cell, stack_final1)
##############################
# isi
if len(isi) > 0:
if aa == 0:
grid_p = gridspec.GridSpecFromSubplotSpec(1, 2, grid_upper[:, 0], width_ratios=[1.4, 2],
wspace=0.3,
hspace=0.55)
#grid_p = gridspec.GridSpecFromSubplotSpec(2, 1, grid_upper[:, 0], height_ratios=[1.4, 2],
# hspace=0.55)
# hspace=0.25)
ax_isi = plt.subplot(grid_p[0])
ax_p = plt.subplot(grid_p[1])
tags.insert(0, ax_isi)
lim_here = find_lim_here(cell, burst_corr=burst_corr)
if np.min(np.concatenate(isi)) < lim_here:
_, spikes_ex, frs_calc2 = correct_burstiness(isi, spikes_base,
[eod_fr] * len(spikes_base),
[eod_fr] * len(spikes_base), lim=lim_here,
burst_corr=burst_corr)
else:
# da machen wir die Burst corr spikes anders
lim_here = find_lim_here(cell, burst_corr=burst_corr)
if np.min(np.concatenate(isi)) < lim_here:
isi, spikes_ex, frs_calc2 = correct_burstiness(isi, spikes_base,
[eod_fr] * len(spikes_base),
[eod_fr] * len(spikes_base), lim=lim_here,
burst_corr=burst_corr)
#embed()
#if len(isi[0])<1:
# print('len thing')
# embed()
right = False
ax_isi = base_cells_susept(ax_isi, ax_p, c, cell, cell_type, cells_plot, colors, eod_fr, frame,
isi, right, spikes_ex,stack,xlim,add_texts = [-3.1,0],texts_left = [250.3,0],peaks = True, pos = -0.55, titles = ['Bursty P-unit,','Bursty P-unit,'],fr_name = '$f_{BaseCorrected}$')
#plt_susept_isi_base(c, cell_type, cells_plot,'grey', ax_isi, isi, xlim=[])# color_here
#ax_p = plt_susept_psd_base(cell_type, 'grey', eod_fr, ax_p, spikes_base, xlim,
# right) #colors[str(cell_type)]
#remove_xticks(ax_p)
#else:
# embed()
##################################
# stimulus
xlim_e = [0, 100]
if aa == 0:
axe = plt.subplot(grid_upper[0, -2::])
plt_stimulus(eod_fr, axe, stack_final1, xlim_e, files[0])
tags.insert(1, axe)
set_same_ylimscale(ax_psds)
labels_for_psds(axd2, axi, axo2, extra_input, right = right, xpos_xlabel = xpos_xlabel)
tags_cell.append(tags)
if not test_clim:
set_clim_same_here(ims, mats=mats, lim_type='up', mean_type = True, percnr = 94)
#set_clim_same_here(ims, clims='all', same = 'same', lim_type='up')
#set_clim_same_here(ims, clims='all', same = 'same', lim_type='up')
# join_y(axes)
# axds[0].get_shared_y_axes().join(*axds)
# embed()
# todo: das muss noch der gleiche ylim sein
try:
set_same_ylim(axos)
set_same_ylim(axds)
except:
print('axo thing')
#fig = plt.gcf()
#fig.tag(tags_cell, xoffs=-4, yoffs=1.9) # -1.5diese Offsets sind nicht intuitiv
if not test_clim:
try:
if len(cells_plot2) ==1:
tags_susept_pictures(tags_cell)
else:
tags_susept_pictures(tags_cell,yoffs=np.array([1.1, 1.1, 2.9, 2.9, 2.9, 2.9]))
except:
print('tag thing')
embed()
if __name__ == '__main__':
cells_plot2 = p_units_to_show(type_here = 'bursts')
burst_cells( cells_plot2 = [cells_plot2[1]], show = True, cell_class =' P-unit')