added figure plotting scripts *_letters.py to plot using models named by letter

This commit is contained in:
nkoch1
2022-09-24 16:35:13 +02:00
parent 3ba75f8ce2
commit b005b04937
10 changed files with 1837 additions and 21 deletions

Binary file not shown.

View File

@@ -389,6 +389,21 @@ color_dict = {'Cb stellate': '#40A787', # cyan'#
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#873770', # magenta
'STN': '#D03050' # pink
}
model_letter = {
'Cb stellate': 'A',
'RS Inhibitory': 'B',
'FS': 'C',
'RS Pyramidal': 'D',
'RS Inhibitory': 'E',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'F',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'G',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'H',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'I',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'J',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'K',
'STN':'L',
}
# plot setup
marker_s_leg = 2

View File

@@ -0,0 +1,543 @@
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 3 19:52:04 2021
@author: nils
"""
import pandas as pd
import numpy as np
import string
import textwrap
import json
import matplotlib
import matplotlib.lines as mlines
from matplotlib import ticker
from matplotlib.ticker import NullFormatter
from Figures.plotstyle import boxplot_style
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
#%% ##################### From https://stackoverflow.com/questions/52878845/swarmplot-with-hue-affecting-marker-beyond-color ##
# to change marker types in seaborn swarmplot
import seaborn as sns
import matplotlib.pyplot as plt
############## Begin hack ##############
from matplotlib.axes._axes import Axes
from matplotlib.markers import MarkerStyle
from numpy import ndarray
def GetColor2Marker(markers):
colorslist = ['#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
import matplotlib.colors
palette = [matplotlib.colors.to_rgb(c) for c in colorslist]
mkcolors = [(palette[i]) for i in range(len(markers))]
return dict(zip(mkcolors,markers))
def fixlegend(ax,markers,markersize=3,**kwargs):
# Fix Legend
legtitle = ax.get_legend().get_title().get_text()
_,l = ax.get_legend_handles_labels()
colorslist = ['#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
import matplotlib.colors
palette = [matplotlib.colors.to_rgb(c) for c in colorslist]
mkcolors = [(palette[i]) for i in range(len(markers))]
newHandles = [plt.Line2D([0],[0], ls="none", marker=m, color=c, mec="none", markersize=markersize,**kwargs) \
for m,c in zip(markers, mkcolors)]
ax.legend(newHandles,l)
leg = ax.get_legend()
leg.set_title(legtitle)
old_scatter = Axes.scatter
def new_scatter(self, *args, **kwargs):
colors = kwargs.get("c", None)
co2mk = kwargs.pop("co2mk",None)
FinalCollection = old_scatter(self, *args, **kwargs)
if co2mk is not None and isinstance(colors, ndarray):
Color2Marker = GetColor2Marker(co2mk)
paths=[]
for col in colors:
mk=Color2Marker[tuple(col)]
marker_obj = MarkerStyle(mk)
paths.append(marker_obj.get_path().transformed(marker_obj.get_transform()))
FinalCollection.set_paths(paths)
return FinalCollection
Axes.scatter = new_scatter
############## End hack. ##############
########################################################################################################################
#%% add gradient arrows
import matplotlib.pyplot as plt
import matplotlib.transforms
import matplotlib.path
from matplotlib.collections import LineCollection
def gradientaxis(ax, start, end, cmap, n=100,lw=1):
# Arrow shaft: LineCollection
x = np.linspace(start[0],end[0],n)
y = np.linspace(start[1],end[1],n)
points = np.array([x,y]).T.reshape(-1,1,2)
segments = np.concatenate([points[:-1],points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, linewidth=lw,zorder=15)
lc.set_array(np.linspace(0,1,n))
ax.add_collection(lc)
return ax
#%%
#%%
def boxplot_with_markers(ax,max_width, alteration='shift', msize=3):
hlinewidth = 0.5
model_names = ['RS pyramidal','RS inhibitory','FS',
'RS pyramidal +$K_V1.1$','RS inhibitory +$K_V1.1$',
'FS +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
'STN $\Delta$$K_V1.1$']
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
print(colors)
print(clr_dict)
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
if alteration=='shift':
i = 2 # Kd act
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor = 'k')
df = pd.read_csv('./Figures/Data/AUC_shift_corr.csv')
sns.swarmplot(y="corr", x="$\Delta V_{1/2}$", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na activation', 'Na inactivation', 'K activation', '$K_V1.1$ activation',
'$K_V1.1$ inactivation', 'A activation', 'A inactivation'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
ax.set_title("Shift ($\Delta V_{1/2}$)", y=1.05)
ax.set_xticklabels(['Na \nactivation', 'Na \ninactivation', 'K \nactivation', '$K_V1.1$ \nactivation',
'$K_V1.1$ \ninactivation', 'A \nactivation', 'A \ninactivation'])
elif alteration=='slope':
i = 3 # Kv1.1 act
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor='k')
df = pd.read_csv('./Figures/Data/AUC_scale_corr.csv')
# Add in points to show each observation
sns.swarmplot(y="corr", x="Slope (k)", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na activation', 'Na inactivation', 'K activation', '$K_V1.1$ activation',
'$K_V1.1$ inactivation', 'A activation', 'A inactivation'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
ax.set_title("Slope (k)", y=1.05)
ax.set_xticklabels(['Na \nactivation', 'Na \ninactivation', 'K \nactivation', '$K_V1.1$ \nactivation',
'$K_V1.1$ \ninactivation', 'A \nactivation', 'A \ninactivation'])
elif alteration=='g':
i = 1 # Kd
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor='k')
df = pd.read_csv('./Figures/Data/AUC_g_corr.csv')
# Add in points to show each observation
sns.swarmplot(y="corr", x="g", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na', 'K', '$K_V1.1$', 'A', 'Leak'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
# ax.plot([lim[0], lim[1]], [0,0], ':k')
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
# Tweak the visual presentation
ax.set_title("Conductance (g)", y=1.05)
ax.set_xticklabels(textwrap.fill(x.get_text(), max_width) for x in ax.get_xticklabels())
else:
print('Please chose "shift", "slope" or "g"')
ax.get_legend().remove()
ax.xaxis.grid(False)
sns.despine(trim=True, bottom=True, ax=ax)
ax.set(xlabel=None, ylabel=r'Kendall $\it{\tau}$')
def model_legend(ax, marker_s_leg, pos, ncol):
# colorslist = [ '#40A787', # cyan'#
# '#F0D730', # yellow
# '#C02717', # red
# '#007030', # dark green
# '#AAB71B', # lightgreen
# '#008797', # light blue
# '#F78017', # orange
# '#478010', # green
# '#53379B', # purple
# '#2060A7', # blue
# '#873770', # magenta
# '#D03050' # pink
# ]
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
model_pos = {'Cb stellate':0, 'RS Inhibitory':1, 'FS':2, 'RS Pyramidal':3,
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':4,
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':5, 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':6,
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':7, 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':8,
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':9,
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':10, 'STN':11}
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
# RS_p = mlines.Line2D([], [], color=colors[model_pos['RS Pyramidal']], marker=Markers[model_pos['RS Pyramidal']], markersize=marker_s_leg, linestyle='None',
# label='Model D')
# RS_i = mlines.Line2D([], [], color=colors[model_pos['RS Inhibitory']], marker=Markers[model_pos['RS Inhibitory']], markersize=marker_s_leg, linestyle='None',
# label='Model B')
# FS = mlines.Line2D([], [], color=colors[model_pos['FS']], marker=Markers[model_pos['FS']], markersize=marker_s_leg, linestyle='None', label='Model C')
# RS_p_Kv = mlines.Line2D([], [], color=colors[model_pos['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model H')
# RS_i_Kv = mlines.Line2D([], [], color=colors[model_pos['RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model E')
# FS_Kv = mlines.Line2D([], [], color=colors[model_pos['Cb stellate']], marker=Markers[model_pos['Cb stellate']], markersize=marker_s_leg, linestyle='None', label='Model G')
# Cb = mlines.Line2D([], [], color=colors[8], marker=Markers[8], markersize=marker_s_leg, linestyle='None',
# label='Model A')
# Cb_pl = mlines.Line2D([], [], color=colors[model_pos['Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model F')
# Cb_sw = mlines.Line2D([], [], color=colors[model_pos['Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model J')
# STN = mlines.Line2D([], [], color=colors[model_pos['STN']], marker=Markers[model_pos['STN']], markersize=marker_s_leg, linestyle='None', label='Model L')
# STN_pl = mlines.Line2D([], [], color=colors[model_pos['STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model I')
# STN_sw = mlines.Line2D([], [], color=colors[model_pos['STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], marker=Markers[model_pos['STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg, linestyle='None',
# label='Model K')
RS_p = mlines.Line2D([], [], color='#007030', marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model D')
RS_i = mlines.Line2D([], [], color='#F0D730', marker="o",
markersize=marker_s_leg, linestyle='None',
label='Model B')
FS = mlines.Line2D([], [], color='#C02717', marker="o", markersize=marker_s_leg,
linestyle='None', label='Model C')
RS_p_Kv = mlines.Line2D([], [], color='#478010',
marker="D",
markersize=marker_s_leg, linestyle='None',
label='Model H')
RS_i_Kv = mlines.Line2D([], [], color='#AAB71B',
marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model E')
FS_Kv = mlines.Line2D([], [], color='#F78017',
marker="D", markersize=marker_s_leg,
linestyle='None', label='Model G')
Cb = mlines.Line2D([], [], color='#40A787', marker="o",
markersize=marker_s_leg, linestyle='None',
label='Model A')
Cb_pl = mlines.Line2D([], [], color='#008797',
marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model F')
Cb_sw = mlines.Line2D([], [], color='#2060A7',
marker="s",
markersize=marker_s_leg, linestyle='None',
label='Model J')
STN = mlines.Line2D([], [], color='#D03050', marker="s", markersize=marker_s_leg,
linestyle='None', label='Model L')
STN_pl = mlines.Line2D([], [], color='#53379B',
marker="D",
markersize=marker_s_leg, linestyle='None',
label='Model I')
STN_sw = mlines.Line2D([], [], color='#873770',
marker="s",
markersize=marker_s_leg, linestyle='None',
label='Model K')
# ax.legend(handles=[RS_p, RS_i, FS, RS_p_Kv, RS_i_Kv, FS_Kv, Cb, Cb_pl, Cb_sw, STN, STN_pl, STN_sw], loc='center',
# bbox_to_anchor=pos, ncol=ncol, frameon=False)
ax.legend(handles=[Cb, RS_i, FS, RS_p, RS_i_Kv, Cb_pl, FS_Kv, RS_p_Kv, STN_pl, Cb_sw, STN_sw, STN], loc='center',
bbox_to_anchor=pos, ncol=ncol, frameon=False)
def plot_AUC_alt(ax, model='FS', color1='red', color2='dodgerblue', alteration='shift'):
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
model_names = ['RS Pyramidal','RS Inhibitory','FS',
'RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
'FS +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
'STN $\Delta$$K_V1.1$']
model_name_dict = {'RS Pyramidal': 'RS Pyramidal',
'RS Inhibitory': 'RS Inhibitory',
'FS': 'FS',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS Pyramidal +$K_V1.1$',
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS Inhibitory +$K_V1.1$',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'FS +$K_V1.1$',
'Cb stellate': 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb stellate $\Delta$$K_V1.1$',
'STN': 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN +$K_V1.1$',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN $\Delta$$K_V1.1$'}
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
if alteration=='shift':
df = pd.read_csv('./Figures/Data/AUC_shift_ex.csv')
df = df.sort_values('alteration')
ax.set_xlabel('$\Delta$$V_{1/2}$')
elif alteration=='slope':
df = pd.read_csv('./Figures/Data/AUC_slope_ex.csv')
ax.set_xscale("log")
ax.set_xticks([0.5, 1, 2])
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.set_minor_formatter(NullFormatter())
ax.set_xlabel('$k$/$k_{WT}$')
elif alteration=='g':
df = pd.read_csv('./Figures/Data/AUC_g_ex.csv')
ax.set_xscale("log")
ax.set_xticks([0.5, 1, 2])
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.set_minor_formatter(NullFormatter())
ax.set_xlabel('$g$/$g_{WT}$')
for mod in model_names:
if mod == model_name_dict[model]:
ax.plot(df['alteration'], df[mod], color=clr_dict[mod], alpha=1, zorder=10, linewidth=2)
else:
ax.plot(df['alteration'], df[mod], color=clr_dict[mod],alpha=0.5, zorder=1, linewidth=1)
if alteration=='shift':
ax.set_ylabel('Normalized $\Delta$AUC', labelpad=4)
else:
ax.set_ylabel('Normalized $\Delta$AUC', labelpad=0)
x = df['alteration']
y = df[model_name_dict[model]]
ax.set_xlim(x.min(), x.max())
ax.set_ylim(df[model_names].min().min(), df[model_names].max().max())
# x axis color gradient
cvals = [-2., 2]
colors = ['lightgrey', 'k']
norm = plt.Normalize(min(cvals), max(cvals))
tuples = list(zip(map(norm, cvals), colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
(xstart, xend) = ax.get_xlim()
(ystart, yend) = ax.get_ylim()
print(ystart, yend)
start = (xstart, ystart * 1.0)
end = (xend, ystart * 1.0)
ax = gradientaxis(ax, start, end, cmap, n=200, lw=4)
ax.spines['bottom'].set_visible(False)
return ax
def plot_fI(ax, model='RS Pyramidal', type='shift', alt='m', color1='red', color2='dodgerblue'):
model_save_name = {'RS Pyramidal': 'RS_pyr_posp',
'RS Inhibitory': 'RS_inhib_posp',
'FS': 'FS_posp',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS_pyr_Kv',
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS_inhib_Kv',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'FS_Kv',
'Cb stellate': 'Cb_stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb_stellate_Kv',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb_stellate_Kv_only',
'STN': 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN_Kv',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN_Kv_only'}
cvals = [-2., 2]
colors = [color1, color2]
norm = plt.Normalize(min(cvals), max(cvals))
tuples = list(zip(map(norm, cvals), colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
colors = cmap(np.linspace(0, 1, 22))
df = pd.read_csv('./Figures/Data/Model_fI/{}_fI.csv'.format(model_save_name[model]))
df.drop(['Unnamed: 0'], axis=1)
newdf = df.loc[df.index[(df['alt'] == alt) & (df['type'] == type)], :]
newdf['mag'] = newdf['mag'].astype('float')
newdf = newdf.sort_values('mag').reset_index()
c = 0
for i in newdf.index:
ax.plot(json.loads(newdf.loc[i, 'I']), json.loads(newdf.loc[i, 'F']), color=colors[c])
c += 1
ax.set_ylabel('Frequency [Hz]')
ax.set_xlabel('Current [nA]')
if model == 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':
ax.set_title("Model G", x=0.2, y=1.0)
elif model == 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':
ax.set_title("Model I", x=0.2, y=1.0)
else:
ax.set_title("", x=0.2, y=1.0)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
L = ax.get_ylim()
ax.set_ylim([0, L[1]])
return ax
#%%
boxplot_style()
color_dict = {'Cb stellate': '#40A787', # cyan'#
'RS Inhibitory': '#F0D730', # yellow
'FS': '#C02717', # red
'RS Pyramidal': '#007030', # dark green
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#AAB71B', # lightgreen
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#008797', # light blue
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#F78017', # orange
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#478010', # green
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#53379B', # purple
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#2060A7', # blue
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#873770', # magenta
'STN': '#D03050' # pink
}
model_letter = {
'Cb stellate': 'A',
'RS Inhibitory': 'B',
'FS': 'C',
'RS Pyramidal': 'D',
'RS Inhibitory': 'E',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'F',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'G',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'H',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'I',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'J',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':'K',
'STN':'L',
}
# plot setup
marker_s_leg = 2
max_width = 20
pad_x = 0.85
pad_y= 0.4
pad_w = 1.1
pad_h = 0.7
fig = plt.figure()
gs = fig.add_gridspec(3, 7, wspace=1.2, hspace=1.)
ax0 = fig.add_subplot(gs[0,2:7])
ax0_ex = fig.add_subplot(gs[0,1])
ax0_fI = fig.add_subplot(gs[0,0])
ax1 = fig.add_subplot(gs[1,2:7])
ax1_ex = fig.add_subplot(gs[1,1])
ax1_fI = fig.add_subplot(gs[1,0])
ax2 = fig.add_subplot(gs[2,2:7])
ax2_ex = fig.add_subplot(gs[2,1])
ax2_fI = fig.add_subplot(gs[2,0])
line_width = 1
# plot fI examples
ax0_fI = plot_fI(ax0_fI, model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='shift', alt='s', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax0_fI.transAxes, color=color_dict['FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'], alpha=1, zorder=-1)
rec = ax0_fI.add_patch(rec)
rec.set_clip_on(False)
ax1_fI = plot_fI(ax1_fI, model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='slope', alt='u', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax1_fI.transAxes, color=color_dict['FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'], alpha=1, zorder=-1)
rec = ax1_fI.add_patch(rec)
rec.set_clip_on(False)
ax2_fI = plot_fI(ax2_fI, model='STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='g', alt='Leak', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax2_fI.transAxes, color=color_dict['STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'], alpha=1, zorder=-1)
rec = ax2_fI.add_patch(rec)
rec.set_clip_on(False)
# plot boxplots
boxplot_with_markers(ax0,max_width, alteration='shift')
boxplot_with_markers(ax1,max_width, alteration='slope')
boxplot_with_markers(ax2,max_width, alteration='g')
# plot legend
pos = (0.225, -0.9)
ncol = 6
model_legend(ax2, marker_s_leg, pos, ncol)
# plot examples
plot_AUC_alt(ax0_ex,model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k', alteration='shift')
plot_AUC_alt(ax1_ex,model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k',alteration='slope')
plot_AUC_alt(ax2_ex, model='STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k', alteration='g')
# label subplots with letters
ax0_fI.text(-0.875, 1.35, string.ascii_uppercase[0], transform=ax0_fI.transAxes, size=10, weight='bold')
ax0_ex.text(-0.8, 1.35, string.ascii_uppercase[1], transform=ax0_ex.transAxes, size=10, weight='bold')
ax0.text(-0.075, 1.35, string.ascii_uppercase[2], transform=ax0.transAxes, size=10, weight='bold')
ax1_fI.text(-0.875, 1.35, string.ascii_uppercase[3], transform=ax1_fI.transAxes,size=10, weight='bold')
ax1_ex.text(-0.8, 1.35, string.ascii_uppercase[4], transform=ax1_ex.transAxes, size=10, weight='bold')
ax1.text(-0.075, 1.35, string.ascii_uppercase[5], transform=ax1.transAxes, size=10, weight='bold')
ax2_fI.text(-0.875, 1.35, string.ascii_uppercase[6], transform=ax2_fI.transAxes,size=10, weight='bold')
ax2_ex.text(-0.8, 1.35, string.ascii_uppercase[7], transform=ax2_ex.transAxes, size=10, weight='bold')
ax2.text(-0.075, 1.35, string.ascii_uppercase[8], transform=ax2.transAxes, size=10, weight='bold')
#save
fig.set_size_inches(cm2inch(20.75,12))
fig.savefig('./Figures/AUC_correlation.pdf', dpi=fig.dpi) #pdf #eps
# fig.savefig('./Figures/AUC_correlation.png', dpi=fig.dpi) #pdf #eps
plt.show()

View File

@@ -0,0 +1,322 @@
# # plot ramp protocol and responses of each model to ramp
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import matplotlib.gridspec as gridspec
# from matplotlib.transforms import Bbox
# import string
#
# def cm2inch(*tupl):
# inch = 2.54
# if isinstance(tupl[0], tuple):
# return tuple(i/inch for i in tupl[0])
# else:
# return tuple(i/inch for i in tupl)
#
# #### from https://gist.github.com/dmeliza/3251476 #####################################################################
# from matplotlib.offsetbox import AnchoredOffsetbox
# class AnchoredScaleBar(AnchoredOffsetbox):
# def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
# pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None,
# **kwargs):
# """
# Draw a horizontal and/or vertical bar with the size in data coordinate
# of the give axes. A label will be drawn underneath (center-aligned).
# - transform : the coordinate frame (typically axes.transData)
# - sizex,sizey : width of x,y bar, in data units. 0 to omit
# - labelx,labely : labels for x,y bars; None to omit
# - loc : position in containing axes
# - pad, borderpad : padding, in fraction of the legend font size (or prop)
# - sep : separation between labels and bars in points.
# - **kwargs : additional arguments passed to base class constructor
# """
# from matplotlib.patches import Rectangle
# from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
# bars = AuxTransformBox(transform)
# if sizex:
# bars.add_artist(Rectangle((0, 0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
# if sizey:
# bars.add_artist(Rectangle((0, 0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))
#
# if sizex and labelx:
# self.xlabel = TextArea(labelx)
# bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
# if sizey and labely:
# self.ylabel = TextArea(labely)
# bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)
#
# AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
# child=bars, prop=prop, frameon=False, **kwargs)
#
#
# def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, **kwargs):
# """ Add scalebars to axes
# Adds a set of scale bars to *ax*, matching the size to the ticks of the plot
# and optionally hiding the x and y axes
# - ax : the axis to attach ticks to
# - matchx,matchy : if True, set size of scale bars to spacing between ticks
# if False, size should be set using sizex and sizey params
# - hidex,hidey : if True, hide x-axis and y-axis of parent
# - **kwargs : additional arguments passed to AnchoredScaleBars
# Returns created scalebar object
# """
#
# def f(axis):
# l = axis.get_majorticklocs()
# return len(l) > 1 and (l[1] - l[0])
#
# if matchx:
# kwargs['sizex'] = f(ax.xaxis)
# kwargs['labelx'] = str(kwargs['sizex'])
# if matchy:
# kwargs['sizey'] = f(ax.yaxis)
# kwargs['labely'] = str(kwargs['sizey'])
#
# sb = AnchoredScaleBar(ax.transData, **kwargs)
# ax.add_artist(sb)
#
# if hidex: ax.xaxis.set_visible(False)
# if hidey: ax.yaxis.set_visible(False)
# if hidex and hidey: ax.set_frame_on(False)
#
# return sb
# ########################################################################################################################
#
#
# def plot_ramp_V(ax, model='RS Pyramidal'): # , stop=750
# model_ramp = pd.read_csv('./Figures/Data/model_ramp.csv')
# ax.plot(model_ramp['t'], model_ramp[model], 'k', linewidth=0.025)
# ax.set_ylabel('V')
# ax.set_xlabel('Time [s]')
# ax.set_ylim(-80, 60)
# ax.axis('off')
# ax.set_title(model)
#
# #% plot setup
# fig = plt.figure(figsize=cm2inch(17.6,17.6))
#
# gs0 = fig.add_gridspec(3, 2, wspace=0.1)
# gs00 = gs0[:,0].subgridspec(7, 2, wspace=0.6, hspace=1)
# gs01 = gs0[:,1].subgridspec(7, 2, wspace=0.6, hspace=1)
#
# ax1_ramp = fig.add_subplot(gs00[0,0:2])
# ax2_ramp = fig.add_subplot(gs01[0,0:2])
# ax3_ramp = fig.add_subplot(gs00[1,0:2])
# ax4_ramp = fig.add_subplot(gs01[1,0:2])
# ax5_ramp = fig.add_subplot(gs00[2, 0:2])
# ax6_ramp = fig.add_subplot(gs01[2, 0:2])
# ax7_ramp = fig.add_subplot(gs00[3,0:2])
# ax8_ramp = fig.add_subplot(gs01[3,0:2])
# ax9_ramp = fig.add_subplot(gs00[4,0:2])
# ax10_ramp = fig.add_subplot(gs01[4,0:2])
# ax11_ramp = fig.add_subplot(gs00[5,0:2])
# ax12_ramp = fig.add_subplot(gs01[5,0:2])
#
# ramp_axs = [ax1_ramp, ax2_ramp, ax3_ramp, ax4_ramp, ax5_ramp,ax6_ramp, ax7_ramp, ax8_ramp,
# ax9_ramp, ax10_ramp, ax11_ramp, ax12_ramp]
#
# # order of models
# models = ['Cb stellate','RS Inhibitory','FS', 'RS Pyramidal','RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
# 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
# 'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
# 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
# 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN']
#
# # plot ramps
# for i in range(len(models)):
# plot_ramp_V(ramp_axs[i], model=models[i])
#
# # add scalebar
# plt.rcParams.update({'font.size': 6})
#
# add_scalebar(ax11_ramp, matchx=False, matchy=False, hidex=True, hidey=True, sizex=1000, sizey=50, labelx='1 s',
# labely='50 mV', loc=3, pad=-2, borderpad=0, barwidth=1, bbox_to_anchor=Bbox.from_bounds(-0.05, 0.1, 1, 1),
# bbox_transform=ax11_ramp.transAxes)
# # add_scalebar(ax12_ramp, matchx=False, matchy=False, hidex=True, hidey=True, sizex=1000, sizey=25, labelx='1 s',
# # labely='25 mV', loc=3, pad=-2, borderpad=0, barwidth=2, bbox_to_anchor=Bbox.from_bounds(-0.05, 0.1, 1, 1),
# # bbox_transform=ax12_ramp.transAxes)
#
# # add subplot labels
# for i in range(0,len(models)):
# ramp_axs[i].text(-0.05, 1.08, string.ascii_uppercase[i], transform=ramp_axs[i].transAxes, size=10, weight='bold')
#
# fig.savefig('./Figures/ramp_firing.pdf', dpi=3000)
# plt.show()
# plot ramp protocol and responses of each model to ramp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.transforms import Bbox
import string
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
#### from https://gist.github.com/dmeliza/3251476 #####################################################################
from matplotlib.offsetbox import AnchoredOffsetbox
class AnchoredScaleBar(AnchoredOffsetbox):
def __init__(self, transform, sizex=0, sizey=0, labelx=None, labely=None, loc=4,
pad=0.1, borderpad=0.1, sep=2, prop=None, barcolor="black", barwidth=None,
**kwargs):
"""
Draw a horizontal and/or vertical bar with the size in data coordinate
of the give axes. A label will be drawn underneath (center-aligned).
- transform : the coordinate frame (typically axes.transData)
- sizex,sizey : width of x,y bar, in data units. 0 to omit
- labelx,labely : labels for x,y bars; None to omit
- loc : position in containing axes
- pad, borderpad : padding, in fraction of the legend font size (or prop)
- sep : separation between labels and bars in points.
- **kwargs : additional arguments passed to base class constructor
"""
from matplotlib.patches import Rectangle
from matplotlib.offsetbox import AuxTransformBox, VPacker, HPacker, TextArea, DrawingArea
bars = AuxTransformBox(transform)
if sizex:
bars.add_artist(Rectangle((0, 0), sizex, 0, ec=barcolor, lw=barwidth, fc="none"))
if sizey:
bars.add_artist(Rectangle((0, 0), 0, sizey, ec=barcolor, lw=barwidth, fc="none"))
if sizex and labelx:
self.xlabel = TextArea(labelx)
bars = VPacker(children=[bars, self.xlabel], align="center", pad=0, sep=sep)
if sizey and labely:
self.ylabel = TextArea(labely)
bars = HPacker(children=[self.ylabel, bars], align="center", pad=0, sep=sep)
AnchoredOffsetbox.__init__(self, loc, pad=pad, borderpad=borderpad,
child=bars, prop=prop, frameon=False, **kwargs)
def add_scalebar(ax, matchx=True, matchy=True, hidex=True, hidey=True, **kwargs):
""" Add scalebars to axes
Adds a set of scale bars to *ax*, matching the size to the ticks of the plot
and optionally hiding the x and y axes
- ax : the axis to attach ticks to
- matchx,matchy : if True, set size of scale bars to spacing between ticks
if False, size should be set using sizex and sizey params
- hidex,hidey : if True, hide x-axis and y-axis of parent
- **kwargs : additional arguments passed to AnchoredScaleBars
Returns created scalebar object
"""
def f(axis):
l = axis.get_majorticklocs()
return len(l) > 1 and (l[1] - l[0])
if matchx:
kwargs['sizex'] = f(ax.xaxis)
kwargs['labelx'] = str(kwargs['sizex'])
if matchy:
kwargs['sizey'] = f(ax.yaxis)
kwargs['labely'] = str(kwargs['sizey'])
sb = AnchoredScaleBar(ax.transData, **kwargs)
ax.add_artist(sb)
if hidex: ax.xaxis.set_visible(False)
if hidey: ax.yaxis.set_visible(False)
if hidex and hidey: ax.set_frame_on(False)
return sb
########################################################################################################################
def plot_ramp_V(ax, model='RS Pyramidal'): # , stop=750
model_ramp = pd.read_csv('./Figures/Data/model_ramp.csv')
# ax.plot(model_ramp['t'], model_ramp[model], 'k', linewidth=0.0025)
ax.plot(model_ramp['t'], model_ramp[model], 'k', linewidth=0.1)
ax.set_ylabel('V')
ax.set_xlabel('Time [s]')
ax.set_ylim(-80, 60)
ax.axis('off')
ax.set_title(model, fontsize=8)
def plot_I_ramp(ax):
dt = 0.01
I_low = 0
I_high = 0.001
initial_period = 1000
sec = 4
ramp_len = int(4 * 1000 * 1 / dt)
stim_time = ramp_len * 2
I_amp = np.array([0])
I_amp = np.reshape(I_amp, (1, I_amp.shape[0]))
I_ramp = np.zeros((stim_time, 1)) @ I_amp
I_ramp[:, :] = np.ones((stim_time, 1)) @ I_amp
stim_num_step = I_ramp.shape[1]
start=0
I_ramp[start:int(start + ramp_len), 0] = np.linspace(0, I_high, ramp_len)
I_ramp[int(start + ramp_len):int(start + ramp_len * 2), 0] = np.linspace(I_high, 0, ramp_len)
t = np.arange(0, 4000 * 2, dt)
ax.plot(t, I_ramp)
ax.set_ylabel('I')
ax.set_xlabel('Time [s]')
ax.axis('off')
ax.set_title('Ramp Current', fontsize=8, x=0.5, y=-0.5)
return ax
#% plot setup
fig = plt.figure(figsize=cm2inch(17.6,25))
gs0 = fig.add_gridspec(2, 1, wspace=0.)
gs00 = gs0[:].subgridspec(13, 1, wspace=0.7, hspace=1.0)
ax1_ramp = fig.add_subplot(gs00[0])
ax2_ramp = fig.add_subplot(gs00[1])
ax3_ramp = fig.add_subplot(gs00[2])
ax4_ramp = fig.add_subplot(gs00[3])
ax5_ramp = fig.add_subplot(gs00[4])
ax6_ramp = fig.add_subplot(gs00[5])
ax7_ramp = fig.add_subplot(gs00[6])
ax8_ramp = fig.add_subplot(gs00[7])
ax9_ramp = fig.add_subplot(gs00[8])
ax10_ramp = fig.add_subplot(gs00[9])
ax11_ramp = fig.add_subplot(gs00[10])
ax12_ramp = fig.add_subplot(gs00[11])
ax13_I = fig.add_subplot(gs00[12])
ramp_axs = [ax1_ramp, ax2_ramp, ax3_ramp, ax4_ramp, ax5_ramp,ax6_ramp, ax7_ramp, ax8_ramp,
ax9_ramp, ax10_ramp, ax11_ramp, ax12_ramp]
# order of models
models = ['Cb stellate','RS Inhibitory','FS', 'RS Pyramidal','RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN']
# plot ramps
for i in range(len(models)):
plot_ramp_V(ramp_axs[i], model=models[i])
# add scalebar
plt.rcParams.update({'font.size': 6})
add_scalebar(ax12_ramp, matchx=False, matchy=False, hidex=True, hidey=True, sizex=1000, sizey=50, labelx='1 s',
labely='50 mV', loc=3, pad=-2, borderpad=0, barwidth=1, bbox_to_anchor=Bbox.from_bounds(0.01, 0.05, 1, 1),
bbox_transform=ax12_ramp.transAxes)
ax13_I = plot_I_ramp(ax13_I)
add_scalebar(ax13_I, matchx=False, matchy=False, hidex=True, hidey=True, sizex=1000, sizey=0.0005, labelx='1 s',
labely='0.5 $I_{max}$', loc=3, pad=-2, borderpad=0, barwidth=1,
bbox_to_anchor=Bbox.from_bounds(0.0, -0.01, 1, 1), bbox_transform=ax13_I.transAxes)
# add subplot labels
for i in range(0,len(models)):
ramp_axs[i].text(-0.01, 1.1, string.ascii_uppercase[i], transform=ramp_axs[i].transAxes, size=10, weight='bold')
#save
fig.set_size_inches(cm2inch(17.6,22))
fig.savefig('./Figures/ramp_firing.png', dpi=fig.dpi)#pdf #eps
plt.show()

Binary file not shown.

View File

@@ -157,6 +157,8 @@ def boxplot_with_markers(ax,max_width, alteration='shift', msize=2.2):
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
print(colors)
print(clr_dict)
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
if alteration=='shift':
i = 3 # Kv1.1 act
@@ -217,27 +219,52 @@ def boxplot_with_markers(ax,max_width, alteration='shift', msize=2.2):
def model_legend(ax, marker_s_leg, pos, ncol):
colorslist = [ '#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
# colorslist = [ '#40A787', # cyan'#
# '#F0D730', # yellow
# '#C02717', # red
# '#007030', # dark green
# '#AAB71B', # lightgreen
# '#008797', # light blue
# '#F78017', # orange
# '#478010', # green
# '#53379B', # purple
# '#2060A7', # blue
# '#873770', # magenta
# '#D03050' # pink
# ]
# import matplotlib.colors
# colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
model_pos = {'Cb stellate':0, 'RS Inhibitory':1, 'FS':2, 'RS Pyramidal':3,
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':4,
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':5, 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':6,
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':7, 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':8,
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':9,
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':10, 'STN':11}
# model_pos = {'Cb stellate': 0, 'RS Inhibitory': 1, 'FS': 2, 'RS Pyramidal': 3,
# 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 4,
# 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 5,
# 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 6,
# 'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 7,
# 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 8,
# 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 9,
# 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 10, 'STN': 11}
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
RS_p = mlines.Line2D([], [], color=colors[model_pos['RS Pyramidal']], marker=Markers[model_pos['RS Pyramidal']], markersize=marker_s_leg, linestyle='None',
label='RS pyramidal')

View File

@@ -0,0 +1,648 @@
# -*- coding: utf-8 -*-
"""
Created on Sat Jul 3 19:52:04 2021
@author: nils
"""
import pandas as pd
import numpy as np
import string
import textwrap
import json
import matplotlib
import matplotlib.lines as mlines
from matplotlib import ticker
from matplotlib.ticker import NullFormatter
from Figures.plotstyle import boxplot_style
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
#%% ##################### From https://stackoverflow.com/questions/52878845/swarmplot-with-hue-affecting-marker-beyond-color ##
# to change marker types in seaborn swarmplot
import seaborn as sns
import matplotlib.pyplot as plt
############## Begin hack ##############
from matplotlib.axes._axes import Axes
from matplotlib.markers import MarkerStyle
from numpy import ndarray
def GetColor2Marker(markers):
colorslist = ['#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
import matplotlib.colors
palette = [matplotlib.colors.to_rgb(c) for c in colorslist] #
mkcolors = [(palette[i]) for i in range(len(markers))]
return dict(zip(mkcolors,markers))
def fixlegend(ax,markers,markersize=3,**kwargs):
# Fix Legend
legtitle = ax.get_legend().get_title().get_text()
_,l = ax.get_legend_handles_labels()
colorslist = ['#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
import matplotlib.colors
palette = [matplotlib.colors.to_rgb(c) for c in colorslist]
mkcolors = [(palette[i]) for i in range(len(markers))]
newHandles = [plt.Line2D([0],[0], ls="none", marker=m, color=c, mec="none", markersize=markersize,**kwargs) \
for m,c in zip(markers, mkcolors)]
ax.legend(newHandles,l)
leg = ax.get_legend()
leg.set_title(legtitle)
old_scatter = Axes.scatter
def new_scatter(self, *args, **kwargs):
colors = kwargs.get("c", None)
co2mk = kwargs.pop("co2mk",None)
FinalCollection = old_scatter(self, *args, **kwargs)
if co2mk is not None and isinstance(colors, ndarray):
Color2Marker = GetColor2Marker(co2mk)
paths=[]
for col in colors:
mk=Color2Marker[tuple(col)]
marker_obj = MarkerStyle(mk)
paths.append(marker_obj.get_path().transformed(marker_obj.get_transform()))
FinalCollection.set_paths(paths)
return FinalCollection
Axes.scatter = new_scatter
############## End hack. ##############
########################################################################################################################
#%% add gradient arrows
import matplotlib.pyplot as plt
import matplotlib.transforms
import matplotlib.path
from matplotlib.collections import LineCollection
def rainbowarrow(ax, start, end, cmap, n=50,lw=3):
# Arrow shaft: LineCollection
x = np.linspace(start[0],end[0],n)
y = np.linspace(start[1],end[1],n)
points = np.array([x,y]).T.reshape(-1,1,2)
segments = np.concatenate([points[:-1],points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, linewidth=lw)
lc.set_array(np.linspace(0,1,n))
ax.add_collection(lc)
# Arrow head: Triangle
tricoords = [(0,-0.02),(0.025,0),(0,0.02),(0,-0.02)]
angle = np.arctan2(end[1]-start[1],end[0]-start[0])
rot = matplotlib.transforms.Affine2D().rotate(angle)
tricoords2 = rot.transform(tricoords)
tri = matplotlib.path.Path(tricoords2, closed=True)
ax.scatter(end[0],end[1], c=1, s=(4*lw)**2, marker=tri, cmap=cmap,vmin=0)
ax.autoscale_view()
return ax
def gradientaxis(ax, start, end, cmap, n=100,lw=1):
# Arrow shaft: LineCollection
x = np.linspace(start[0],end[0],n)
y = np.linspace(start[1],end[1],n)
points = np.array([x,y]).T.reshape(-1,1,2)
segments = np.concatenate([points[:-1],points[1:]], axis=1)
lc = LineCollection(segments, cmap=cmap, linewidth=lw,zorder=15)
lc.set_array(np.linspace(0,1,n))
ax.add_collection(lc)
return ax
#%%
def boxplot_with_markers(ax,max_width, alteration='shift', msize=2.2):
hlinewidth = 0.5
model_names = ['RS pyramidal','RS inhibitory','FS',
'RS pyramidal +$K_V1.1$','RS inhibitory +$K_V1.1$',
'FS +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
'STN $\Delta$$K_V1.1$']
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
if alteration=='shift':
i = 3 # Kv1.1 act
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor='k')
df = pd.read_csv('./Figures/Data/rheo_shift_corr.csv')
sns.swarmplot(y="corr", x="$\Delta V_{1/2}$", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na activation', 'Na inactivation', 'K activation', '$K_V1.1$ activation',
'$K_V1.1$ inactivation', 'A activation', 'A inactivation'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
ax.set_title("Shift ($\Delta V_{1/2}$)", y=1.05)
ax.set_xticklabels(['Na \nactivation', 'Na \ninactivation', 'K \nactivation', '$K_V1.1$ \nactivation',
'$K_V1.1$ \ninactivation', 'A \nactivation', 'A \ninactivation'])
elif alteration=='slope':
i = 4 # Kv1.1 inact
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor='k')
df = pd.read_csv('./Figures/Data/rheo_scale_corr.csv')
# Add in points to show each observation
sns.swarmplot(y="corr", x="Slope (k)", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na activation', 'Na inactivation', 'K activation', '$K_V1.1$ activation',
'$K_V1.1$ inactivation', 'A activation', 'A inactivation'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
ax.set_title("Slope (k)", y=1.05)
ax.set_xticklabels(['Na \nactivation', 'Na \ninactivation', 'K \nactivation', '$K_V1.1$ \nactivation',
'$K_V1.1$ \ninactivation', 'A \nactivation', 'A \ninactivation'])
elif alteration=='g':
i = 4 # Leak
ax.axvspan(i - 0.4, i + 0.4, fill=False, edgecolor='k')
df = pd.read_csv('./Figures/Data/rheo_g_corr.csv')
# Add in points to show each observation
sns.swarmplot(y="corr", x="g", hue="model", data=df,
palette=clr_dict, linewidth=0, orient='v', ax=ax, size=msize,
order=['Na', 'K', '$K_V1.1$', 'A', 'Leak'],
hue_order=model_names, co2mk=Markers)
lim = ax.get_xlim()
ax.plot([lim[0], lim[1]], [0, 0], ':r',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [1, 1], ':k',linewidth=hlinewidth)
ax.plot([lim[0], lim[1]], [-1, -1], ':k',linewidth=hlinewidth)
ax.set_title("Conductance (g)", y=1.05)
ax.set_xticklabels(textwrap.fill(x.get_text(), max_width) for x in ax.get_xticklabels())
else:
print('Please chose "shift", "slope" or "g"')
ax.get_legend().remove()
ax.xaxis.grid(False)
sns.despine(trim=True, bottom=True, ax=ax)
ax.set(xlabel=None, ylabel=r'Kendall $\it{\tau}$')
def model_legend(ax, marker_s_leg, pos, ncol):
colorslist = [ '#40A787', # cyan'#
'#F0D730', # yellow
'#C02717', # red
'#007030', # dark green
'#AAB71B', # lightgreen
'#008797', # light blue
'#F78017', # orange
'#478010', # green
'#53379B', # purple
'#2060A7', # blue
'#873770', # magenta
'#D03050' # pink
]
model_names = ['RS pyramidal', 'RS inhibitory', 'FS',
'RS pyramidal +$K_V1.1$', 'RS inhibitory +$K_V1.1$',
'FS +$K_V1.1$', 'Cb stellate', 'Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$K_V1.1$', 'STN', 'STN +$K_V1.1$',
'STN $\Delta$$K_V1.1$']
# colorslist = ['#007030', # dark green
# '#F0D730', # yellow
# '#C02717', # red
# '#478010', # green
# '#AAB71B', # lightgreen
# '#F78017', # orange
# '#40A787', # cyan'#
# '#008797', # light blue
# '#2060A7', # blue
# '#D03050', # pink
# '#53379B', # purple
# '#873770', # magenta
# ]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
# model_pos = {'Cb stellate':0, 'RS Inhibitory':1, 'FS':2, 'RS Pyramidal':3,
# 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':4,
# 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':5, 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':6,
# 'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':7, 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':8,
# 'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':9,
# 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':10, 'STN':11}
Markers = ["o", "o", "o", "^", "^", "^", "D", "D", "D", "s", "s", "s"]
model_pos = {'RS Pyramidal':0, 'RS Inhibitory':1, 'FS':2,
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':3, 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':4,
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':5, 'Cb stellate':6, 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':7,
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':8, 'STN':9, 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':10,
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':11}
RS_p = mlines.Line2D([], [], color='#007030', marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model D')
RS_i = mlines.Line2D([], [], color='#F0D730', marker="o",
markersize=marker_s_leg, linestyle='None',
label='Model B')
FS = mlines.Line2D([], [], color='#C02717', marker="o", markersize=marker_s_leg,
linestyle='None', label='Model C')
RS_p_Kv = mlines.Line2D([], [], color='#478010',
marker="D",
markersize=marker_s_leg, linestyle='None',
label='Model H')
RS_i_Kv = mlines.Line2D([], [], color='#AAB71B',
marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model E')
FS_Kv = mlines.Line2D([], [], color='#F78017',
marker="D", markersize=marker_s_leg,
linestyle='None', label='Model G')
Cb = mlines.Line2D([], [], color='#40A787', marker="o",
markersize=marker_s_leg, linestyle='None',
label='Model A')
Cb_pl = mlines.Line2D([], [], color='#008797',
marker="^",
markersize=marker_s_leg, linestyle='None',
label='Model F')
Cb_sw = mlines.Line2D([], [], color='#2060A7',
marker="s",
markersize=marker_s_leg, linestyle='None',
label='Model J')
STN = mlines.Line2D([], [], color='#D03050', marker="s", markersize=marker_s_leg,
linestyle='None', label='Model L')
STN_pl = mlines.Line2D([], [], color='#53379B',
marker="D",
markersize=marker_s_leg, linestyle='None',
label='Model I')
STN_sw = mlines.Line2D([], [], color='#873770',
marker="s",
markersize=marker_s_leg, linestyle='None',
label='Model K')
#
# RS_p = mlines.Line2D([], [], color=colors[model_pos['RS Pyramidal']], marker=Markers[model_pos['RS Pyramidal']],
# markersize=marker_s_leg, linestyle='None',
# label='Model D')
# RS_i = mlines.Line2D([], [], color=colors[model_pos['RS Inhibitory']], marker=Markers[model_pos['RS Inhibitory']],
# markersize=marker_s_leg, linestyle='None',
# label='Model B')
# FS = mlines.Line2D([], [], color=colors[model_pos['FS']], marker=Markers[model_pos['FS']], markersize=marker_s_leg,
# linestyle='None', label='Model C')
# RS_p_Kv = mlines.Line2D([], [], color=colors[model_pos['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model H')
# RS_i_Kv = mlines.Line2D([], [], color=colors[model_pos['RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model E')
# FS_Kv = mlines.Line2D([], [], color=colors[model_pos['FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']], markersize=marker_s_leg,
# linestyle='None', label='Model G')
# Cb = mlines.Line2D([], [], color=colors[model_pos['Cb stellate']], marker=Markers[model_pos['Cb stellate']],
# markersize=marker_s_leg, linestyle='None',
# label='Model A')
# Cb_pl = mlines.Line2D([], [], color=colors[model_pos['Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model F')
# Cb_sw = mlines.Line2D([], [], color=colors[model_pos['Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model J')
# STN = mlines.Line2D([], [], color=colors[model_pos['STN']], marker=Markers[model_pos['STN']], markersize=marker_s_leg,
# linestyle='None', label='Model L')
# STN_pl = mlines.Line2D([], [], color=colors[model_pos['STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model I')
# STN_sw = mlines.Line2D([], [], color=colors[model_pos['STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# marker=Markers[model_pos['STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']],
# markersize=marker_s_leg, linestyle='None',
# label='Model K')
# ax.legend(handles=[RS_p, RS_i, FS, RS_p_Kv, RS_i_Kv, FS_Kv, Cb, Cb_pl, Cb_sw, STN, STN_pl, STN_sw], loc='center',
# bbox_to_anchor=pos, ncol=ncol, frameon=False)
ax.legend(handles=[Cb, RS_i, FS, RS_p, RS_i_Kv, Cb_pl, FS_Kv, RS_p_Kv, STN_pl, Cb_sw, STN_sw, STN], loc='center',
bbox_to_anchor=pos, ncol=ncol, frameon=False)
def plot_rheo_alt(ax, model='FS', color1='red', color2='dodgerblue', alteration='shift'):
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
model_names = ['RS Pyramidal','RS Inhibitory','FS',
'RS Pyramidal +$K_V1.1$','RS Inhibitory +$K_V1.1$',
'FS +$K_V1.1$','Cb stellate','Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$K_V1.1$','STN','STN +$K_V1.1$',
'STN $\Delta$$K_V1.1$']
model_name_dict = {'RS Pyramidal': 'RS Pyramidal',
'RS Inhibitory': 'RS Inhibitory',
'FS': 'FS',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS Pyramidal +$K_V1.1$',
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS Inhibitory +$K_V1.1$',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'FS +$K_V1.1$',
'Cb stellate': 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb stellate +$K_V1.1$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb stellate $\Delta$$K_V1.1$',
'STN': 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN +$K_V1.1$',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN $\Delta$$K_V1.1$'}
colorslist = ['#007030', # dark green
'#F0D730', # yellow
'#C02717', # red
'#478010', # green
'#AAB71B', # lightgreen
'#F78017', # orange
'#40A787', # cyan'#
'#008797', # light blue
'#2060A7', # blue
'#D03050', # pink
'#53379B', # purple
'#873770', # magenta
]
import matplotlib.colors
colors = [matplotlib.colors.to_rgb(c) for c in colorslist]
clr_dict = {}
for m in range(len(model_names)):
clr_dict[model_names[m]] = colors[m]
if alteration=='shift':
df = pd.read_csv('./Figures/Data/rheo_shift_ex.csv')
df = df.sort_values('alteration')
ax.set_xlabel('$\Delta$$V_{1/2}$')
elif alteration=='slope':
df = pd.read_csv('./Figures/Data/rheo_slope_ex.csv')
ax.set_xscale("log")
ax.set_xticks([0.5, 1, 2])
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.set_minor_formatter(NullFormatter())
ax.set_xlabel('$k$/$k_{WT}$')
elif alteration=='g':
df = pd.read_csv('./Figures/Data/rheo_g_ex.csv')
ax.set_xscale("log")
ax.set_xticks([0.5, 1, 2])
ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
ax.xaxis.set_minor_formatter(NullFormatter())
ax.set_xlabel('$g$/$g_{WT}$')
for mod in model_names:
if mod == model_name_dict[model]:
ax.plot(df['alteration'], df[mod], color=clr_dict[mod], alpha=1, zorder=10, linewidth=2)
else:
ax.plot(df['alteration'], df[mod], color=clr_dict[mod], alpha=0.5, zorder=1, linewidth=1)
ax.set_ylabel('$\Delta$ Rheobase (nA)', labelpad=0)
x = df['alteration']
y = df[model_name_dict[model]]
ax.set_xlim(x.min(), x.max())
ax.set_ylim(df[model_names].min().min(), df[model_names].max().max())
# x axis color gradient
cvals = [-2., 2]
colors = ['lightgrey', 'k']
norm = plt.Normalize(min(cvals), max(cvals))
tuples = list(zip(map(norm, cvals), colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
(xstart, xend) = ax.get_xlim()
(ystart, yend) = ax.get_ylim()
print(ystart, yend)
start = (xstart, ystart*1.0)
end = (xend, ystart*1.0)
ax = gradientaxis(ax, start, end, cmap, n=200,lw=4)
ax.spines['bottom'].set_visible(False)
# ax.set_ylim(ystart, yend)
#xlabel tick colors
# my_colors = ['lightgrey', 'grey', 'k']
# for ticklabel, tickcolor in zip(ax.get_xticklabels(), my_colors):
# ticklabel.set_color(tickcolor)
return ax
def plot_fI(ax, model='RS Pyramidal', type='shift', alt='m', color1='red', color2='dodgerblue'):
model_save_name = {'RS Pyramidal': 'RS_pyr_posp',
'RS Inhibitory': 'RS_inhib_posp',
'FS': 'FS_posp',
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS_pyr_Kv',
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'RS_inhib_Kv',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'FS_Kv',
'Cb stellate': 'Cb_stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb_stellate_Kv',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'Cb_stellate_Kv_only',
'STN': 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN_Kv',
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': 'STN_Kv_only'}
cvals = [-2., 2]
colors = [color1, color2]
norm = plt.Normalize(min(cvals), max(cvals))
tuples = list(zip(map(norm, cvals), colors))
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
colors = cmap(np.linspace(0, 1, 22))
df = pd.read_csv('./Figures/Data/Model_fI/{}_fI.csv'.format(model_save_name[model]))
df.drop(['Unnamed: 0'], axis=1)
newdf = df.loc[df.index[(df['alt'] == alt) & (df['type'] == type)], :]
newdf['mag'] = newdf['mag'].astype('float')
newdf = newdf.sort_values('mag').reset_index()
c = 0
for i in newdf.index:
ax.plot(json.loads(newdf.loc[i, 'I']), json.loads(newdf.loc[i, 'F']), color=colors[c])
c += 1
# colors2 = [colors[10, :], 'k']
# norm2 = plt.Normalize(min(cvals), max(cvals))
# tuples2 = list(zip(map(norm2, cvals), colors2))
# cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples2)
#
# colors3 = [colors[11, :], 'lightgrey']
# norm3 = plt.Normalize(min(cvals), max(cvals))
# tuples3 = list(zip(map(norm3, cvals), colors3))
# cmap3 = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples3)
#
# start = (1.1, json.loads(newdf.loc[10, 'F'])[-1])
# end = (1.1, json.loads(newdf.loc[20, 'F'])[-1])#-json.loads(newdf.loc[20, 'F'])[-1]*0.1)
# ax = rainbowarrow(ax, start, end, cmap2, n=50, lw=1)
# ax.text(1.15, json.loads(newdf.loc[20, 'F'])[-1], '$+ \Delta V$', fontsize=4, color='k')
#
# start = (1.1, json.loads(newdf.loc[10, 'F'])[-1])
# end = (1.1, json.loads(newdf.loc[0, 'F'])[-1])#-json.loads(newdf.loc[0, 'F'])[-1]*0.1)
# ax = rainbowarrow(ax, start, end, cmap3, n=50, lw=1)
# ax.text(1.15, json.loads(newdf.loc[0, 'F'])[-1], '$- \Delta V$', fontsize=4, color='lightgrey')
ax.set_ylabel('Frequency [Hz]')
ax.set_xlabel('Current [nA]')
if model == 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':
ax.set_title("Model G", x=0.2, y=1.0)
elif model == 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$':
ax.set_title("Model F", x=0.2, y=1.0)
elif model == 'Cb stellate':
ax.set_title("Model A", x=0.2, y=1.0)
else:
ax.set_title("", x=0.2, y=1.0)
# plot_rheo_alt(ax0_ex, model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k',
# alteration='shift')
# plot_rheo_alt(ax1_ex, model='Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k',
# alteration='slope')
# plot_rheo_alt(ax2_ex, model='Cb stellate', color1='lightgrey', color2='k', alteration='g')
# ax.set_title(model, x=0.2, y=1.025)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
L = ax.get_ylim()
ax.set_ylim([0, L[1]])
return ax
#%%
boxplot_style()
color_dict = {'Cb stellate': '#40A787', # cyan'#
'RS Inhibitory': '#F0D730', # yellow
'FS': '#C02717', # red
'RS Pyramidal': '#007030', # dark green
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#AAB71B', # lightgreen
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#008797', # light blue
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#F78017', # orange
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#478010', # green
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#53379B', # purple
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#2060A7', # blue
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': '#873770', # magenta
'STN': '#D03050' # pink
}
# plot setup
marker_s_leg = 2
max_width = 20
pad_x = 0.85
pad_y= 0.4
pad_w = 1.1
pad_h = 0.7
fig = plt.figure()
gs = fig.add_gridspec(3, 7, wspace=1.2, hspace=1.)
ax0 = fig.add_subplot(gs[0,2:7])
ax0_ex = fig.add_subplot(gs[0,1])
ax0_fI = fig.add_subplot(gs[0,0])
ax1 = fig.add_subplot(gs[1,2:7])
ax1_ex = fig.add_subplot(gs[1,1])
ax1_fI = fig.add_subplot(gs[1,0])
ax2 = fig.add_subplot(gs[2,2:7])
ax2_ex = fig.add_subplot(gs[2,1])
ax2_fI = fig.add_subplot(gs[2,0])
line_width = 1
# plot fI curves
ax0_fI = plot_fI(ax0_fI, model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='shift', alt='s', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax0_fI.transAxes, color=color_dict['FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'], alpha=1, zorder=-1)
rec = ax0_fI.add_patch(rec)
rec.set_clip_on(False)
ax1_fI = plot_fI(ax1_fI, model='Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='slope', alt='u', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax1_fI.transAxes, color=color_dict['Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$'], alpha=1, zorder=-1)
rec = ax1_fI.add_patch(rec)
rec.set_clip_on(False)
ax2_fI = plot_fI(ax2_fI, model='Cb stellate', type='g', alt='Leak', color1='lightgrey', color2='k')
rec = plt.Rectangle((-pad_x, -pad_y), 1 + pad_w, 1 + pad_h, fill=False, lw=line_width,transform=ax2_fI.transAxes, color=color_dict['Cb stellate'], alpha=1, zorder=-1)
rec = ax2_fI.add_patch(rec)
rec.set_clip_on(False)
# plot boxplots
boxplot_with_markers(ax0,max_width, alteration='shift')
boxplot_with_markers(ax1,max_width, alteration='slope')
boxplot_with_markers(ax2,max_width, alteration='g')
# plot legend
pos = (0.225, -0.9)
ncol = 6
model_legend(ax2, marker_s_leg, pos, ncol)
# plot rheo across model for example alteration
plot_rheo_alt(ax0_ex,model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k', alteration='shift')
plot_rheo_alt(ax1_ex,model='Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', color1='lightgrey', color2='k',alteration='slope')
plot_rheo_alt(ax2_ex, model='Cb stellate', color1='lightgrey', color2='k', alteration='g')
# label subplots with letters
ax0_fI.text(-0.875, 1.35, string.ascii_uppercase[0], transform=ax0_fI.transAxes, size=10, weight='bold')
ax0_ex.text(-0.8, 1.35, string.ascii_uppercase[1], transform=ax0_ex.transAxes, size=10, weight='bold')
ax0.text(-0.075, 1.35, string.ascii_uppercase[2], transform=ax0.transAxes, size=10, weight='bold')
ax1_fI.text(-0.875, 1.35, string.ascii_uppercase[3], transform=ax1_fI.transAxes,size=10, weight='bold')
ax1_ex.text(-0.8, 1.35, string.ascii_uppercase[4], transform=ax1_ex.transAxes, size=10, weight='bold')
ax1.text(-0.075, 1.35, string.ascii_uppercase[5], transform=ax1.transAxes, size=10, weight='bold')
ax2_fI.text(-0.875, 1.35, string.ascii_uppercase[6], transform=ax2_fI.transAxes,size=10, weight='bold')
ax2_ex.text(-0.8, 1.35, string.ascii_uppercase[7], transform=ax2_ex.transAxes, size=10, weight='bold')
ax2.text(-0.075, 1.35, string.ascii_uppercase[8], transform=ax2.transAxes, size=10, weight='bold')
# save
fig.set_size_inches(cm2inch(20.75,12))
fig.savefig('./Figures/rheobase_correlation.pdf', dpi=fig.dpi)
# fig.savefig('./Figures/rheobase_correlation.png', dpi=fig.dpi) #bbox_inches='tight', dpi=fig.dpi # eps # pdf
plt.show()
#%%
# fig, axs = plt.subplots(1,2)
# axs[0] = plot_fI(axs[0] , model='FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', type='shift', alt='s', color1='lightgrey', color2='k')
# plt.show()
#%%
#
#
# cvals = [-2., 2]
# colors = ['lightgrey', 'k']
#
# norm = plt.Normalize(min(cvals), max(cvals))
# tuples = list(zip(map(norm, cvals), colors))
# cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
# colors = cmap(np.linspace(0, 1, 22))
#
# colors2 = [colors[10,:], 'k']
# norm2 = plt.Normalize(min(cvals), max(cvals))
# tuples2 = list(zip(map(norm2, cvals), colors2))
# cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples2)
#
# colors3 = [colors[11,:], 'lightgrey']
# norm3 = plt.Normalize(min(cvals), max(cvals))
# tuples3 = list(zip(map(norm3, cvals), colors3))
# cmap3 = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples3)
#
# fig, axs = plt.subplots(1,2)
# start = (0,0)
# end = (1,1)
# axs[0] = rainbowarrow(axs[0], start, end, cmap2, n=50,lw=3)
# start = (0,0)
# end = (-1,-1)
# axs[0] = rainbowarrow(axs[0], start, end, cmap3, n=50,lw=3)
# plt.show()

View File

@@ -0,0 +1,259 @@
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import string
from Figures.plotstyle import sim_style
import seaborn as sns
import scipy.stats as stats
import matplotlib.lines as mlines
def cm2inch(*tupl):
inch = 2.54
if isinstance(tupl[0], tuple):
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
def Kendall_tau(df):
tau = df.corr(method='kendall')
p = pd.DataFrame(columns=df.columns, index=df.columns)
for col in range((df.columns).shape[0]):
for col2 in range((df.columns).shape[0]):
if col != col2:
_, p.loc[df.columns[col], df.columns[col2]] = stats.kendalltau(
df[df.columns[col]], df[df.columns[col2]], nan_policy='omit')
return tau, p
def correlation_plot(ax, df='AUC', title='', cbar=False):
# cbar_ax = fig.add_axes([0.685, 0.44, .15, .01])
cbar_ax = fig.add_axes([0.685, 0.48, .15, .01])
cbar_ax.spines['left'].set_visible(False)
cbar_ax.spines['bottom'].set_visible(False)
cbar_ax.spines['right'].set_visible(False)
cbar_ax.spines['top'].set_visible(False)
cbar_ax.set_xticks([])
cbar_ax.set_yticks([])
if df == 'AUC':
df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
elif df == 'rheo':
df = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
# array for names
cmap = sns.diverging_palette(220, 10, as_cmap=True)
models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
# model_letter_names = ['Model H', 'Model E', 'Model G', 'Model A', 'Model F', 'Model J', 'Model L', 'Model I', 'Model K']
model_letter_names = ['H', 'E', 'G', 'A', 'F', 'J', 'L', 'I', 'K']
col_dict = {}
for m in range(len(models)):
col_dict[model_names[m]] = model_letter_names[m]
df.rename(columns=col_dict, inplace=True)
df = df[model_letter_names]
# calculate correlation matrix
tau, p = Kendall_tau(df)
# tau = tau.drop(columns='STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', index='RS pyramidal')
# mask to hide upper triangle of matrix
mask = np.zeros_like(tau, dtype=bool)
mask[np.triu_indices_from(mask)] = True
np.fill_diagonal(mask, False)
# models and renaming of tau
models = ['RS pyramidal', 'RS inhibitory', 'FS', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
model_names = ['RS pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'RS inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
# model_letter_names = ['Model H', 'Model E', 'Model G', 'Model A', 'Model F', 'Model J', 'Model L', 'Model I', 'Model K']
model_letter_names = ['H', 'E', 'G', 'A', 'F', 'J', 'L', 'I', 'K']
col_dict = {}
for m in range(len(models)):
col_dict[model_names[m]] = model_letter_names[m]
tau.rename(columns=col_dict, index=col_dict, inplace=True)
tau = tau[model_letter_names]
# plotting with or without colorbar
if cbar==False:
res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
cbar=False, cmap=cmap, cbar_ax=cbar_ax, cbar_kws={"shrink": .52})
else:
res = sns.heatmap(tau, annot=False, mask=mask, center=0, vmax=1, vmin=-1, linewidths=.5, square=True, ax=ax,
cbar=True, cmap=cmap, cbar_ax=cbar_ax,
cbar_kws={"orientation": "horizontal",
"ticks": [-1,-0.5, 0, 0.5, 1]} )
cbar_ax.set_title(r'Kendall $\tau$', y=1.02, loc='center', fontsize=6)
cbar_ax.tick_params(length=3)
for tick in cbar_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(6)
ax.set_title(title, fontsize=8)
ax.set_xlabel("Model")
ax.set_ylabel("Model")
def mutation_plot(ax, model='RS_pyramidal'):
models = ['RS_pyramidal', 'RS_inhib', 'FS', 'Cb_stellate', 'Cb_stellate_Kv', 'Cb_stellate_Kv_only', 'STN',
'STN_Kv', 'STN_Kv_only']
model_names = ['RS pyramidal', 'RS inhibitory', 'FS',
'Cb stellate', 'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN', 'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
model_display_names = ['RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'Cb stellate',
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$',
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN',
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$', 'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$']
model_letter_names = ['Model H',
'Model E',
'Model G', 'Model A',
'Model F',
'Model J', 'Model L',
'Model I',
'Model K']
col_dict = {}
for m in range(len(models)):
col_dict[models[m]] = model_display_names[m]
ax_dict = {}
ax_dict['RS_pyramidal'] = (0, 0)
ax_dict['RS_inhib'] = (0, 1)
ax_dict['FS'] = (1, 0)
ax_dict['Cb_stellate'] = (2, 0)
ax_dict['Cb_stellate_Kv'] = (2, 1)
ax_dict['Cb_stellate_Kv_only'] = (3, 0)
ax_dict['STN'] = (3, 1)
ax_dict['STN_Kv'] = (4, 0)
ax_dict['STN_Kv_only'] = (4, 1)
ylim_dict = {}
ylim_dict['RS_pyramidal'] = (-0.1, 0.3)
ylim_dict['RS_inhib'] = (-0.6, 0.6)
ylim_dict['FS'] = (-0.06, 0.08)
ylim_dict['Cb_stellate'] = (-0.1, 0.4)
ylim_dict['Cb_stellate_Kv'] = (-0.1, 0.5)
ylim_dict['Cb_stellate_Kv_only'] = (-1, 0.8)
ylim_dict['STN'] = (-0.01, 0.015)
ylim_dict['STN_Kv'] = (-0.4, 0.6)
ylim_dict['STN_Kv_only'] = (-0.03, 0.3)
Marker_dict = {'Cb stellate': 'o', 'RS Inhibitory': 'o', 'FS': 'o', 'RS Pyramidal': "^",
'RS Inhibitory +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
'Cb stellate +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "^",
'FS +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'RS Pyramidal +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'STN +$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "D",
'Cb stellate $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s",
'STN $\Delta$$\mathrm{K}_{\mathrm{V}}\mathrm{1.1}$': "s", 'STN': "s"}
AUC = pd.read_csv(os.path.join('./Figures/Data/sim_mut_AUC.csv'), index_col='Unnamed: 0')
rheo = pd.read_csv(os.path.join('./Figures/Data/sim_mut_rheo.csv'), index_col='Unnamed: 0')
mod = models.index(model)
mut_names = AUC.index
ax.plot(rheo.loc[mut_names, model_names[mod]]*1000, AUC.loc[mut_names, model_names[mod]]*100, linestyle='',
markeredgecolor='grey', markerfacecolor='grey', marker=Marker_dict[model_display_names[mod]],
markersize=2)
ax.plot(rheo.loc['wt', model_names[mod]], AUC.loc['wt', model_names[mod]]*100, 'sk')
mut_col = sns.color_palette("pastel")
ax.plot(rheo.loc['V174F', model_names[mod]]*1000, AUC.loc['V174F', model_names[mod]]*100, linestyle='',
markeredgecolor=mut_col[0], markerfacecolor=mut_col[0], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['F414C', model_names[mod]]*1000, AUC.loc['F414C', model_names[mod]]*100, linestyle='',
markeredgecolor=mut_col[1], markerfacecolor=mut_col[1], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['E283K', model_names[mod]]*1000, AUC.loc['E283K', model_names[mod]]*100, linestyle='',
markeredgecolor=mut_col[2], markerfacecolor=mut_col[2], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.plot(rheo.loc['V404I', model_names[mod]]*1000, AUC.loc['V404I', model_names[mod]]*100, linestyle='',
markeredgecolor=mut_col[3], markerfacecolor=mut_col[5], marker=Marker_dict[model_display_names[mod]],markersize=4)
ax.set_title(model_letter_names[mod], pad=14)
ax.set_xlabel('$\Delta$Rheobase [pA]')
ax.set_ylabel('Normalized $\Delta$AUC (%)')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# ax.ticklabel_format(axis="y", style="sci", scilimits=(0, 0),useMathText=True)
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
ax.hlines(0, xmin, xmax, colors='lightgrey', linestyles='--')
ax.vlines(0, ymin,ymax, colors='lightgrey', linestyles='--')
return ax
def mutation_legend(ax, marker_s_leg, pos, ncol):
colors = sns.color_palette("pastel")
Markers = ["o", "o", "o", "o"]
V174F = mlines.Line2D([], [], color=colors[0], marker=Markers[0], markersize=marker_s_leg, linestyle='None',
label='V174F')
F414C = mlines.Line2D([], [], color=colors[1], marker=Markers[1], markersize=marker_s_leg, linestyle='None',
label='F414C')
E283K = mlines.Line2D([], [], color=colors[2], marker=Markers[2], markersize=marker_s_leg, linestyle='None', label='E283K')
V404I = mlines.Line2D([], [], color=colors[5], marker=Markers[3], markersize=marker_s_leg, linestyle='None',
label='V404I')
WT = mlines.Line2D([], [], color='k', marker='s', markersize=marker_s_leg+2, linestyle='None', label='Wild type')
ax.legend(handles=[WT, V174F, F414C, E283K, V404I], loc='center', bbox_to_anchor=pos, ncol=ncol, frameon=False)
sim_style()
# plot setup
fig = plt.figure()
gs0 = fig.add_gridspec(1, 6, wspace=-0.2)
gsl = gs0[0:3].subgridspec(3, 3, wspace=0.9, hspace=0.8)
gsr = gs0[4:6].subgridspec(7, 1, wspace=0.6, hspace=0.8)
ax00 = fig.add_subplot(gsl[0,0])
ax01 = fig.add_subplot(gsl[0,1])
ax02 = fig.add_subplot(gsl[0,2])
ax10 = fig.add_subplot(gsl[1,0])
ax11 = fig.add_subplot(gsl[1,1])
ax12 = fig.add_subplot(gsl[1,2])
ax20 = fig.add_subplot(gsl[2,0])
ax21 = fig.add_subplot(gsl[2,1])
ax22 = fig.add_subplot(gsl[2,2])
axr0 = fig.add_subplot(gsr[0:3,0])
axr1 = fig.add_subplot(gsr[4:,0])
# plot mutations in each model
ax00 = mutation_plot(ax00, model='RS_pyramidal')
ax01 = mutation_plot(ax01, model='RS_inhib')
ax02 = mutation_plot(ax02, model='FS')
ax10 = mutation_plot(ax10, model='Cb_stellate')
ax11 = mutation_plot(ax11, model='Cb_stellate_Kv')
ax12 = mutation_plot(ax12, model='Cb_stellate_Kv_only')
ax20 = mutation_plot(ax20, model='STN')
ax21 = mutation_plot(ax21, model='STN_Kv')
ax22 = mutation_plot(ax22, model='STN_Kv_only')
marker_s_leg = 4
pos = (0.425, -0.7)
ncol = 5
mutation_legend(ax21, marker_s_leg, pos, ncol)
# plot correlation matrices
correlation_plot(axr1,df = 'AUC', title='Normalized $\Delta$AUC', cbar=False)
correlation_plot(axr0,df = 'rheo', title='$\Delta$Rheobase', cbar=True)
# add subplot labels
axs = [ax00, ax01,ax02, ax10, ax11, ax12, ax20, ax21, ax22]
j=0
for i in range(0,9):
# axs[i].text(-0.48, 1.175, string.ascii_uppercase[i], transform=axs[i].transAxes, size=10, weight='bold')
axs[i].text(-0.625, 1.25, string.ascii_uppercase[i], transform=axs[i].transAxes, size=10, weight='bold')
j +=1
axr0.text(-0.77, 1.1, string.ascii_uppercase[j], transform=axr0.transAxes, size=10, weight='bold')
axr1.text(-0.77, 1.1, string.ascii_uppercase[j+1], transform=axr1.transAxes, size=10, weight='bold')
# save
fig.set_size_inches(cm2inch(22.2,15))
fig.savefig('./Figures/simulation_model_comparison.pdf', dpi=fig.dpi) #eps
# fig.savefig('./Figures/simulation_model_comparison.png', dpi=fig.dpi) #eps
plt.show()