[plotstyle] overwrite Axes constructer for show_spines
This commit is contained in:
parent
c49047fa67
commit
b4cfd0d181
42
plotstyle.py
42
plotstyle.py
@ -88,7 +88,7 @@ def adjust_fs(fig=None, left=5.5, right=0.5, bottom=2.8, top=0.5):
|
||||
'top': 1.0 - top*fs/h }
|
||||
|
||||
|
||||
def show_spines(ax, spines):
|
||||
def show_spines(ax, spines='lb'):
|
||||
""" Show and hide spines.
|
||||
|
||||
Parameters
|
||||
@ -150,7 +150,14 @@ def show_spines(ax, spines):
|
||||
ax.yaxis.set_ticks_position('both')
|
||||
|
||||
|
||||
def axis_label(label, unit=None):
|
||||
def __axes__init__(ax, *args, **kwargs):
|
||||
""" Set some default formatting for a new Axes instance.
|
||||
"""
|
||||
ax.__init__orig(*args, **kwargs)
|
||||
ax.show_spines('lb')
|
||||
|
||||
|
||||
def __axis_label(label, unit=None):
|
||||
""" Format an axis label from a label and a unit
|
||||
|
||||
Parameters
|
||||
@ -173,10 +180,10 @@ def axis_label(label, unit=None):
|
||||
return '%s [%s]' % (label, unit)
|
||||
|
||||
|
||||
def set_xlabel(ax, label, unit=None, **kwargs):
|
||||
def __set_xlabel(ax, label, unit=None, **kwargs):
|
||||
""" Format the xlabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
Uses the __axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -187,13 +194,13 @@ def set_xlabel(ax, label, unit=None, **kwargs):
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_xlabel() function.
|
||||
"""
|
||||
ax.set_xlabel_orig(axis_label(label, unit), **kwargs)
|
||||
ax.set_xlabel_orig(__axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
def set_ylabel(ax, label, unit=None, **kwargs):
|
||||
def __set_ylabel(ax, label, unit=None, **kwargs):
|
||||
""" Format the ylabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
Uses the __axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -204,13 +211,13 @@ def set_ylabel(ax, label, unit=None, **kwargs):
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_ylabel() function.
|
||||
"""
|
||||
ax.set_ylabel_orig(axis_label(label, unit), **kwargs)
|
||||
ax.set_ylabel_orig(__axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
def set_zlabel(ax, label, unit=None, **kwargs):
|
||||
def __set_zlabel(ax, label, unit=None, **kwargs):
|
||||
""" Format the zlabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
Uses the __axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -221,16 +228,21 @@ def set_zlabel(ax, label, unit=None, **kwargs):
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_zlabel() function.
|
||||
"""
|
||||
ax.set_zlabel_orig(axis_label(label, unit), **kwargs)
|
||||
ax.set_zlabel_orig(__axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
# overwrite axes constructor:
|
||||
mpl.axes.Subplot.__init__orig = mpl.axes.Subplot.__init__
|
||||
mpl.axes.Subplot.__init__ = __axes__init__
|
||||
mpl.axes.Axes.show_spines = show_spines
|
||||
|
||||
# overwrite set_[xy]label member functions:
|
||||
# overwrite axes set_[xyz]label() member functions:
|
||||
mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel
|
||||
mpl.axes.Axes.set_xlabel = set_xlabel
|
||||
mpl.axes.Axes.set_xlabel = __set_xlabel
|
||||
mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel
|
||||
mpl.axes.Axes.set_ylabel = set_ylabel
|
||||
mpl.axes.Axes.set_ylabel = __set_ylabel
|
||||
Axes3D.set_zlabel_orig = Axes3D.set_zlabel
|
||||
Axes3D.set_zlabel = set_zlabel
|
||||
Axes3D.set_zlabel = __set_zlabel
|
||||
|
||||
# initialization:
|
||||
if xkcd_style:
|
||||
|
@ -21,8 +21,6 @@ def plot_data(ax, x, y, c):
|
||||
ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2, zorder=5)
|
||||
for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]:
|
||||
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
@ -32,7 +30,6 @@ def plot_data(ax, x, y, c):
|
||||
|
||||
|
||||
def plot_data_errors(ax, x, y, c):
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
#ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
@ -56,7 +53,6 @@ def plot_data_errors(ax, x, y, c):
|
||||
ax.plot(xx, yy, color=colors['orange'], lw=2, zorder=5)
|
||||
|
||||
def plot_error_hist(ax, x, y, c):
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Squared error')
|
||||
ax.set_ylabel('Frequency')
|
||||
bins = np.arange(0.0, 1250.0, 100)
|
||||
|
@ -21,8 +21,6 @@ if __name__ == "__main__":
|
||||
ax.plot(xx, c*xx**3.0, color=colors['red'], lw=3, zorder=5)
|
||||
for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]:
|
||||
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=2, zorder=5)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
|
@ -48,9 +48,6 @@ def plot_mse(ax, x, y, c, cs):
|
||||
xytext=(cs[i]+0.3, ms[i]+200), textcoords='data', ha='left',
|
||||
arrowprops=dict(arrowstyle="->", relpos=(0.0,0.0),
|
||||
connectionstyle="angle3,angleA=10,angleB=70") )
|
||||
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('c')
|
||||
ax.set_ylabel('Mean squared error')
|
||||
ax.set_xlim(0, 10)
|
||||
@ -60,8 +57,6 @@ def plot_mse(ax, x, y, c, cs):
|
||||
|
||||
def plot_descent(ax, cs, mses):
|
||||
ax.plot(np.arange(len(mses))+1, mses, '-o', c=colors['red'], mew=0, ms=8)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Iteration')
|
||||
#ax.set_ylabel('Mean squared error')
|
||||
ax.set_xlim(0, 10.5)
|
||||
|
@ -15,7 +15,6 @@ def create_data():
|
||||
|
||||
def plot_data(ax, x, y):
|
||||
ax.scatter(x, y, marker='o', color=colors['blue'], s=40)
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Input x')
|
||||
ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
@ -29,7 +28,6 @@ def plot_data_slopes(ax, x, y, m, n):
|
||||
xx = np.asarray([2, 118])
|
||||
for i in np.linspace(0.3*m, 2.0*m, 5):
|
||||
ax.plot(xx, i*xx+n, color=colors['red'], lw=2)
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Input x')
|
||||
#ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
@ -43,7 +41,6 @@ def plot_data_intercepts(ax, x, y, m, n):
|
||||
xx = np.asarray([2, 118])
|
||||
for i in np.linspace(n-1*n, n+1*n, 5):
|
||||
ax.plot(xx, m*xx + i, color=colors['red'], lw=2)
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Input x')
|
||||
#ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
|
@ -14,7 +14,6 @@ def create_data():
|
||||
|
||||
|
||||
def plot_data(ax, x, y, m, n):
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Input x')
|
||||
ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
@ -38,7 +37,6 @@ def plot_data(ax, x, y, m, n):
|
||||
|
||||
|
||||
def plot_error_hist(ax, x, y, m, n):
|
||||
show_spines(ax, 'lb')
|
||||
ax.set_xlabel('Squared error')
|
||||
ax.set_ylabel('Frequency')
|
||||
bins = np.arange(0.0, 602.0, 50.0)
|
||||
|
@ -19,7 +19,6 @@ if __name__ == "__main__":
|
||||
spec = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=[3, 1], wspace=0.08,
|
||||
**adjust_fs(fig, left=6.0))
|
||||
ax1 = fig.add_subplot(spec[0, 0])
|
||||
show_spines(ax1, 'lb')
|
||||
ax1.scatter(indices, data, c=colors['blue'], edgecolor='white', s=50)
|
||||
ax1.set_xlabel('Index')
|
||||
ax1.set_ylabel('Weight', 'kg')
|
||||
@ -28,7 +27,6 @@ if __name__ == "__main__":
|
||||
ax1.set_yticks(np.arange(0, 351, 100))
|
||||
|
||||
ax2 = fig.add_subplot(spec[0, 1])
|
||||
show_spines(ax2, 'lb')
|
||||
xx = np.arange(0.0, 350.0, 0.5)
|
||||
yy = st.norm.pdf(xx, mu, sigma)
|
||||
ax2.plot(yy, xx, color=colors['red'], lw=2)
|
||||
|
@ -24,7 +24,6 @@ if __name__ == "__main__":
|
||||
spec = gridspec.GridSpec(nrows=2, ncols=2, **adjust_fs(fig))
|
||||
|
||||
ax = fig.add_subplot(spec[0, 0])
|
||||
show_spines(ax, 'lb')
|
||||
ax.plot(indices, x1, c=colors['blue'], lw=1, zorder=10)
|
||||
ax.scatter(indices, x1, c=colors['blue'], edgecolor='white', s=50, zorder=20)
|
||||
ax.set_xlabel('Index')
|
||||
@ -33,7 +32,6 @@ if __name__ == "__main__":
|
||||
ax.set_ylim(-0.05, 1.05)
|
||||
|
||||
ax = fig.add_subplot(spec[0, 1])
|
||||
show_spines(ax, 'lb')
|
||||
ax.plot(indices, x2, c=colors['blue'], lw=1, zorder=10)
|
||||
ax.scatter(indices, x2, c=colors['blue'], edgecolor='white', s=50, zorder=20)
|
||||
ax.set_xlabel('Index')
|
||||
@ -42,7 +40,6 @@ if __name__ == "__main__":
|
||||
ax.set_ylim(-0.05, 1.05)
|
||||
|
||||
ax = fig.add_subplot(spec[1, 1])
|
||||
show_spines(ax, 'lb')
|
||||
ax.plot(indices, x3, c=colors['blue'], lw=1, zorder=10)
|
||||
ax.scatter(indices, x3, c=colors['blue'], edgecolor='white', s=50, zorder=20)
|
||||
ax.set_xlabel('Index')
|
||||
@ -51,7 +48,6 @@ if __name__ == "__main__":
|
||||
ax.set_ylim(-0.05, 1.05)
|
||||
|
||||
ax = fig.add_subplot(spec[1, 0])
|
||||
show_spines(ax, 'lb')
|
||||
ax.plot(lags, corrs, c=colors['red'], lw=1, zorder=10)
|
||||
ax.scatter(lags, corrs, c=colors['red'], edgecolor='white', s=50, zorder=20)
|
||||
ax.set_xlabel('Lag')
|
||||
|
@ -23,7 +23,6 @@ if __name__ == "__main__":
|
||||
fig = plt.figure()
|
||||
spec = gridspec.GridSpec(nrows=1, ncols=2, **adjust_fs(fig, left=4.5))
|
||||
ax1 = fig.add_subplot(spec[0, 0])
|
||||
show_spines(ax1, 'lb')
|
||||
ax1.plot(xx, yy, colors['red'], lw=2)
|
||||
ax1.scatter(x, y, c=colors['blue'], edgecolor='white', s=50)
|
||||
ax1.set_xlabel('Hair deflection', 'nm')
|
||||
@ -34,7 +33,6 @@ if __name__ == "__main__":
|
||||
ax1.set_yticks(np.arange(0, 9, 2))
|
||||
|
||||
ax2 = fig.add_subplot(spec[0, 1])
|
||||
show_spines(ax2, 'lb')
|
||||
xg = np.linspace(-3.0, 3.01, 200)
|
||||
yg = st.norm.pdf(xg, 0.0, sigma)
|
||||
ax2.plot(xg, yg, colors['red'], lw=2)
|
||||
|
Reference in New Issue
Block a user