model_mutations_2022/Figures/simulation_model_comparison.py

242 lines
11 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import string
# from plotstyle import plot_style
from plotstyle import sim_style
import seaborn as sns
import scipy.stats as stats
import matplotlib.lines as mlines
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
def Kendall_tau(df):
tau = df.corr(method='kendall')
p = pd.DataFrame(columns=df.columns, index=df.columns)
for col in range((df.columns).shape[0]):
for col2 in range((df.columns).shape[0]):
# print(col, col2)
if col != col2:
_, p.loc[df.columns[col], df.columns[col2]] = stats.kendalltau(
df[df.columns[col]], df[df.columns[col2]], nan_policy='omit')
return tau, p
def correlation_plot(ax, df='AUC', title='', cbar=False):
cbar_ax = fig.add_axes([0.94, .25, .03, .4])
cbar_ax.spines['left'].set_visible(False)
cbar_ax.spines['bottom'].set_visible(False)
cbar_ax.spines['right'].set_visible(False)
cbar_ax.spines['top'].set_visible(False)
# cbar_ax.axis('off')
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])
if df == 'AUC':
df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
elif df == 'rheo':
df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
# array for names
cmap = sns.diverging_palette(220, 10, as_cmap=True)
models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
col_dict = {}
for m in range(len(models)):
col_dict[models[m]] = model_names[m]
df.rename(columns=col_dict, inplace=True)
df = df[model_names]
# calculate correlation matrix
tau, p = Kendall_tau(df)
tau = tau.drop(columns='STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', index='RS pyramidal')
# mask to hide upper triangle of matrix
mask = np.zeros_like(tau, dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
np.fill_diagonal(mask, False)
# models and renaming of tau
models = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
model_names = ['RS pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'RS inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
col_dict = {}
for m in range(len(models)):
col_dict[models[m]] = model_names[m]
tau.rename(columns=col_dict, index=col_dict, inplace=True)
tau = tau[model_names]
# plotting with or without colorbar
if cbar==False:
res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
cbar=False, cmap=cmap, cbar_ax=cbar_ax, cbar_kws={"shrink": .52})
else:
res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
cbar=True, cmap=cmap, cbar_ax=cbar_ax, )
cbar_ax.set_title(r'Kendall $\tau$', y=1.02, loc='left')
ax.set_title(title)
def mutation_plot2(ax, model='RS_pramidal'):
models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', #'RS pyramidal', 'RS inhibitory', 'FS', 'IB',
'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
model_display_names = ['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
col_dict = {}
for m in range(len(models)):
col_dict[models[m]] = model_display_names[m]
ax_dict = {}
ax_dict['RS_pyramidal'] = (0, 0)
ax_dict['RS_inhib'] = (0, 1)
ax_dict['FS'] = (1, 0)
# ax_dict['IB'] = (1, 1)
ax_dict['Cb_stellate'] = (2, 0)
ax_dict['Cb_stellate_Kv'] = (2, 1)
ax_dict['Cb_stellate_Kv_only'] = (3, 0)
ax_dict['STN'] = (3, 1)
ax_dict['STN_Kv'] = (4, 0)
ax_dict['STN_Kv_only'] = (4, 1)
ylim_dict = {}
ylim_dict['RS_pyramidal'] = (-0.1, 0.3)
ylim_dict['RS_inhib'] = (-0.6, 0.6)
ylim_dict['FS'] = (-0.06, 0.08)
# ylim_dict['IB'] = (-0.2, 0.2)
ylim_dict['Cb_stellate'] = (-0.1, 0.4)
ylim_dict['Cb_stellate_Kv'] = (-0.1, 0.5)
ylim_dict['Cb_stellate_Kv_only'] = (-1, 0.8)
ylim_dict['STN'] = (-0.01, 0.015)
ylim_dict['STN_Kv'] = (-0.4, 0.6)
ylim_dict['STN_Kv_only'] = (-0.03, 0.3)
Marker_dict = {'Cb stellate': 'o', 'RS Inhibitory': 'o', 'FS': 'o', 'RS Pyramidal': "^",
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s",
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s", 'STN': "s"}
AUC = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
rheo = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
mod = models.index(model)
mut_names = AUC.index
ax.plot(rheo.loc[mut_names, model_names[mod]], AUC.loc[mut_names, model_names[mod]], linestyle='',
markeredgecolor='grey', markerfacecolor='grey', marker=Marker_dict[model_display_names[mod]],
markersize=2) # markeredgecolor=clr_dict[model_names[mod]],markerfacecolor=clr_dict[model_names[mod]],
ax.plot(rheo.loc['wt', model_names[mod]], AUC.loc['wt', model_names[mod]], 'sk')
mut_col = sns.color_palette("pastel")
ax.plot(rheo.loc['V174F', model_names[mod]], AUC.loc['V174F', model_names[mod]], linestyle='',
markeredgecolor=mut_col[0], markerfacecolor=mut_col[0], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['F414C', model_names[mod]], AUC.loc['F414C', model_names[mod]], linestyle='',
markeredgecolor=mut_col[1], markerfacecolor=mut_col[1], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['E283K', model_names[mod]], AUC.loc['E283K', model_names[mod]], linestyle='',
markeredgecolor=mut_col[2], markerfacecolor=mut_col[2], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['V404I', model_names[mod]], AUC.loc['V404I', model_names[mod]], linestyle='',
markeredgecolor=mut_col[3], markerfacecolor=mut_col[5], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.set_title(model_display_names[mod])#, color=clr_dict[models[mod]]) , fontsize=14
ax.set_xlabel('$\Delta$ Rheobase (nA)', fontsize=6)
ax.set_ylabel('$AUC_{contrast}$', fontsize=6)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
return ax
def mutation_legend(ax, marker_s_leg, pos, ncol):
colors = sns.color_palette("pastel")
Markers = ["o", "o", "o", "o"]
V174F = mlines.Line2D([], [], color=colors[0], marker=Markers[0], markersize=marker_s_leg, linestyle='None',
label='V174F')
F414C = mlines.Line2D([], [], color=colors[1], marker=Markers[1], markersize=marker_s_leg, linestyle='None',
label='F414C')
E283K = mlines.Line2D([], [], color=colors[2], marker=Markers[2], markersize=marker_s_leg, linestyle='None', label='E283K')
V404I = mlines.Line2D([], [], color=colors[5], marker=Markers[3], markersize=marker_s_leg, linestyle='None',
label='V404I')
WT = mlines.Line2D([], [], color='k', marker='s', markersize=marker_s_leg+2, linestyle='None', label='Wild type')
ax.legend(handles=[WT, V174F, F414C, E283K, V404I], loc='center', bbox_to_anchor=pos, ncol=ncol, frameon=False)
# plot_style()
sim_style()
# plot setup
fig = plt.figure() #figsize=cm2inch(17.6,15)
gs0 = fig.add_gridspec(1, 6, wspace=3.5)
gsl = gs0[0:4].subgridspec(3, 3, wspace=0.9, hspace=0.8)
gsr = gs0[4:6].subgridspec(2, 1, wspace=0.6, hspace=1)
ax00 = fig.add_subplot(gsl[0,0])
ax01 = fig.add_subplot(gsl[0,1])
ax02 = fig.add_subplot(gsl[0,2])
ax10 = fig.add_subplot(gsl[1,0])
ax11 = fig.add_subplot(gsl[1,1])
ax12 = fig.add_subplot(gsl[1,2])
ax20 = fig.add_subplot(gsl[2,0])
ax21 = fig.add_subplot(gsl[2,1])
ax22 = fig.add_subplot(gsl[2,2])
axr0 = fig.add_subplot(gsr[0,0])
axr1 = fig.add_subplot(gsr[1,0])
# plot mutations in each model
ax00 = mutation_plot2(ax00, model='RS_pyramidal')
ax01 = mutation_plot2(ax01, model='RS_inhib')
ax02 = mutation_plot2(ax02, model='FS')
ax10 = mutation_plot2(ax10, model='Cb_stellate')
ax11 = mutation_plot2(ax11, model='Cb_stellate_Kv')
ax12 = mutation_plot2(ax12, model='Cb_stellate_Kv_only')
ax20 = mutation_plot2(ax20, model='STN')
ax21 = mutation_plot2(ax21, model='STN_Kv')
ax22 = mutation_plot2(ax22, model='STN_Kv_only')
marker_s_leg = 4
pos = (0.25, -0.45)
ncol = 5
mutation_legend(ax21, marker_s_leg, pos, ncol)
# plot correlation matrices
correlation_plot(axr1,df = 'AUC', title='$AUC_{contrast}$', cbar=False)
correlation_plot(axr0,df = 'rheo', title='$\Delta rheobase$', cbar=True)
# add subplot labels
axs = [ax00, ax01,ax02, ax10, ax11, ax12, ax20, ax21, ax22]
j=0
for i in range(0,9):
axs[i].text(-0.48, 1.08, string.ascii_uppercase[i], transform=axs[i].transAxes, size=10, weight='bold')
j +=1
axr0.text(-0.38, 1.2, string.ascii_uppercase[j], transform=axr0.transAxes, size=10, weight='bold')
axr1.text(-0.38, 1.2, string.ascii_uppercase[j+1], transform=axr1.transAxes, size=10, weight='bold')
# save
# fig.savefig('./Figures/simulation_model_comparison.pdf') #, bbox_inches='tight'
print(fig.dpi)
fig.set_size_inches(cm2inch(17.95,15))
fig.savefig('./Figures/simulation_model_comparison.pdf', dpi=fig.dpi) #bbox_inches='tight', dpi=fig.dpi
plt.show()