import numpy as np import matplotlib.pyplot as plt import pandas as pd import os import string from Figures.plotstyle import sim_style import seaborn as sns import scipy.stats as stats import matplotlib.lines as mlines def cm2inch(*tupl): inch = 2.54 if isinstance(tupl[0], tuple): return tuple(i/inch for i in tupl[0]) else: return tuple(i/inch for i in tupl) def Kendall_tau(df): tau = df.corr(method='kendall') p = pd.DataFrame(columns=df.columns, index=df.columns) for col in range((df.columns).shape[0]): for col2 in range((df.columns).shape[0]): if col != col2: _, p.loc[df.columns[col], df.columns[col2]] = stats.kendalltau( df[df.columns[col]], df[df.columns[col2]], nan_policy='omit') return tau, p def correlation_plot(ax, df='AUC', title='', cbar=False): cbar_ax = fig.add_axes([0.685, 0.44, .15, .01]) cbar_ax.spines['left'].set_visible(False) cbar_ax.spines['bottom'].set_visible(False) cbar_ax.spines['right'].set_visible(False) cbar_ax.spines['top'].set_visible(False) cbar_ax.set_xticks([]) cbar_ax.set_yticks([]) if df == 'AUC': df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0') elif df == 'rheo': df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0') # array for names cmap = sns.diverging_palette(220, 10, as_cmap=True) models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN', 'STN_Kv', 'STN_Kv_only'] model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'] col_dict = {} for m in range(len(models)): col_dict[models[m]] = model_names[m] df.rename(columns=col_dict, inplace=True) df = df[model_names] # calculate correlation matrix tau, p = Kendall_tau(df) tau = tau.drop(columns='STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', index='RS pyramidal') # mask to hide upper triangle of matrix mask = np.zeros_like(tau, dtype=bool) mask[np.triu_indices_from(mask)] = True np.fill_diagonal(mask, False) # models and renaming of tau models = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'] model_names = ['RS pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'RS inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'] col_dict = {} for m in range(len(models)): col_dict[models[m]] = model_names[m] tau.rename(columns=col_dict, index=col_dict, inplace=True) tau = tau[model_names] # plotting with or without colorbar if cbar==False: res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax, cbar=False, cmap=cmap, cbar_ax=cbar_ax, cbar_kws={"shrink": .52}) else: res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax, cbar=True, cmap=cmap, cbar_ax=cbar_ax, cbar_kws={"orientation": "horizontal", "ticks": [-1,-0.5, 0, 0.5, 1]} ) cbar_ax.set_title(r'Kendall $\tau$', y=1.02, loc='center', fontsize=6) cbar_ax.tick_params(length=3) for tick in cbar_ax.xaxis.get_major_ticks(): tick.label.set_fontsize(6) ax.set_title(title, fontsize=8) def mutation_plot(ax, model='RS_pramidal'): models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN', 'STN_Kv', 'STN_Kv_only'] model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'] model_display_names = ['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'] col_dict = {} for m in range(len(models)): col_dict[models[m]] = model_display_names[m] ax_dict = {} ax_dict['RS_pyramidal'] = (0, 0) ax_dict['RS_inhib'] = (0, 1) ax_dict['FS'] = (1, 0) ax_dict['Cb_stellate'] = (2, 0) ax_dict['Cb_stellate_Kv'] = (2, 1) ax_dict['Cb_stellate_Kv_only'] = (3, 0) ax_dict['STN'] = (3, 1) ax_dict['STN_Kv'] = (4, 0) ax_dict['STN_Kv_only'] = (4, 1) ylim_dict = {} ylim_dict['RS_pyramidal'] = (-0.1, 0.3) ylim_dict['RS_inhib'] = (-0.6, 0.6) ylim_dict['FS'] = (-0.06, 0.08) ylim_dict['Cb_stellate'] = (-0.1, 0.4) ylim_dict['Cb_stellate_Kv'] = (-0.1, 0.5) ylim_dict['Cb_stellate_Kv_only'] = (-1, 0.8) ylim_dict['STN'] = (-0.01, 0.015) ylim_dict['STN_Kv'] = (-0.4, 0.6) ylim_dict['STN_Kv_only'] = (-0.03, 0.3) Marker_dict = {'Cb stellate': 'o', 'RS Inhibitory': 'o', 'FS': 'o', 'RS Pyramidal': "^", 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^", 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^", 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D", 'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D", 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D", 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s", 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s", 'STN': "s"} AUC = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0') rheo = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0') mod = models.index(model) mut_names = AUC.index ax.plot(rheo.loc[mut_names, model_names[mod]]*1000, AUC.loc[mut_names, model_names[mod]], linestyle='', markeredgecolor='grey', markerfacecolor='grey', marker=Marker_dict[model_display_names[mod]], markersize=2) ax.plot(rheo.loc['wt', model_names[mod]], AUC.loc['wt', model_names[mod]], 'sk') mut_col = sns.color_palette("pastel") ax.plot(rheo.loc['V174F', model_names[mod]]*1000, AUC.loc['V174F', model_names[mod]], linestyle='', markeredgecolor=mut_col[0], markerfacecolor=mut_col[0], marker=Marker_dict[model_display_names[mod]],markersize=4) ax.plot(rheo.loc['F414C', model_names[mod]]*1000, AUC.loc['F414C', model_names[mod]], linestyle='', markeredgecolor=mut_col[1], markerfacecolor=mut_col[1], marker=Marker_dict[model_display_names[mod]],markersize=4) ax.plot(rheo.loc['E283K', model_names[mod]]*1000, AUC.loc['E283K', model_names[mod]], linestyle='', markeredgecolor=mut_col[2], markerfacecolor=mut_col[2], marker=Marker_dict[model_display_names[mod]],markersize=4) ax.plot(rheo.loc['V404I', model_names[mod]]*1000, AUC.loc['V404I', model_names[mod]], linestyle='', markeredgecolor=mut_col[3], markerfacecolor=mut_col[5], marker=Marker_dict[model_display_names[mod]],markersize=4) ax.set_title(model_display_names[mod], pad=14) ax.set_xlabel('$\Delta$ Rheobase (pA)') ax.set_ylabel('Normalized $\Delta$AUC') ax.spines['right'].set_visible(False) ax.spines['top'].set_visible(False) ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0),useMathText=True) xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() ax.hlines(0, xmin, xmax, colors='lightgrey', linestyles='--') ax.vlines(0, ymin,ymax, colors='lightgrey', linestyles='--') return ax def mutation_legend(ax, marker_s_leg, pos, ncol): colors = sns.color_palette("pastel") Markers = ["o", "o", "o", "o"] V174F = mlines.Line2D([], [], color=colors[0], marker=Markers[0], markersize=marker_s_leg, linestyle='None', label='V174F') F414C = mlines.Line2D([], [], color=colors[1], marker=Markers[1], markersize=marker_s_leg, linestyle='None', label='F414C') E283K = mlines.Line2D([], [], color=colors[2], marker=Markers[2], markersize=marker_s_leg, linestyle='None', label='E283K') V404I = mlines.Line2D([], [], color=colors[5], marker=Markers[3], markersize=marker_s_leg, linestyle='None', label='V404I') WT = mlines.Line2D([], [], color='k', marker='s', markersize=marker_s_leg+2, linestyle='None', label='Wild type') ax.legend(handles=[WT, V174F, F414C, E283K, V404I], loc='center', bbox_to_anchor=pos, ncol=ncol, frameon=False) sim_style() # plot setup fig = plt.figure() gs0 = fig.add_gridspec(1, 6, wspace=-0.2) gsl = gs0[0:3].subgridspec(3, 3, wspace=0.9, hspace=0.8) gsr = gs0[4:6].subgridspec(6, 1, wspace=0.6, hspace=1) ax00 = fig.add_subplot(gsl[0,0]) ax01 = fig.add_subplot(gsl[0,1]) ax02 = fig.add_subplot(gsl[0,2]) ax10 = fig.add_subplot(gsl[1,0]) ax11 = fig.add_subplot(gsl[1,1]) ax12 = fig.add_subplot(gsl[1,2]) ax20 = fig.add_subplot(gsl[2,0]) ax21 = fig.add_subplot(gsl[2,1]) ax22 = fig.add_subplot(gsl[2,2]) axr0 = fig.add_subplot(gsr[0:2,0]) axr1 = fig.add_subplot(gsr[4:,0]) # plot mutations in each model ax00 = mutation_plot(ax00, model='RS_pyramidal') ax01 = mutation_plot(ax01, model='RS_inhib') ax02 = mutation_plot(ax02, model='FS') ax10 = mutation_plot(ax10, model='Cb_stellate') ax11 = mutation_plot(ax11, model='Cb_stellate_Kv') ax12 = mutation_plot(ax12, model='Cb_stellate_Kv_only') ax20 = mutation_plot(ax20, model='STN') ax21 = mutation_plot(ax21, model='STN_Kv') ax22 = mutation_plot(ax22, model='STN_Kv_only') marker_s_leg = 4 pos = (0.425, -0.7) ncol = 5 mutation_legend(ax21, marker_s_leg, pos, ncol) # plot correlation matrices correlation_plot(axr1,df = 'AUC', title='Normalized $\Delta$AUC', cbar=False) correlation_plot(axr0,df = 'rheo', title='$\Delta$ Rheobase', cbar=True) # add subplot labels axs = [ax00, ax01,ax02, ax10, ax11, ax12, ax20, ax21, ax22] j=0 for i in range(0,9): axs[i].text(-0.48, 1.175, string.ascii_uppercase[i], transform=axs[i].transAxes, size=10, weight='bold') j +=1 axr0.text(-0.77, 1.1, string.ascii_uppercase[j], transform=axr0.transAxes, size=10, weight='bold') axr1.text(-0.77, 1.1, string.ascii_uppercase[j+1], transform=axr1.transAxes, size=10, weight='bold') # save fig.set_size_inches(cm2inch(22.2,15)) fig.savefig('./Figures/simulation_model_comparison.pdf', dpi=fig.dpi) plt.show()