289 lines
13 KiB
Python
289 lines
13 KiB
Python
import numpy as np
|
|
from matplotlib import gridspec, pyplot as plt
|
|
|
|
from plotstyle import plot_style
|
|
from threefish.defaults import default_figsize
|
|
from threefish.load import save_visualization
|
|
from threefish.RAM.plot_labels import title_motivation
|
|
from threefish.RAM.plot_subplots import circle_plot, colors_suscept_paper_dots, plot_arrays_ROC_psd_single3, \
|
|
plot_shemes4
|
|
from threefish.RAM.reformat import chose_certain_group, extract_waves, load_cells_three, \
|
|
predefine_grouping_frame, save_arrays_susept
|
|
from threefish.RAM.values import find_all_dir_cells, ws_nonlin_systems
|
|
from threefish.reformat import load_b_public
|
|
|
|
|
|
def motivation_all_small(dev_desired = '1', ylim=[-1.25, 1.25], c1=10, devs=['2'],
|
|
figsize=None, save=True, end='0', sorted_on='LocalReconst0.2Norm'):
|
|
|
|
|
|
plot_style()
|
|
default_figsize(column=2, length=4.3) #6.7 ts=12, ls=12, fs=12
|
|
show = True
|
|
|
|
datasets, data_dir = find_all_dir_cells()
|
|
|
|
cells = ['2021-08-03-ac-invivo-1']
|
|
c2 = 10
|
|
eodftype = '_psdEOD_'
|
|
chirps = [
|
|
''] # '_ChirpsDelete3_',,'_ChirpsDelete3_'','','',''#'_ChirpsDelete3_'#''#'_ChirpsDelete3_'#'#'_ChirpsDelete2_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsDelete_'#''#'_ChirpsCache_'
|
|
extract = '' # '_globalmax_'
|
|
if len(cells) < 1:
|
|
data_dir, cells = load_cells_three(end, data_dir=data_dir, datasets=datasets)
|
|
|
|
cells = ['2021-08-03-ac-invivo-1']
|
|
ax_s = []
|
|
|
|
for c, cell in enumerate(cells):
|
|
contrasts = [c2]
|
|
|
|
for c, contrast in enumerate(contrasts):
|
|
DF1_desired = [1.2]#DF1_desired # [::-1]
|
|
DF2_desired = [0.95]#DF2_desired # [::-1]
|
|
#embed()
|
|
|
|
#######################################
|
|
# ROC part
|
|
|
|
b = load_b_public(c, cell, data_dir)
|
|
|
|
frame_loaded = predefine_grouping_frame(b, eodftype=eodftype, cell_name=cell)
|
|
frame_loaded = frame_loaded[(frame_loaded['c2'] == c2) & (frame_loaded['c1'] == c1)]
|
|
for gg in range(len(DF1_desired)):
|
|
ax_w = []
|
|
|
|
###################
|
|
# all trials in one
|
|
group_mean = group_saved_matrix(DF1_desired, DF2_desired, gg, frame_loaded)
|
|
|
|
detection = 'MeanTrialsIndexPhaseSort'
|
|
mean_type = '_' + detection # + '_' + minsetting + '_' + extend_trials + concat
|
|
|
|
##############################################################
|
|
# load plotting arrays
|
|
arrays, arrays_original, spikes_pure = save_arrays_susept(
|
|
data_dir, cell, c, chirps, extract, group_mean, mean_type, plot_group=0,
|
|
rocextra=False, sorted_on=sorted_on, dev_desired = dev_desired)
|
|
####################################################
|
|
|
|
if figsize:
|
|
fig = plt.figure(figsize=figsize)
|
|
else:
|
|
fig = plt.figure()
|
|
grid = gridspec.GridSpec(2, 1, wspace=0.7, hspace=0.15, left=0.055, top=0.96,
|
|
bottom=0.15,
|
|
right=0.935, height_ratios=[0.5, 5.3]) # height_ratios=[1, 2], height_ratios = [1,6]bottom=0.25, top=0.8,
|
|
|
|
##########################################################################
|
|
# plot shemes above (top)
|
|
grid00 = gridspec.GridSpecFromSubplotSpec(1, 4, wspace=0.15, hspace=0.05,
|
|
subplot_spec=grid[0, :])
|
|
|
|
plot_pictograms(ax_s, grid00)
|
|
|
|
##########################################################################
|
|
# plot stimulus (first row)
|
|
grid0 = gridspec.GridSpecFromSubplotSpec(5, 4, wspace=0.15, hspace=0.35,
|
|
subplot_spec=grid[1, :],
|
|
height_ratios=[1, 0.35, 1.2, 0, 3, ])
|
|
|
|
color0, color01, color012, color01_2, color02, color0_burst, xlim = plot_stimulus_motivation(ax_w,
|
|
grid0,
|
|
group_mean,
|
|
ylim)
|
|
|
|
##########################################
|
|
# spike response (bottom)
|
|
|
|
array_chosen = 1
|
|
smoothed_base = arrays[0][0]
|
|
mat_base = arrays_original[0][0]
|
|
|
|
fr_isi, ax_ps, ax_as = plot_arrays_ROC_psd_single3(
|
|
[[smoothed_base], arrays[2], arrays[1], arrays[3]],
|
|
[[mat_base], arrays_original[2], arrays_original[1],
|
|
arrays_original[3]], spikes_pure, cell, grid0, mean_type,
|
|
group_mean, xlim=xlim, row=1,
|
|
array_chosen=array_chosen,
|
|
color0_burst=color0_burst, color01=color01, color02=color02,ylim_log=(-22, 3),
|
|
color012=color012,color012_minus = color01_2,color0=color0)
|
|
|
|
##########################################################################
|
|
|
|
individual_tag = 'DF1' + str(DF1_desired[gg]) + 'DF2' + str(
|
|
DF2_desired[gg]) + cell + '_c1_' + str(c1) + '_c2_' + str(c2) + mean_type
|
|
axes = []
|
|
axes.append(ax_w)
|
|
fig.tag(ax_s, xoffs=-1.9, yoffs=1.2)
|
|
if save:
|
|
save_visualization(individual_tag=individual_tag, show=show, pdf=True)
|
|
|
|
|
|
def group_saved_matrix(DF1_desired, DF2_desired, gg, mt_sorted):
|
|
grouped = mt_sorted.groupby(
|
|
['c1', 'c2', 'm1, m2'],
|
|
as_index=False)
|
|
grouped_mean = chose_certain_group(DF1_desired[gg],
|
|
DF2_desired[gg], grouped,
|
|
several=True, emb=False,
|
|
concat=True)
|
|
grouped = mt_sorted.groupby(
|
|
['c1', 'c2', 'm1, m2', 'repro_tag_id'],
|
|
as_index=False)
|
|
grouped_orig = chose_certain_group(DF1_desired[gg],
|
|
DF2_desired[gg],
|
|
grouped,
|
|
several=True)
|
|
group_mean = [grouped_orig[0][0], grouped_mean]
|
|
return group_mean
|
|
|
|
|
|
def plot_stimulus_motivation(ax_w, grid0, group_mean, ylim):
|
|
xlim = [0, 100]
|
|
stimulus_length = 0.3
|
|
deltat = 1 / 40000
|
|
eodf = np.mean(group_mean[1].eodf)
|
|
eod_fr = eodf
|
|
a_fr = 1
|
|
eod_fe = eodf + np.mean(
|
|
group_mean[1].DF2) # data.eodf.iloc[0] + 10 # cell_model.eode.iloc[0]
|
|
a_fe = group_mean[0][1] / 100
|
|
eod_fj = eodf + np.mean(
|
|
group_mean[1].DF1) # data.eodf.iloc[0] + 50 # cell_model.eodj.iloc[0]
|
|
a_fj = group_mean[0][0] / 100
|
|
variant_cell = 'no' # 'receiver_emitter_jammer'
|
|
eod_fish_j, time_array, time_fish_r, eod_fish_r, time_fish_e, eod_fish_e, time_fish_sam, eod_fish_sam, stimulus_am, stimulus_sam = extract_waves(
|
|
variant_cell, '',
|
|
stimulus_length, deltat, eod_fr, a_fr, a_fe, [eod_fe], 0, eod_fj, a_fj)
|
|
titles = title_motivation()
|
|
gs = [0, 1, 2, 3, 4]
|
|
waves_presents = [['receiver', '', '', 'all'],
|
|
['receiver', 'emitter', '', 'all'],
|
|
['receiver', '', 'jammer', 'all'],
|
|
|
|
['receiver', 'emitter', 'jammer', 'all'],
|
|
] # ['', '', '', ''],['receiver', '', '', 'all'],
|
|
symbols = ['', '', '', '', '']
|
|
time_array = time_array * 1000
|
|
color01, color012, color01_2, color02, color0_burst, color0 = colors_suscept_paper_dots()
|
|
colors_am = ['black', 'black', 'black', 'black'] # color01, color02, color012]
|
|
extracted = [False, True, True, True]
|
|
extracted2 = [False, False, False, False]
|
|
for i in range(len(waves_presents)):
|
|
ax = plot_shemes4(eod_fish_r, eod_fish_e, eod_fish_j, grid0, time_array,
|
|
g=gs[i], title_top=True, eod_fr=eod_fr,
|
|
waves_present=waves_presents[i], ylim=ylim,
|
|
xlim=xlim, color_am=colors_am[i],
|
|
color_am2=color01_2, extracted=extracted[i], extracted2=extracted2[i],
|
|
title=titles[i], add=0.1) # 'intruder','receiver'#jammer_name
|
|
|
|
ax_w.append(ax)
|
|
if ax:
|
|
ax.text(1.1, 0.45, symbols[i], fontsize=35, transform=ax.transAxes)
|
|
bar = False
|
|
if bar:
|
|
if i == 0:
|
|
ax.plot([0, 20], [ylim[0] + 0.01, ylim[0] + 0.01], color='black')
|
|
ax.text(0, -0.16, '20 ms', va='center', fontsize=10,
|
|
transform=ax.transAxes)
|
|
return color0, color01, color012, color01_2, color02, color0_burst, xlim
|
|
|
|
|
|
def plot_pictograms(ax_s, grid00):
|
|
texts1 = ['', '$s_{1}(t)$', '$s_{2}(t)$', '$s_{1} +s_{2}(t)$']
|
|
texts2 = ['$r_{0}$', '$r_{0} +r_{1}(t)$', '$r_{0} +r_{2}(t)$', r'$r_{t} \neq r_{0}+r_{1}(t)+r_{2}(t)$']
|
|
for g in range(4):
|
|
horizontal = True
|
|
if horizontal:
|
|
grid000 = gridspec.GridSpecFromSubplotSpec(1, 4, wspace=0, hspace=0,
|
|
subplot_spec=grid00[g], width_ratios=[2, 0.7, 2, 1.6])
|
|
else:
|
|
grid000 = gridspec.GridSpecFromSubplotSpec(3, 1, wspace=0, hspace=0,
|
|
subplot_spec=grid00[g])
|
|
ax0 = plt.subplot(grid000[0])
|
|
color = 'black' # color_beats()
|
|
# ax0.plot(time_array, sine, color=color, clip_on=False)
|
|
ax0.show_spines('')
|
|
# ax0.set_title('$s(t)$') # xy=(0.2, 0.2),
|
|
ax0.show_spines('')
|
|
# xytext=(0.8, 0.8),
|
|
lw = 0.5
|
|
ws = ws_nonlin_systems()
|
|
fs = 8
|
|
|
|
middle = 0.5
|
|
if horizontal:
|
|
start = 0.7
|
|
if texts1[g] != '':
|
|
ax0.annotate('', ha='center', xycoords='axes fraction',
|
|
xy=(1, middle), textcoords='axes fraction',
|
|
xytext=(start, middle),
|
|
arrowprops={"arrowstyle": "->",
|
|
"linestyle": "-",
|
|
"linewidth": lw,
|
|
"color":
|
|
'black'},
|
|
zorder=1, annotation_clip=False, transform=ax0.transAxes, )
|
|
ax0.text(start, 0.5, texts1[g], transform=ax0.transAxes, ha='right',
|
|
va='center', fontsize=fs)
|
|
else:
|
|
start = 1.5
|
|
if g != texts1[g]:
|
|
ax0.annotate('', ha='center', xycoords='axes fraction',
|
|
xy=(middle, start), textcoords='axes fraction',
|
|
xytext=(middle, 0),
|
|
arrowprops={"arrowstyle": "<-",
|
|
"linestyle": "-",
|
|
"linewidth": lw,
|
|
"color":
|
|
'black'},
|
|
zorder=1, annotation_clip=False, transform=ax0.transAxes, )
|
|
ax0.text(0.5, start, texts1[g], transform=ax0.transAxes, ha='center', va='center')
|
|
|
|
ax_s.append(ax0)
|
|
# embed()
|
|
# fig.texts.append(ax[0].texts.pop())
|
|
|
|
###################################
|
|
ax1 = plt.subplot(grid000[1])
|
|
circle_plot(ax1, ws)
|
|
ax1.show_spines('')
|
|
ax1.set_xlim(0, 20)
|
|
ax1.set_ylim(-20, 40)
|
|
####################################texts1[g]texts2[g]
|
|
ax2 = plt.subplot(grid000[2])
|
|
|
|
if horizontal:
|
|
end = 0.3
|
|
ax2.annotate('', ha='center', xycoords='axes fraction',
|
|
xy=(end, middle), textcoords='axes fraction',
|
|
xytext=(0, middle),
|
|
arrowprops={"arrowstyle": "->",
|
|
"linestyle": "-",
|
|
"linewidth": lw,
|
|
"color":
|
|
'black'},
|
|
zorder=1, annotation_clip=False, transform=ax2.transAxes, )
|
|
ax2.text(end, 0.5, texts2[g], transform=ax2.transAxes, ha='left', va='center', fontsize=fs)
|
|
|
|
else:
|
|
end = -0.5
|
|
ax2.annotate('', ha='center', xycoords='axes fraction',
|
|
xy=(middle, end), textcoords='axes fraction',
|
|
xytext=(middle, 1),
|
|
arrowprops={"arrowstyle": "->",
|
|
"linestyle": "-",
|
|
"linewidth": lw,
|
|
"color":
|
|
'black'},
|
|
zorder=1, annotation_clip=False, transform=ax2.transAxes, )
|
|
ax2.text(middle, end, texts2[g], transform=ax2.transAxes, ha='center', va='center')
|
|
|
|
ax2.show_spines('')
|
|
|
|
|
|
if __name__ == '__main__':#2.5
|
|
motivation_all_small(dev_desired = '1', c1=10, devs=['05'], save=True, end='all',
|
|
sorted_on='LocalReconst0.2NormAm')#
|