[plotstyle] set_[xyz]label() functions are set as members of axes

This commit is contained in:
Jan Benda 2020-01-04 18:32:52 +01:00
parent 1432d4fcd2
commit c49047fa67
10 changed files with 106 additions and 50 deletions

View File

@ -1,5 +1,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
xkcd_style = False
@ -149,33 +150,88 @@ def show_spines(ax, spines):
ax.yaxis.set_ticks_position('both')
def set_xlabel(ax, label, unit=None, **kwargs):
def axis_label(label, unit=None):
""" Format an axis label from a label and a unit
Parameters
----------
label: string
The name of the axis.
unit: string
The unit of the axis values.
Returns
-------
label: string
An axis label formatted from `label` and `unit`.
"""
if not unit:
ax.set_xlabel(label, **kwargs)
return label
elif xkcd_style:
ax.set_xlabel('%s / %s' % (label, unit), **kwargs)
return '%s / %s' % (label, unit)
else:
ax.set_xlabel('%s [%s]' % (label, unit), **kwargs)
return '%s [%s]' % (label, unit)
def set_xlabel(ax, label, unit=None, **kwargs):
""" Format the xlabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Parameters
----------
label: string
The name of the axis.
unit: string
The unit of the axis values.
kwargs: key-word arguments
Further arguments passed on to the set_xlabel() function.
"""
ax.set_xlabel_orig(axis_label(label, unit), **kwargs)
def set_ylabel(ax, label, unit=None, **kwargs):
if not unit:
ax.set_ylabel(label, **kwargs)
elif xkcd_style:
ax.set_ylabel('%s / %s' % (label, unit), **kwargs)
else:
ax.set_ylabel('%s [%s]' % (label, unit), **kwargs)
""" Format the ylabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Parameters
----------
label: string
The name of the axis.
unit: string
The unit of the axis values.
kwargs: key-word arguments
Further arguments passed on to the set_ylabel() function.
"""
ax.set_ylabel_orig(axis_label(label, unit), **kwargs)
def set_zlabel(ax, label, unit=None, **kwargs):
if not unit:
ax.set_zlabel(label, **kwargs)
elif xkcd_style:
ax.set_zlabel('%s / %s' % (label, unit), **kwargs)
else:
ax.set_zlabel('%s [%s]' % (label, unit), **kwargs)
""" Format the zlabel from a label and an unit.
Uses the axis_label() function to format the axis label.
Parameters
----------
label: string
The name of the axis.
unit: string
The unit of the axis values.
kwargs: key-word arguments
Further arguments passed on to the set_zlabel() function.
"""
ax.set_zlabel_orig(axis_label(label, unit), **kwargs)
# overwrite set_[xy]label member functions:
mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel
mpl.axes.Axes.set_xlabel = set_xlabel
mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel
mpl.axes.Axes.set_ylabel = set_ylabel
Axes3D.set_zlabel_orig = Axes3D.set_zlabel
Axes3D.set_zlabel = set_zlabel
# initialization:
if xkcd_style:
plt.xkcd()

View File

@ -23,8 +23,8 @@ def plot_data(ax, x, y, c):
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5)
show_spines(ax, 'lb')
set_xlabel(ax, 'Size x', 'm')
set_ylabel(ax, 'Weight y', 'kg')
ax.set_xlabel('Size x', 'm')
ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)
ax.set_ylim(0, 400)
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
@ -33,8 +33,8 @@ def plot_data(ax, x, y, c):
def plot_data_errors(ax, x, y, c):
show_spines(ax, 'lb')
set_xlabel(ax, 'Size x', 'm')
#set_ylabel(ax, 'Weight y', 'kg')
ax.set_xlabel('Size x', 'm')
#ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)
ax.set_ylim(0, 400)
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
@ -57,8 +57,8 @@ def plot_data_errors(ax, x, y, c):
def plot_error_hist(ax, x, y, c):
show_spines(ax, 'lb')
set_xlabel(ax, 'Squared error')
set_ylabel(ax, 'Frequency')
ax.set_xlabel('Squared error')
ax.set_ylabel('Frequency')
bins = np.arange(0.0, 1250.0, 100)
ax.set_xlim(bins[0], bins[-1])
#ax.set_ylim(0, 35)

View File

@ -23,8 +23,8 @@ if __name__ == "__main__":
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=2, zorder=5)
show_spines(ax, 'lb')
set_xlabel(ax, 'Size x', 'm')
set_ylabel(ax, 'Weight y', 'kg')
ax.set_xlabel('Size x', 'm')
ax.set_ylabel('Weight y', 'kg')
ax.set_xlim(2, 4)
ax.set_ylim(0, 400)
ax.set_xticks(np.arange(2.0, 4.1, 0.5))

View File

@ -51,8 +51,8 @@ def plot_mse(ax, x, y, c, cs):
show_spines(ax, 'lb')
set_xlabel(ax, 'c')
set_ylabel(ax, 'Mean squared error')
ax.set_xlabel('c')
ax.set_ylabel('Mean squared error')
ax.set_xlim(0, 10)
ax.set_ylim(0, 25000)
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
@ -62,8 +62,8 @@ def plot_descent(ax, cs, mses):
ax.plot(np.arange(len(mses))+1, mses, '-o', c=colors['red'], mew=0, ms=8)
show_spines(ax, 'lb')
set_xlabel(ax, 'Iteration')
#set_ylabel(ax, 'Mean squared error')
ax.set_xlabel('Iteration')
#ax.set_ylabel('Mean squared error')
ax.set_xlim(0, 10.5)
ax.set_ylim(0, 25000)
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
@ -75,7 +75,7 @@ if __name__ == "__main__":
x, y, c = create_data()
cs, mses = gradient_descent(x, y)
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.subplots_adjust(wspace=0.2, **adjust_fs(left=7.5, right=0.5))
fig.subplots_adjust(wspace=0.2, **adjust_fs(left=8.0, right=0.5))
plot_mse(ax1, x, y, c, cs)
plot_descent(ax2, cs, mses)
fig.savefig("cubicmse.pdf")

View File

@ -16,9 +16,9 @@ def create_data():
def plot_error_plane(ax, x, y, m, n):
set_xlabel(ax, 'Slope m')
set_ylabel(ax, 'Intercept b')
set_zlabel(ax, 'Mean squared error')
ax.set_xlabel('Slope m')
ax.set_ylabel('Intercept b')
ax.set_zlabel('Mean squared error')
ax.set_xlim(-4.5, 5.0)
ax.set_ylim(-60.0, -20.0)
ax.set_zlim(0.0, 700.0)

View File

@ -18,8 +18,8 @@ x = np.linspace(x1, x2, 200)
y = np.linspace(x1, x2, 200)
X, Y = np.meshgrid(x, y)
Z = gaussian(X, Y)
set_xlabel(ax, 'x')
set_ylabel(ax, 'y', rotation='horizontal')
ax.set_xlabel('x')
ax.set_ylabel('y', rotation='horizontal')
ax.set_xticks(np.arange(x1, x2+0.1, 1.0))
ax.set_yticks(np.arange(x1, x2+0.1, 1.0))
ax.contour(X, Y, Z, linewidths=2, zorder=0)

View File

@ -16,8 +16,8 @@ def create_data():
def plot_data(ax, x, y):
ax.scatter(x, y, marker='o', color=colors['blue'], s=40)
show_spines(ax, 'lb')
set_xlabel(ax, 'Input x')
set_ylabel(ax, 'Output y')
ax.set_xlabel('Input x')
ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
ax.set_ylim(-80, 80)
ax.set_xticks(np.arange(0,121, 40))
@ -30,8 +30,8 @@ def plot_data_slopes(ax, x, y, m, n):
for i in np.linspace(0.3*m, 2.0*m, 5):
ax.plot(xx, i*xx+n, color=colors['red'], lw=2)
show_spines(ax, 'lb')
set_xlabel(ax, 'Input x')
#set_ylabel(ax, 'Output y')
ax.set_xlabel('Input x')
#ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
ax.set_ylim(-80, 80)
ax.set_xticks(np.arange(0,121, 40))
@ -44,8 +44,8 @@ def plot_data_intercepts(ax, x, y, m, n):
for i in np.linspace(n-1*n, n+1*n, 5):
ax.plot(xx, m*xx + i, color=colors['red'], lw=2)
show_spines(ax, 'lb')
set_xlabel(ax, 'Input x')
#set_ylabel(ax, 'Output y')
ax.set_xlabel('Input x')
#ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
ax.set_ylim(-80, 80)
ax.set_xticks(np.arange(0,121, 40))

View File

@ -15,8 +15,8 @@ def create_data():
def plot_data(ax, x, y, m, n):
show_spines(ax, 'lb')
set_xlabel(ax, 'Input x')
set_ylabel(ax, 'Output y')
ax.set_xlabel('Input x')
ax.set_ylabel('Output y')
ax.set_xlim(0, 120)
ax.set_ylim(-80, 80)
ax.set_xticks(np.arange(0,121, 40))

View File

@ -21,8 +21,8 @@ if __name__ == "__main__":
ax1 = fig.add_subplot(spec[0, 0])
show_spines(ax1, 'lb')
ax1.scatter(indices, data, c=colors['blue'], edgecolor='white', s=50)
set_xlabel(ax1, 'Index')
set_ylabel(ax1, 'Weight', 'kg')
ax1.set_xlabel('Index')
ax1.set_ylabel('Weight', 'kg')
ax1.set_xlim(-10, 310)
ax1.set_ylim(0, 370)
ax1.set_yticks(np.arange(0, 351, 100))
@ -35,7 +35,7 @@ if __name__ == "__main__":
bw = 20.0
h, b = np.histogram(data, np.arange(0, 401, bw))
ax2.barh(b[:-1], h/np.sum(h)/(b[1]-b[0]), fc=colors['yellow'], height=bar_fac*bw, align='edge')
set_xlabel(ax2, 'Pdf', '1/kg')
ax2.set_xlabel('Pdf', '1/kg')
ax2.set_xlim(0, 0.012)
ax2.set_xticks([0, 0.005, 0.01])
ax2.set_xticklabels(['0', '0.005', '0.01'])

View File

@ -26,8 +26,8 @@ if __name__ == "__main__":
show_spines(ax1, 'lb')
ax1.plot(xx, yy, colors['red'], lw=2)
ax1.scatter(x, y, c=colors['blue'], edgecolor='white', s=50)
set_xlabel(ax1, 'Hair deflection', 'nm')
set_ylabel(ax1, 'Conductance', 'nS')
ax1.set_xlabel('Hair deflection', 'nm')
ax1.set_ylabel('Conductance', 'nS')
ax1.set_xlim(-20, 20)
ax1.set_ylim(-1.5, 9.5)
ax1.set_xticks(np.arange(-20.0, 21.0, 10.0))
@ -41,8 +41,8 @@ if __name__ == "__main__":
bw = 0.25
h, b = np.histogram(y-boltzmann(x, x0, k), np.arange(-3.0, 3.01, bw))
ax2.bar(b[:-1], h/np.sum(h)/(b[1]-b[0]), fc=colors['yellow'], width=bar_fac*bw, align='edge')
set_xlabel(ax2, 'Residuals', 'nS')
set_ylabel(ax2, 'Pdf', '1/nS')
ax2.set_xlabel('Residuals', 'nS')
ax2.set_ylabel('Pdf', '1/nS')
ax2.set_xlim(-2.5, 2.5)
ax2.set_ylim(0, 0.75)
ax2.set_yticks(np.arange(0, 0.75, 0.2))