242 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import matplotlib.pyplot as plt
 | |
| import pandas as pd
 | |
| import os
 | |
| import string
 | |
| # from plotstyle import plot_style
 | |
| from plotstyle import sim_style
 | |
| import seaborn as sns
 | |
| import scipy.stats as stats
 | |
| import matplotlib.lines as mlines
 | |
| 
 | |
| def cm2inch(*tupl):
 | |
|     inch = 2.54
 | |
|     if isinstance(tupl[0], tuple):
 | |
|         return tuple(i/inch for i in tupl[0])
 | |
|     else:
 | |
|         return tuple(i/inch for i in tupl)
 | |
| 
 | |
| def Kendall_tau(df):
 | |
|     tau = df.corr(method='kendall')
 | |
|     p = pd.DataFrame(columns=df.columns, index=df.columns)
 | |
|     for col in range((df.columns).shape[0]):
 | |
|         for col2 in range((df.columns).shape[0]):
 | |
|             # print(col, col2)
 | |
|             if col != col2:
 | |
|                 _, p.loc[df.columns[col], df.columns[col2]] = stats.kendalltau(
 | |
|                     df[df.columns[col]], df[df.columns[col2]], nan_policy='omit')
 | |
|     return tau, p
 | |
| 
 | |
| def correlation_plot(ax, df='AUC', title='', cbar=False):
 | |
|     cbar_ax = fig.add_axes([0.94, .25, .03, .4])
 | |
|     cbar_ax.spines['left'].set_visible(False)
 | |
|     cbar_ax.spines['bottom'].set_visible(False)
 | |
|     cbar_ax.spines['right'].set_visible(False)
 | |
|     cbar_ax.spines['top'].set_visible(False)
 | |
|     # cbar_ax.axis('off')
 | |
|     cbar_ax.set_xticks([])
 | |
|     cbar_ax.set_yticks([])
 | |
|     if df == 'AUC':
 | |
|         df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
 | |
|     elif df == 'rheo':
 | |
|         df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
 | |
| 
 | |
|     # array for names
 | |
|     cmap = sns.diverging_palette(220, 10, as_cmap=True)
 | |
|     models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
 | |
|               'STN_Kv', 'STN_Kv_only']
 | |
|     model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
 | |
|                    'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
 | |
|                    'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
 | |
|     col_dict = {}
 | |
|     for m in range(len(models)):
 | |
|         col_dict[models[m]] = model_names[m]
 | |
| 
 | |
|     df.rename(columns=col_dict, inplace=True)
 | |
|     df = df[model_names]
 | |
| 
 | |
|     # calculate correlation matrix
 | |
|     tau, p = Kendall_tau(df)
 | |
|     tau = tau.drop(columns='STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', index='RS pyramidal')
 | |
| 
 | |
|     # mask to hide upper triangle of matrix
 | |
|     mask = np.zeros_like(tau, dtype=np.bool)
 | |
|     mask[np.triu_indices_from(mask)] = True
 | |
|     np.fill_diagonal(mask, False)
 | |
| 
 | |
|     # models and renaming of tau
 | |
|     models = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
 | |
|        'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|        'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
 | |
|        'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
 | |
|     model_names = ['RS pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'RS inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'Cb stellate',  'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
 | |
|                    'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
 | |
|     col_dict = {}
 | |
|     for m in range(len(models)):
 | |
|         col_dict[models[m]] = model_names[m]
 | |
|     tau.rename(columns=col_dict, index=col_dict, inplace=True)
 | |
|     tau = tau[model_names]
 | |
| 
 | |
|     # plotting with or without colorbar
 | |
|     if cbar==False:
 | |
|         res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
 | |
|                           cbar=False, cmap=cmap, cbar_ax=cbar_ax, cbar_kws={"shrink": .52})
 | |
|     else:
 | |
|         res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
 | |
|                           cbar=True, cmap=cmap, cbar_ax=cbar_ax, )
 | |
|         cbar_ax.set_title(r'Kendall $\tau$', y=1.02, loc='left')
 | |
|     ax.set_title(title)
 | |
| 
 | |
| def mutation_plot2(ax, model='RS_pramidal'):
 | |
|     models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
 | |
|               'STN_Kv', 'STN_Kv_only']
 | |
|     model_names = ['RS pyramidal', 'RS inhibitory', 'FS',  #'RS pyramidal', 'RS inhibitory', 'FS', 'IB',
 | |
|                    'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
 | |
|     model_display_names = ['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate',
 | |
|                    'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
 | |
|                    'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
 | |
|                    'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
 | |
|     col_dict = {}
 | |
|     for m in range(len(models)):
 | |
|         col_dict[models[m]] = model_display_names[m]
 | |
| 
 | |
|     ax_dict = {}
 | |
|     ax_dict['RS_pyramidal'] = (0, 0)
 | |
|     ax_dict['RS_inhib'] = (0, 1)
 | |
|     ax_dict['FS'] = (1, 0)
 | |
|     # ax_dict['IB'] = (1, 1)
 | |
|     ax_dict['Cb_stellate'] = (2, 0)
 | |
|     ax_dict['Cb_stellate_Kv'] = (2, 1)
 | |
|     ax_dict['Cb_stellate_Kv_only'] = (3, 0)
 | |
|     ax_dict['STN'] = (3, 1)
 | |
|     ax_dict['STN_Kv'] = (4, 0)
 | |
|     ax_dict['STN_Kv_only'] = (4, 1)
 | |
| 
 | |
|     ylim_dict = {}
 | |
|     ylim_dict['RS_pyramidal'] = (-0.1, 0.3)
 | |
|     ylim_dict['RS_inhib'] = (-0.6, 0.6)
 | |
|     ylim_dict['FS'] = (-0.06, 0.08)
 | |
|     # ylim_dict['IB'] = (-0.2, 0.2)
 | |
|     ylim_dict['Cb_stellate'] = (-0.1, 0.4)
 | |
|     ylim_dict['Cb_stellate_Kv'] = (-0.1, 0.5)
 | |
|     ylim_dict['Cb_stellate_Kv_only'] = (-1, 0.8)
 | |
|     ylim_dict['STN'] = (-0.01, 0.015)
 | |
|     ylim_dict['STN_Kv'] = (-0.4, 0.6)
 | |
|     ylim_dict['STN_Kv_only'] = (-0.03, 0.3)
 | |
|     Marker_dict = {'Cb stellate': 'o', 'RS Inhibitory': 'o', 'FS': 'o', 'RS Pyramidal': "^",
 | |
|                    'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
 | |
|                    'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
 | |
|                    'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
 | |
|                    'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
 | |
|                    'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
 | |
|                    'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s",
 | |
|                    'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s", 'STN': "s"}
 | |
| 
 | |
|     AUC = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
 | |
|     rheo = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
 | |
| 
 | |
|     mod = models.index(model)
 | |
|     mut_names = AUC.index
 | |
|     ax.plot(rheo.loc[mut_names, model_names[mod]], AUC.loc[mut_names, model_names[mod]], linestyle='',
 | |
|             markeredgecolor='grey', markerfacecolor='grey', marker=Marker_dict[model_display_names[mod]],
 | |
|             markersize=2)  # markeredgecolor=clr_dict[model_names[mod]],markerfacecolor=clr_dict[model_names[mod]],
 | |
| 
 | |
|     ax.plot(rheo.loc['wt', model_names[mod]], AUC.loc['wt', model_names[mod]], 'sk')
 | |
| 
 | |
|     mut_col = sns.color_palette("pastel")
 | |
|     ax.plot(rheo.loc['V174F', model_names[mod]], AUC.loc['V174F', model_names[mod]], linestyle='',
 | |
|             markeredgecolor=mut_col[0], markerfacecolor=mut_col[0], marker=Marker_dict[model_display_names[mod]],markersize=4)
 | |
|     ax.plot(rheo.loc['F414C', model_names[mod]], AUC.loc['F414C', model_names[mod]], linestyle='',
 | |
|             markeredgecolor=mut_col[1], markerfacecolor=mut_col[1], marker=Marker_dict[model_display_names[mod]],markersize=4)
 | |
|     ax.plot(rheo.loc['E283K', model_names[mod]], AUC.loc['E283K', model_names[mod]], linestyle='',
 | |
|             markeredgecolor=mut_col[2], markerfacecolor=mut_col[2], marker=Marker_dict[model_display_names[mod]],markersize=4)
 | |
|     ax.plot(rheo.loc['V404I', model_names[mod]], AUC.loc['V404I', model_names[mod]], linestyle='',
 | |
|             markeredgecolor=mut_col[3], markerfacecolor=mut_col[5], marker=Marker_dict[model_display_names[mod]],markersize=4)
 | |
|     ax.set_title(model_display_names[mod])#, color=clr_dict[models[mod]]) , fontsize=14
 | |
|     ax.set_xlabel('$\Delta$ Rheobase (nA)', fontsize=6)
 | |
|     ax.set_ylabel('$AUC_{contrast}$', fontsize=6)
 | |
|     ax.spines['right'].set_visible(False)
 | |
|     ax.spines['top'].set_visible(False)
 | |
|     return ax
 | |
| 
 | |
| def mutation_legend(ax, marker_s_leg, pos, ncol):
 | |
|     colors = sns.color_palette("pastel")
 | |
| 
 | |
|     Markers = ["o", "o", "o", "o"]
 | |
|     V174F = mlines.Line2D([], [], color=colors[0], marker=Markers[0], markersize=marker_s_leg, linestyle='None',
 | |
|                          label='V174F')
 | |
|     F414C = mlines.Line2D([], [], color=colors[1], marker=Markers[1], markersize=marker_s_leg, linestyle='None',
 | |
|                          label='F414C')
 | |
|     E283K = mlines.Line2D([], [], color=colors[2], marker=Markers[2], markersize=marker_s_leg, linestyle='None', label='E283K')
 | |
|     V404I = mlines.Line2D([], [], color=colors[5], marker=Markers[3], markersize=marker_s_leg, linestyle='None',
 | |
|                          label='V404I')
 | |
|     WT = mlines.Line2D([], [], color='k', marker='s', markersize=marker_s_leg+2, linestyle='None', label='Wild type')
 | |
| 
 | |
|     ax.legend(handles=[WT, V174F, F414C, E283K, V404I], loc='center', bbox_to_anchor=pos, ncol=ncol, frameon=False)
 | |
| 
 | |
| 
 | |
| 
 | |
| # plot_style()
 | |
| sim_style()
 | |
| # plot setup
 | |
| fig = plt.figure() #figsize=cm2inch(17.6,15)
 | |
| gs0 = fig.add_gridspec(1, 6, wspace=3.5)
 | |
| gsl = gs0[0:4].subgridspec(3, 3, wspace=0.9, hspace=0.8)
 | |
| gsr = gs0[4:6].subgridspec(2, 1, wspace=0.6, hspace=1)
 | |
| 
 | |
| ax00 = fig.add_subplot(gsl[0,0])
 | |
| ax01 = fig.add_subplot(gsl[0,1])
 | |
| ax02 = fig.add_subplot(gsl[0,2])
 | |
| ax10 = fig.add_subplot(gsl[1,0])
 | |
| ax11 = fig.add_subplot(gsl[1,1])
 | |
| ax12 = fig.add_subplot(gsl[1,2])
 | |
| ax20 = fig.add_subplot(gsl[2,0])
 | |
| ax21 = fig.add_subplot(gsl[2,1])
 | |
| ax22 = fig.add_subplot(gsl[2,2])
 | |
| 
 | |
| axr0 = fig.add_subplot(gsr[0,0])
 | |
| axr1 = fig.add_subplot(gsr[1,0])
 | |
| 
 | |
| # plot mutations in each model
 | |
| ax00 = mutation_plot2(ax00, model='RS_pyramidal')
 | |
| ax01 = mutation_plot2(ax01, model='RS_inhib')
 | |
| ax02 = mutation_plot2(ax02, model='FS')
 | |
| ax10 = mutation_plot2(ax10, model='Cb_stellate')
 | |
| ax11 = mutation_plot2(ax11, model='Cb_stellate_Kv')
 | |
| ax12 = mutation_plot2(ax12, model='Cb_stellate_Kv_only')
 | |
| ax20 = mutation_plot2(ax20, model='STN')
 | |
| ax21 = mutation_plot2(ax21, model='STN_Kv')
 | |
| ax22 = mutation_plot2(ax22, model='STN_Kv_only')
 | |
| 
 | |
| marker_s_leg = 4
 | |
| pos = (0.25, -0.45)
 | |
| ncol = 5
 | |
| mutation_legend(ax21, marker_s_leg, pos, ncol)
 | |
| 
 | |
| # plot correlation matrices
 | |
| correlation_plot(axr1,df = 'AUC', title='$AUC_{contrast}$', cbar=False)
 | |
| correlation_plot(axr0,df = 'rheo', title='$\Delta rheobase$', cbar=True)
 | |
| 
 | |
| # add subplot labels
 | |
| axs = [ax00, ax01,ax02, ax10, ax11, ax12, ax20, ax21, ax22]
 | |
| j=0
 | |
| for i in range(0,9):
 | |
|     axs[i].text(-0.48, 1.08, string.ascii_uppercase[i], transform=axs[i].transAxes, size=10, weight='bold')
 | |
|     j +=1
 | |
| axr0.text(-0.38, 1.2, string.ascii_uppercase[j], transform=axr0.transAxes, size=10, weight='bold')
 | |
| axr1.text(-0.38, 1.2, string.ascii_uppercase[j+1], transform=axr1.transAxes, size=10, weight='bold')
 | |
| 
 | |
| # save
 | |
| # fig.savefig('./Figures/simulation_model_comparison.pdf') #, bbox_inches='tight'
 | |
| print(fig.dpi)
 | |
| fig.set_size_inches(cm2inch(17.95,15))
 | |
| fig.savefig('./Figures/simulation_model_comparison.pdf', dpi=fig.dpi) #bbox_inches='tight', dpi=fig.dpi
 | |
| plt.show()
 | |
| 
 |