[plotstyle] overwrite Axes constructer for show_spines

This commit is contained in:
Jan Benda 2020-01-04 19:06:52 +01:00
parent c49047fa67
commit b4cfd0d181
9 changed files with 27 additions and 39 deletions

View File

@ -88,7 +88,7 @@ def adjust_fs(fig=None, left=5.5, right=0.5, bottom=2.8, top=0.5):
'top': 1.0 - top*fs/h }
def show_spines(ax, spines):
def show_spines(ax, spines='lb'):
""" Show and hide spines.
Parameters
@ -150,7 +150,14 @@ def show_spines(ax, spines):
ax.yaxis.set_ticks_position('both')
def axis_label(label, unit=None):
def __axes__init__(ax, *args, **kwargs):
""" Set some default formatting for a new Axes instance.
"""
ax.__init__orig(*args, **kwargs)
ax.show_spines('lb')
def __axis_label(label, unit=None):
""" Format an axis label from a label and a unit
Parameters
@ -173,10 +180,10 @@ def axis_label(label, unit=None):
return '%s [%s]' % (label, unit)
def set_xlabel(ax, label, unit=None, **kwargs):
def __set_xlabel(ax, label, unit=None, **kwargs):
""" Format the xlabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Uses the __axis_label() function to format the axis label.
Parameters
----------
@ -187,13 +194,13 @@ def set_xlabel(ax, label, unit=None, **kwargs):
kwargs: key-word arguments
Further arguments passed on to the set_xlabel() function.
"""
ax.set_xlabel_orig(axis_label(label, unit), **kwargs)
ax.set_xlabel_orig(__axis_label(label, unit), **kwargs)
def set_ylabel(ax, label, unit=None, **kwargs):
def __set_ylabel(ax, label, unit=None, **kwargs):
""" Format the ylabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Uses the __axis_label() function to format the axis label.
Parameters
----------
@ -204,13 +211,13 @@ def set_ylabel(ax, label, unit=None, **kwargs):
kwargs: key-word arguments
Further arguments passed on to the set_ylabel() function.
"""
ax.set_ylabel_orig(axis_label(label, unit), **kwargs)
ax.set_ylabel_orig(__axis_label(label, unit), **kwargs)
def set_zlabel(ax, label, unit=None, **kwargs):
def __set_zlabel(ax, label, unit=None, **kwargs):
""" Format the zlabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Uses the __axis_label() function to format the axis label.
Parameters
----------
@ -221,16 +228,21 @@ def set_zlabel(ax, label, unit=None, **kwargs):
kwargs: key-word arguments
Further arguments passed on to the set_zlabel() function.
"""
ax.set_zlabel_orig(axis_label(label, unit), **kwargs)
ax.set_zlabel_orig(__axis_label(label, unit), **kwargs)
# overwrite axes constructor:
mpl.axes.Subplot.__init__orig = mpl.axes.Subplot.__init__
mpl.axes.Subplot.__init__ = __axes__init__
mpl.axes.Axes.show_spines = show_spines
# overwrite set_[xy]label member functions:
# overwrite axes set_[xyz]label() member functions:
mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel
mpl.axes.Axes.set_xlabel = set_xlabel
mpl.axes.Axes.set_xlabel = __set_xlabel
mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel
mpl.axes.Axes.set_ylabel = set_ylabel
mpl.axes.Axes.set_ylabel = __set_ylabel
Axes3D.set_zlabel_orig = Axes3D.set_zlabel
Axes3D.set_zlabel = set_zlabel
Axes3D.set_zlabel = __set_zlabel
# initialization:
if xkcd_style:

View File

@ -21,8 +21,6 @@ def plot_data(ax, x, y, c):
ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2, zorder=5)
for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]:
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5)
show_spines(ax, 'lb')
ax.set_xlabel('Size x', 'm')
ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)
@ -32,7 +30,6 @@ def plot_data(ax, x, y, c):
def plot_data_errors(ax, x, y, c):
show_spines(ax, 'lb')
ax.set_xlabel('Size x', 'm')
#ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)
@ -56,7 +53,6 @@ def plot_data_errors(ax, x, y, c):
ax.plot(xx, yy, color=colors['orange'], lw=2, zorder=5)
def plot_error_hist(ax, x, y, c):
show_spines(ax, 'lb')
ax.set_xlabel('Squared error')
ax.set_ylabel('Frequency')
bins = np.arange(0.0, 1250.0, 100)

View File

@ -21,8 +21,6 @@ if __name__ == "__main__":
ax.plot(xx, c*xx**3.0, color=colors['red'], lw=3, zorder=5)
for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]:
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=2, zorder=5)
show_spines(ax, 'lb')
ax.set_xlabel('Size x', 'm')
ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)

View File

@ -48,9 +48,6 @@ def plot_mse(ax, x, y, c, cs):
xytext=(cs[i]+0.3, ms[i]+200), textcoords='data', ha='left',
arrowprops=dict(arrowstyle="->", relpos=(0.0,0.0),
connectionstyle="angle3,angleA=10,angleB=70") )
show_spines(ax, 'lb')
ax.set_xlabel('c')
ax.set_ylabel('Mean squared error')
ax.set_xlim(0, 10)
@ -60,8 +57,6 @@ def plot_mse(ax, x, y, c, cs):
def plot_descent(ax, cs, mses):
ax.plot(np.arange(len(mses))+1, mses, '-o', c=colors['red'], mew=0, ms=8)
show_spines(ax, 'lb')
ax.set_xlabel('Iteration')
#ax.set_ylabel('Mean squared error')
ax.set_xlim(0, 10.5)

View File

@ -15,7 +15,6 @@ def create_data():
def plot_data(ax, x, y):
ax.scatter(x, y, marker='o', color=colors['blue'], s=40)
show_spines(ax, 'lb')
ax.set_xlabel('Input x')
ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
@ -29,7 +28,6 @@ def plot_data_slopes(ax, x, y, m, n):
xx = np.asarray([2, 118])
for i in np.linspace(0.3*m, 2.0*m, 5):
ax.plot(xx, i*xx+n, color=colors['red'], lw=2)
show_spines(ax, 'lb')
ax.set_xlabel('Input x')
#ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
@ -43,7 +41,6 @@ def plot_data_intercepts(ax, x, y, m, n):
xx = np.asarray([2, 118])
for i in np.linspace(n-1*n, n+1*n, 5):
ax.plot(xx, m*xx + i, color=colors['red'], lw=2)
show_spines(ax, 'lb')
ax.set_xlabel('Input x')
#ax.set_ylabel('Output y')
ax.set_xlim(0, 120)

View File

@ -14,7 +14,6 @@ def create_data():
def plot_data(ax, x, y, m, n):
show_spines(ax, 'lb')
ax.set_xlabel('Input x')
ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
@ -38,7 +37,6 @@ def plot_data(ax, x, y, m, n):
def plot_error_hist(ax, x, y, m, n):
show_spines(ax, 'lb')
ax.set_xlabel('Squared error')
ax.set_ylabel('Frequency')
bins = np.arange(0.0, 602.0, 50.0)

View File

@ -19,7 +19,6 @@ if __name__ == "__main__":
spec = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=[3, 1], wspace=0.08,
**adjust_fs(fig, left=6.0))
ax1 = fig.add_subplot(spec[0, 0])
show_spines(ax1, 'lb')
ax1.scatter(indices, data, c=colors['blue'], edgecolor='white', s=50)
ax1.set_xlabel('Index')
ax1.set_ylabel('Weight', 'kg')
@ -28,7 +27,6 @@ if __name__ == "__main__":
ax1.set_yticks(np.arange(0, 351, 100))
ax2 = fig.add_subplot(spec[0, 1])
show_spines(ax2, 'lb')
xx = np.arange(0.0, 350.0, 0.5)
yy = st.norm.pdf(xx, mu, sigma)
ax2.plot(yy, xx, color=colors['red'], lw=2)

View File

@ -24,7 +24,6 @@ if __name__ == "__main__":
spec = gridspec.GridSpec(nrows=2, ncols=2, **adjust_fs(fig))
ax = fig.add_subplot(spec[0, 0])
show_spines(ax, 'lb')
ax.plot(indices, x1, c=colors['blue'], lw=1, zorder=10)
ax.scatter(indices, x1, c=colors['blue'], edgecolor='white', s=50, zorder=20)
ax.set_xlabel('Index')
@ -33,7 +32,6 @@ if __name__ == "__main__":
ax.set_ylim(-0.05, 1.05)
ax = fig.add_subplot(spec[0, 1])
show_spines(ax, 'lb')
ax.plot(indices, x2, c=colors['blue'], lw=1, zorder=10)
ax.scatter(indices, x2, c=colors['blue'], edgecolor='white', s=50, zorder=20)
ax.set_xlabel('Index')
@ -42,7 +40,6 @@ if __name__ == "__main__":
ax.set_ylim(-0.05, 1.05)
ax = fig.add_subplot(spec[1, 1])
show_spines(ax, 'lb')
ax.plot(indices, x3, c=colors['blue'], lw=1, zorder=10)
ax.scatter(indices, x3, c=colors['blue'], edgecolor='white', s=50, zorder=20)
ax.set_xlabel('Index')
@ -51,7 +48,6 @@ if __name__ == "__main__":
ax.set_ylim(-0.05, 1.05)
ax = fig.add_subplot(spec[1, 0])
show_spines(ax, 'lb')
ax.plot(lags, corrs, c=colors['red'], lw=1, zorder=10)
ax.scatter(lags, corrs, c=colors['red'], edgecolor='white', s=50, zorder=20)
ax.set_xlabel('Lag')

View File

@ -23,7 +23,6 @@ if __name__ == "__main__":
fig = plt.figure()
spec = gridspec.GridSpec(nrows=1, ncols=2, **adjust_fs(fig, left=4.5))
ax1 = fig.add_subplot(spec[0, 0])
show_spines(ax1, 'lb')
ax1.plot(xx, yy, colors['red'], lw=2)
ax1.scatter(x, y, c=colors['blue'], edgecolor='white', s=50)
ax1.set_xlabel('Hair deflection', 'nm')
@ -34,7 +33,6 @@ if __name__ == "__main__":
ax1.set_yticks(np.arange(0, 9, 2))
ax2 = fig.add_subplot(spec[0, 1])
show_spines(ax2, 'lb')
xg = np.linspace(-3.0, 3.01, 200)
yg = st.norm.pdf(xg, 0.0, sigma)
ax2.plot(xg, yg, colors['red'], lw=2)