highbeats_pdf/notchosen/examples.py
2020-11-30 13:15:45 +01:00

429 lines
19 KiB
Python

import nixio as nix
import os
from IPython import embed
#from utility import *
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.mlab as ml
import scipy.integrate as si
from scipy.ndimage import gaussian_filter
from IPython import embed
from myfunctions import *
#from axes import label_axes, labelaxes_params
from myfunctions import auto_rows
#from differentcells import default_settings
#from differentcells import plot_single_cells
import matplotlib.gridspec as gridspec
from functionssimulation import single_stim
import math
from functionssimulation import find_times
from functionssimulation import rectify
from functionssimulation import global_maxima
from functionssimulation import integrate_chirp
from functionssimulation import find_periods
from myfunctions import default_settings
from axes import labelaxes_params,label_axes
from mpl_toolkits.axes_grid1 import host_subplot
import mpl_toolkits.axisartist as AA
import string
def plot_single_cells(ax,colors = ['#BA2D22', '#F47F17', '#AAB71B', '#3673A4', '#53379B'], data = ['2019-10-21-aa-invivo-1','2019-11-18-af-invivo-1','2019-10-28-aj-invivo-1'], var = '05'):
# labelaxes_params(xoffs=-3, yoffs=0, labels='A', font=dict(fontweight='bold'))
data_all = pd.read_pickle('beat_results_smoothed.pkl')
end = ['original', '005','05', '2' ]
end = [var]
y_sum = [[]] * len(data)
axis = {}
for dd,set in enumerate(data):
for ee, e in enumerate(end):
d = data_all[data_all['dataset'] == set]
eod = d['eodf'].iloc[0]
x = d['delta_f'] / d['eodf'] + 1
xx = d['delta_f']
y = d['result_frequency_' + e]
y2 = d['result_amplitude_max_' + e]
y_sum[dd] = np.nanmax(y)
axis[1] = plt.subplot(ax[0])
axis[1].plot(x, y, zorder = 1,color=colors[0])
axis[1].set_ylabel('AF [Hz]')
axis[1].set_xlim([0, 4])
labels = [item.get_text() for item in axis[1].get_xticklabels()]
empty_string_labels = [''] * len(labels)
axis[1].set_xticklabels(empty_string_labels)
axis[2] = host_subplot(ax[1], axes_class=AA.Axes)
#axis[2] = plt.subplot(ax[1])
#host = host_subplot(ax[1], axes_class=AA.Axes)
#host.spines['right'].set_visible(False)
#host.spines['top'].set_visible(False)
#axis[2] = host.twiny()
axis[2].plot(xx, y2, label="Beats [Hz]", zorder = 2,color=colors[0])
axis[2].set_ylabel('Modulation ')
axis[1].spines['right'].set_visible(False)
axis[1].spines['top'].set_visible(False)
axis[2].spines['right'].set_visible(False)
axis[2].spines['top'].set_visible(False)
axis[1].set_xlim([0, np.max(x)])
axis[2].set_xlim([-eod, np.max(xx)])
nr_size = 10
axis[2].text(-0.02, 1.1, string.ascii_uppercase[4],
transform=axis[2].transAxes,
size=nr_size, weight='bold')
axis[1].text(-0.02, 1.1, string.ascii_uppercase[3],
transform=axis[1].transAxes,
size=nr_size, weight='bold')
axis[3] = axis[2].twiny()
axis[3].set_xlabel('EOD multiples')
offset = -40
new_fixed_axis = axis[3].get_grid_helper().new_fixed_axis
axis[3].axis["bottom"] = new_fixed_axis(loc="bottom", axes=axis[3],
offset=(0,offset))
axis[3].spines['right'].set_visible(False)
axis[3].spines['top'].set_visible(False)
axis[3].axis["bottom"].toggle(all=True)
axis[2].set_xlabel("Difference frequency [Hz]")
#par2.set_xlim([np.min(xx), np.max(xx)])
axis[3].set_xlim([0, np.max(x)])
#p1, = host.plot([0, 1, 2], [0, 1, 2], label="Density")
#p2, = par1.plot([0, 1, 2], [0, 3, 2], label="Temperature")
p3, = axis[3].plot(x, y2,color = 'grey',zorder = 1)
#embed()
axis[2].set_xticks(np.arange(-eod,np.max(xx),eod/2))
#ax['corr'].set_yticks(np.arange(eod_fe[0] - eod_fr, eod_fe[-1] - eod_fr, eod_fr / 2))
axis[2].set_ylim([0, np.nanmax(y_sum)])
plt.subplots_adjust(wspace = 0.4,left = 0.17, right = 0.96,bottom = 0.2)
return axis,y2
def plot_beat_corr(ax,lower,beat_corr_col = 'red',df_col = 'pink',ax_nr = 3,multiple = 3):
eod_fr = 500
eod_fe = np.arange(0, eod_fr * multiple, 5)
beats = eod_fe - eod_fr
beat_corr = eod_fe % eod_fr
beat_corr[beat_corr > eod_fr / 2] = eod_fr - beat_corr[beat_corr > eod_fr / 2]
#gs0 = gridspec.GridSpec(3, 1, height_ratios=[4, 1, 1], hspace=0.7)
#plt.figure(figsize=(4.5, 6))
style = 'dotted'
color_v = 'black'
color_b = 'silver'
# plt.subplot(3,1,1)
ax['corr'] = plt.subplot(lower[ax_nr])
np.max(beats) / eod_fr
ax['corr'].set_xticks(np.arange((eod_fe[0]-eod_fr)/eod_fr+1, (eod_fe[-1]-eod_fr)/eod_fr+1,(eod_fr/2)/eod_fr+1))
ax['corr'].set_yticks(np.arange((eod_fe[0]-eod_fr)/eod_fr+1, (eod_fe[-1]-eod_fr)/eod_fr+1,(eod_fr/2)/eod_fr+1))
ax['corr'].set_xticks(np.arange(0,10,0.5))
ax['corr'].set_yticks(np.arange(0, 10, 0.5))
# plt.axvline(x = -250, Linestyle = style,color = color_v)
# plt.axvline(x = 250, Linestyle = style,color = color_v)
# plt.axvline(x = 750, Linestyle = style,color = color_v)
# plt.axvline(x = 1500, Linestyle = style)
# plt.subplot(3,1,2)
plt.xlabel('Beats [Hz]')
plt.ylabel('Difference frequency [Hz]')
#plt.subplot(gs0[1])
if beat_corr_col != 'no':
plt.plot(beats/eod_fr+1, beat_corr/(eod_fr+1), color=beat_corr_col, alpha = 0.7)
plt.ylim([0,np.max(beat_corr/(eod_fr+1))*1.4])
plt.xlim([(beats/eod_fr+1)[0],(beats/eod_fr+1)[-1]])
if df_col != 'no':
plt.plot(beats/eod_fr+1, np.abs(beats)/(eod_fr+1), color=df_col, alpha = 0.7)
#plt.axvline(x=-250, Linestyle=style, color=color_v)
#plt.axvline(x=250, Linestyle=style, color=color_v)
#plt.axvline(x=750, Linestyle=style, color=color_v)
plt.xlabel('EOD adjusted beat [Hz]')
ax['corr'] .spines['right'].set_visible(False)
ax['corr'] .spines['top'].set_visible(False)
ax['corr'] .spines['left'].set_visible(True)
ax['corr'] .spines['bottom'].set_visible(True)
# plt.axvline(x = 1250, Linestyle = style,color = color_v)
# plt.axvline(x = 1500, Linestyle = style,color = color_v)
mult = np.array(beats) / eod_fr + 1
# plt.subplot(3,1,3)
plt.xlabel('EOD multiples')
plt.ylabel('EOD adj. beat [Hz]', fontsize = 10)
plt.grid()
#plt.subplot(gs0[2])
#plt.plot(mult, beat_corr, color=color_b)
# plt.axvline(x = 0, Linestyle = style)
#plt.axvline(x=0.5, Linestyle=style, color=color_v)
# plt.axvline(x = 1, Linestyle = style)
#plt.axvline(x=1.5, Linestyle=style, color=color_v)
#plt.axvline(x=2.5, Linestyle=style, color=color_v)
#plt.xlabel('EOD multiples')
#plt.ylabel('EOD adj. beat [Hz]', fontsize = 10)
return ax
def try_resort_automatically():
diffs = np.diff(dfs)
fast_sampling = dfs[np.concatenate([np.array([True]),diffs <21])]
second_derivative = np.diff(np.diff(fast_sampling))
first_index = np.concatenate([np.array([False]),second_derivative <0])
second_index = np.concatenate([second_derivative > 0,np.array([False])])
remaining = fast_sampling[np.concatenate([np.array([True]),second_derivative == 0, np.array([True])])]
first = np.arange(0,len(first_index),1)[first_index]
second = np.arange(0, len(second_index), 1)[second_index]-1
residual = []
indeces = []
for i in range(len(first)):
index = np.arange(first[i],second[i],2)
index2 = np.arange(first[i], second[i], 1)
indeces.append(index2)
residual.append(fast_sampling[index])#first[i]:second[i]:2
indeces = np.concatenate(indeces)
remaining = fast_sampling[~indeces]
residual = np.concatenate(residual)
new_dfs = np.sort(np.concatenate([residual, remaining]))
if __name__ == "__main__":
data = ['2019-10-21-aa-invivo-1']
data = ['2019-09-23-ad-invivo-1']
labelaxes_params(xoffs=1, yoffs=0, labels='A', font=dict(fontweight='bold'))
labelaxes_params(xoffs=-6, yoffs=1, labels='A', font=dict(fontweight='bold'))
default_settings(data,intermediate_width = 6.29,intermediate_length = 7.5, ts = 6, ls = 8, fs = 9)
fig = plt.figure()
#fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True)
ax = {}
#ax = plt.subplot(grid[2])
data_all = pd.read_pickle('data_beat.pkl')
d = data_all[data_all['dataset'] == data[0]]
eod = d['eodf'].iloc[0]
dfs = np.unique(d['df'])
#embed()
grid = gridspec.GridSpec(2, 4, wspace=0.0, height_ratios=[6, 2], width_ratios=[1,1,0.3,3], hspace=0.2)
low_nr = 60
from_middle = 45 #20
example_df = [low_nr- eod,eod / 2 - from_middle - eod,eod - low_nr - eod,low_nr,eod / 2 - from_middle, low_nr + eod]
#example_df = [1, eod / 2 - 20 - eod, eod - low_nr - eod, low_nr, eod / 2 - 20, low_nr + eod]
rows = len(example_df)
cols = 1
power_raster = gridspec.GridSpecFromSubplotSpec(rows, cols,
subplot_spec=grid[0, 0],wspace = 0.05, hspace=0.3)
max_p = [[]]*len(example_df)
for i in range(len(example_df)):
power = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=[1,1.7],hspace = 0.2, wspace = 0.2, subplot_spec = power_raster[i])
first = ['crimson', 'lightcoral', 'darkviolet']
second = ['hotpink', 'deeppink', 'mediumvioletred']
third = ['khaki', 'yellow', 'gold']
third = ['orange', 'orangered', 'darkred']
fourth = ['DarkGreen', 'LimeGreen', 'YellowGreen']
fith = ['SkyBlue', 'DeepSkyBlue', 'Blue']
colors = np.concatenate([fourth, third, first])
ax_nr = 0
ax['scatter_small'+str(i)] = plt.subplot(power[ax_nr])
eod_fr = eod
eod_fe = [example_df[i] + eod]
e = 0
factor = 200
sampling = 500 * factor
minus_bef = -250
plus_bef = -200
#minus_bef = -2100
#plus_bef = -100
f_max, lims, _ = single_stim(ax, [colors[i]], 1, 1, eod_fr, eod_fe, e, power,delta_t = 0.001, add = 'no',minus_bef =minus_bef, plus_bef = plus_bef,sampling = sampling,
col_basic = 'silver',col_hline = 'no',labels = False,a_fr=1, ax_nr=ax_nr , phase_zero=[0], shift_phase=0,df_col = 'no',beat_corr_col='no', size=[120], a_fe=0.8)
ax['between'] = plt.subplot(grid[0, 2])
ax['between'].spines['right'].set_visible(False)
ax['between'].spines['top'].set_visible(False)
ax['between'].spines['left'].set_visible(False)
ax['between'].spines['bottom'].set_visible(False)
ax['between'].set_ylim([np.min(dfs), np.max(dfs)])
ax['between'].set_xlim([-0.5,30])
ax['between'].set_xticks([])
ax['between'].set_yticks([])
ax['between'].set_ylim(ax['between'].get_ylim()[::-1])
nr_size = 10
ax['scatter'] = plt.subplot(grid[0,1])
ax['scatter'].spines['right'].set_visible(False)
ax['scatter'].spines['top'].set_visible(False)
counter = 0
new_dfs = np.concatenate([dfs[0:25], dfs[25:40:2], dfs[40:53:2], dfs[54:-1]])
for i in range(len(new_dfs)):
spikes = d[d['df'] == new_dfs[i]]['spike_times']
counter += 1
ll = 0.1
ul = 0.3
transformed_spikes = spikes.iloc[0]-spikes.iloc[0][0]
used_spikes = transformed_spikes[transformed_spikes>ll]
used_spikes = used_spikes[used_spikes<ul]*1000
ax['scatter'].scatter(used_spikes,np.ones(len(used_spikes))*new_dfs[i],s = 0.2,color = 'silver')
#plt.gca().invert_yaxis()
#ax = plt.gca()
ax['scatter'].set_ylim([np.min(dfs),np.max(dfs)])
ax['scatter'].set_xlim([ll*1000,ul*1000])
ax['scatter'].set_ylabel('Difference frequency [Hz]')
ax['scatter'].set_xlabel('Time [ms]')
ax['scatter'].set_ylim(ax['scatter'].get_ylim()[::-1])
ax['scatter'].text(-0.1, 1.025, string.ascii_uppercase[0], transform=ax['scatter'].transAxes,
size= nr_size, weight='bold')
#embed()
axis = gridspec.GridSpecFromSubplotSpec(2, 1,
subplot_spec=grid[1,:], wspace=0, hspace=0.5)
x = d['df'] / d['eodf'] + 1
main_color = 'darkgrey'
var = '05'
var = 'original'
axis,y2 = plot_single_cells(axis,colors = [main_color], data = data,var = var)
new_dfs = np.concatenate([dfs[0:25], dfs[25:40:2], dfs[40:53:2], dfs[54:-1]])
low_nr = 60
low = [low_nr- eod, low_nr, low_nr + eod, low_nr + eod * 2, low_nr + eod * 3]
high_nr = eod / 2 - 20
high = [high_nr - eod, high_nr, high_nr + eod, high_nr + eod * 2, high_nr + eod * 3]
high_nr = eod - low_nr
low_mirrowed = [high_nr - eod, high_nr, high_nr + eod, high_nr + eod * 2, high_nr + eod * 3]
first = ['crimson','lightcoral','darkviolet']
second = ['hotpink','deeppink','mediumvioletred']
third = ['khaki','yellow','gold']
third = ['orange','orangered','darkred']
fourth = ['DarkGreen','LimeGreen','YellowGreen']
fith = ['SkyBlue','DeepSkyBlue','Blue']
colors = np.concatenate([fourth, third,first ])
example_df = np.concatenate([first, high, low_mirrowed])
#embed()
new = np.transpose([low, high, low_mirrowed])
example_df = np.concatenate([new[0], new[1]])#new[2]new[2]
rows = len(example_df)
cols = 1
power_raster = gridspec.GridSpecFromSubplotSpec(rows, cols,
subplot_spec=grid[0, 3],wspace = 0.05, hspace=0.3)
#plt.tight_layout(power_raster)
#embed()
low_nr = 60
from_middle = 45 #20
example_df = [low_nr- eod,eod / 2 - from_middle - eod,eod - low_nr - eod,low_nr,eod / 2 - from_middle, low_nr + eod]
#example_df = [1, eod / 2 - 20 - eod, eod - low_nr - eod, low_nr, eod / 2 - 20, low_nr + eod]
max_p = [[]]*len(example_df)
for i in range(len(example_df)):
power = gridspec.GridSpecFromSubplotSpec(1, 2, width_ratios=[1,1.7],hspace = 0.2, wspace = 0.2, subplot_spec = power_raster[i])
ax_nr = 0
ax['scatter_small'+str(i)] = plt.subplot(power[ax_nr])
eod_fr = eod
eod_fe = [example_df[i] + eod]
e = 0
factor = 200
sampling = 500 * factor
minus_bef = -250
plus_bef = -200
#minus_bef = -2100
#plus_bef = -100
f_max, lims, _ = single_stim(ax, [colors[i]], 1, 1, eod_fr, eod_fe, e, power,delta_t = 0.001, add = 'no',minus_bef =minus_bef, plus_bef = plus_bef,sampling = sampling,
col_basic = 'silver',col_hline = 'no',labels = False,a_fr=1, ax_nr=ax_nr , phase_zero=[0], shift_phase=0,df_col = 'no',beat_corr_col='no', size=[120], a_fe=0.8)
ax['between'].scatter(0.12,example_df[i],zorder=2, s=25,marker = '<',color=colors[i])
ax['between'].scatter(0.12,example_df[i], zorder=2, s=25,marker = '<',color=colors[i])
ll = np.abs(plus_bef)
ul = np.abs(minus_bef)
df = new_dfs[np.argmin(np.abs(new_dfs - example_df[i]))]
spikes = d[d['df'] == df]['spike_times']
tranformed_spikes = spikes.iloc[0]*1000-spikes.iloc[0][0]*1000
used_spikes = tranformed_spikes[tranformed_spikes>ll]
used_spikes = used_spikes[used_spikes<ul]
used_spikes = used_spikes-used_spikes[0]
ax['scatter_small'+str(i)].scatter((used_spikes),np.ones(len(used_spikes))*-2,zorder=2,s = 2,marker = '|',color = 'black')#color = colors[i]
ax['scatter_small'+str(i)].set_ylim([-2.5,2.2])
ax['power'+str(i)] = plt.subplot(power[1])
nfft = 4096 #
sampling_rate = 40000
# embed()
pp = [[]]*len(spikes)
for s in range(len(spikes)):
new_spikes = list(map(int, (spikes.iloc[s] - spikes.iloc[s][0]) * sampling_rate))
array = np.zeros(new_spikes[-1] + 2)
array[new_spikes] = 1
array = array*sampling_rate
if var == '05':
window05 = 0.0005 * sampling_rate
array = gaussian_filter(array, sigma=window05) * sampling_rate
pp[s], f = ml.psd(array - np.mean(array), Fs=sampling_rate, NFFT=nfft, noverlap=nfft / 2)
#embed()
p = np.mean(pp, axis = 0)
#embed()
diff = d['eodf'].iloc[0] * (df / d['eodf'].iloc[0] - int(df / d['eodf'].iloc[0]))
if diff > d['eodf'].iloc[0] * 0.5:
diff = diff - d['eodf'].iloc[0] * 0.5
plt.plot(f, p, zorder=1 ,color=main_color)
max_p[i] = np.max(p)
ax['power'+str(i)].scatter(f[np.argmax(p[f < 0.5 * eod])],max(p[f < 0.5 * eod]),zorder=2,color = colors[i], s = 25)
ax['power' + str(i)].scatter(f[f == f[np.argmin(np.abs(f-eod))]], p[f == f[np.argmin(np.abs(f-eod))]]*0.90, zorder=2,
s=25, color = 'darkgrey',edgecolor = 'black')
ax['power' + str(i)].axvline(x = eod/2, color = 'black', linestyle = 'dashed', lw = 0.5)
plt.xlim([-40, 1600])
axis[3].scatter(example_df[i]/(eod)+1, np.sqrt(np.max(p[f < 0.5 * eod])*np.abs(f[0]-f[1])),zorder=3, s=20,marker = 'o',color=colors[i])
axis[1].scatter(example_df[i]/(eod)+1,f[np.argmax(p[f < 0.5 * eod])],zorder=2, s=20,marker = 'o',color=colors[i])
if i != rows-1:
#ax['power'+str(i)].set_xticks([])
#ax['scatter_small'].set_xticks([])
labels = [item.get_text() for item in ax['scatter_small'+str(i)].get_xticklabels()]
empty_string_labels = [''] * len(labels)
ax['scatter_small'+str(i)].set_xticklabels(empty_string_labels)
labels = [item.get_text() for item in ax['power'+str(i)].get_xticklabels()]
empty_string_labels = [''] * len(labels)
ax['power'+str(i)].set_xticklabels(empty_string_labels)
else:
ax['power'+str(i)].set_xlabel('Frequency [Hz]')
ax['scatter_small'+str(i)].set_xlabel('Time [ms]')
ax['power' + str(i)].set_yticks([])
ax['power'+str(i)].spines['left'].set_visible(False)
ax['scatter_small'+str(i)].spines['left'].set_visible(False)
ax['scatter_small'+str(i)].set_yticks([])
ax['power'+str(i)].spines['right'].set_visible(False)
ax['power'+str(i)].spines['top'].set_visible(False)
ax['scatter_small'+str(i)].spines['right'].set_visible(False)
ax['scatter_small'+str(i)].spines['top'].set_visible(False)
for i in range(len(example_df)):
ax['power'+str(i)].set_ylim([0,np.max(max_p)])
ax['power'+str(0)].text(-0.1, 1.1, string.ascii_uppercase[2], transform=ax['power'+str(0)].transAxes,
size= nr_size, weight='bold')
ax['scatter_small'+str(0)].text(-0.1, 1.1, string.ascii_uppercase[1], transform=ax['scatter_small'+str(0)].transAxes,
size= nr_size, weight='bold')
plt.subplots_adjust(left = 0.11, bottom = 0.18, top = 0.94)
#fig.label_axes()
#fig.label_axes()
#embed()
#grid.format(
# xlabel='xlabel', ylabel='ylabel', suptitle=titles[mode],
# abc=True, abcloc='ul',
# grid=False, xticks=25, yticks=5)
plt.savefig('singlecellexample5.pdf')
plt.savefig('../highbeats_pdf/singlecellexample5.pdf')
# plt.subplots_adjust(left = 0.25)
plt.show()
#plt.close()