[plotstyle] set_[xyz]label() functions are set as members of axes
This commit is contained in:
parent
1432d4fcd2
commit
c49047fa67
90
plotstyle.py
90
plotstyle.py
@ -1,5 +1,6 @@
|
||||
import matplotlib as mpl
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
xkcd_style = False
|
||||
|
||||
@ -149,33 +150,88 @@ def show_spines(ax, spines):
|
||||
ax.yaxis.set_ticks_position('both')
|
||||
|
||||
|
||||
def set_xlabel(ax, label, unit=None, **kwargs):
|
||||
def axis_label(label, unit=None):
|
||||
""" Format an axis label from a label and a unit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label: string
|
||||
The name of the axis.
|
||||
unit: string
|
||||
The unit of the axis values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
label: string
|
||||
An axis label formatted from `label` and `unit`.
|
||||
"""
|
||||
if not unit:
|
||||
ax.set_xlabel(label, **kwargs)
|
||||
return label
|
||||
elif xkcd_style:
|
||||
ax.set_xlabel('%s / %s' % (label, unit), **kwargs)
|
||||
return '%s / %s' % (label, unit)
|
||||
else:
|
||||
ax.set_xlabel('%s [%s]' % (label, unit), **kwargs)
|
||||
return '%s [%s]' % (label, unit)
|
||||
|
||||
|
||||
def set_xlabel(ax, label, unit=None, **kwargs):
|
||||
""" Format the xlabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label: string
|
||||
The name of the axis.
|
||||
unit: string
|
||||
The unit of the axis values.
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_xlabel() function.
|
||||
"""
|
||||
ax.set_xlabel_orig(axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
def set_ylabel(ax, label, unit=None, **kwargs):
|
||||
if not unit:
|
||||
ax.set_ylabel(label, **kwargs)
|
||||
elif xkcd_style:
|
||||
ax.set_ylabel('%s / %s' % (label, unit), **kwargs)
|
||||
else:
|
||||
ax.set_ylabel('%s [%s]' % (label, unit), **kwargs)
|
||||
""" Format the ylabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label: string
|
||||
The name of the axis.
|
||||
unit: string
|
||||
The unit of the axis values.
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_ylabel() function.
|
||||
"""
|
||||
ax.set_ylabel_orig(axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
def set_zlabel(ax, label, unit=None, **kwargs):
|
||||
if not unit:
|
||||
ax.set_zlabel(label, **kwargs)
|
||||
elif xkcd_style:
|
||||
ax.set_zlabel('%s / %s' % (label, unit), **kwargs)
|
||||
else:
|
||||
ax.set_zlabel('%s [%s]' % (label, unit), **kwargs)
|
||||
""" Format the zlabel from a label and an unit.
|
||||
|
||||
Uses the axis_label() function to format the axis label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label: string
|
||||
The name of the axis.
|
||||
unit: string
|
||||
The unit of the axis values.
|
||||
kwargs: key-word arguments
|
||||
Further arguments passed on to the set_zlabel() function.
|
||||
"""
|
||||
ax.set_zlabel_orig(axis_label(label, unit), **kwargs)
|
||||
|
||||
|
||||
# overwrite set_[xy]label member functions:
|
||||
mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel
|
||||
mpl.axes.Axes.set_xlabel = set_xlabel
|
||||
mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel
|
||||
mpl.axes.Axes.set_ylabel = set_ylabel
|
||||
Axes3D.set_zlabel_orig = Axes3D.set_zlabel
|
||||
Axes3D.set_zlabel = set_zlabel
|
||||
|
||||
|
||||
# initialization:
|
||||
if xkcd_style:
|
||||
plt.xkcd()
|
||||
|
@ -23,8 +23,8 @@ def plot_data(ax, x, y, c):
|
||||
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Size x', 'm')
|
||||
set_ylabel(ax, 'Weight y', 'kg')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
ax.set_ylim(0, 400)
|
||||
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
|
||||
@ -33,8 +33,8 @@ def plot_data(ax, x, y, c):
|
||||
|
||||
def plot_data_errors(ax, x, y, c):
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Size x', 'm')
|
||||
#set_ylabel(ax, 'Weight y', 'kg')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
#ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
ax.set_ylim(0, 400)
|
||||
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
|
||||
@ -57,8 +57,8 @@ def plot_data_errors(ax, x, y, c):
|
||||
|
||||
def plot_error_hist(ax, x, y, c):
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Squared error')
|
||||
set_ylabel(ax, 'Frequency')
|
||||
ax.set_xlabel('Squared error')
|
||||
ax.set_ylabel('Frequency')
|
||||
bins = np.arange(0.0, 1250.0, 100)
|
||||
ax.set_xlim(bins[0], bins[-1])
|
||||
#ax.set_ylim(0, 35)
|
||||
|
@ -23,8 +23,8 @@ if __name__ == "__main__":
|
||||
ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=2, zorder=5)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Size x', 'm')
|
||||
set_ylabel(ax, 'Weight y', 'kg')
|
||||
ax.set_xlabel('Size x', 'm')
|
||||
ax.set_ylabel('Weight y', 'kg')
|
||||
ax.set_xlim(2, 4)
|
||||
ax.set_ylim(0, 400)
|
||||
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
|
||||
|
@ -51,8 +51,8 @@ def plot_mse(ax, x, y, c, cs):
|
||||
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'c')
|
||||
set_ylabel(ax, 'Mean squared error')
|
||||
ax.set_xlabel('c')
|
||||
ax.set_ylabel('Mean squared error')
|
||||
ax.set_xlim(0, 10)
|
||||
ax.set_ylim(0, 25000)
|
||||
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
|
||||
@ -62,8 +62,8 @@ def plot_descent(ax, cs, mses):
|
||||
ax.plot(np.arange(len(mses))+1, mses, '-o', c=colors['red'], mew=0, ms=8)
|
||||
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Iteration')
|
||||
#set_ylabel(ax, 'Mean squared error')
|
||||
ax.set_xlabel('Iteration')
|
||||
#ax.set_ylabel('Mean squared error')
|
||||
ax.set_xlim(0, 10.5)
|
||||
ax.set_ylim(0, 25000)
|
||||
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
|
||||
@ -75,7 +75,7 @@ if __name__ == "__main__":
|
||||
x, y, c = create_data()
|
||||
cs, mses = gradient_descent(x, y)
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2)
|
||||
fig.subplots_adjust(wspace=0.2, **adjust_fs(left=7.5, right=0.5))
|
||||
fig.subplots_adjust(wspace=0.2, **adjust_fs(left=8.0, right=0.5))
|
||||
plot_mse(ax1, x, y, c, cs)
|
||||
plot_descent(ax2, cs, mses)
|
||||
fig.savefig("cubicmse.pdf")
|
||||
|
@ -16,9 +16,9 @@ def create_data():
|
||||
|
||||
|
||||
def plot_error_plane(ax, x, y, m, n):
|
||||
set_xlabel(ax, 'Slope m')
|
||||
set_ylabel(ax, 'Intercept b')
|
||||
set_zlabel(ax, 'Mean squared error')
|
||||
ax.set_xlabel('Slope m')
|
||||
ax.set_ylabel('Intercept b')
|
||||
ax.set_zlabel('Mean squared error')
|
||||
ax.set_xlim(-4.5, 5.0)
|
||||
ax.set_ylim(-60.0, -20.0)
|
||||
ax.set_zlim(0.0, 700.0)
|
||||
|
@ -18,8 +18,8 @@ x = np.linspace(x1, x2, 200)
|
||||
y = np.linspace(x1, x2, 200)
|
||||
X, Y = np.meshgrid(x, y)
|
||||
Z = gaussian(X, Y)
|
||||
set_xlabel(ax, 'x')
|
||||
set_ylabel(ax, 'y', rotation='horizontal')
|
||||
ax.set_xlabel('x')
|
||||
ax.set_ylabel('y', rotation='horizontal')
|
||||
ax.set_xticks(np.arange(x1, x2+0.1, 1.0))
|
||||
ax.set_yticks(np.arange(x1, x2+0.1, 1.0))
|
||||
ax.contour(X, Y, Z, linewidths=2, zorder=0)
|
||||
|
@ -16,8 +16,8 @@ def create_data():
|
||||
def plot_data(ax, x, y):
|
||||
ax.scatter(x, y, marker='o', color=colors['blue'], s=40)
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Input x')
|
||||
set_ylabel(ax, 'Output y')
|
||||
ax.set_xlabel('Input x')
|
||||
ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
ax.set_ylim(-80, 80)
|
||||
ax.set_xticks(np.arange(0,121, 40))
|
||||
@ -30,8 +30,8 @@ def plot_data_slopes(ax, x, y, m, n):
|
||||
for i in np.linspace(0.3*m, 2.0*m, 5):
|
||||
ax.plot(xx, i*xx+n, color=colors['red'], lw=2)
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Input x')
|
||||
#set_ylabel(ax, 'Output y')
|
||||
ax.set_xlabel('Input x')
|
||||
#ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
ax.set_ylim(-80, 80)
|
||||
ax.set_xticks(np.arange(0,121, 40))
|
||||
@ -44,8 +44,8 @@ def plot_data_intercepts(ax, x, y, m, n):
|
||||
for i in np.linspace(n-1*n, n+1*n, 5):
|
||||
ax.plot(xx, m*xx + i, color=colors['red'], lw=2)
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Input x')
|
||||
#set_ylabel(ax, 'Output y')
|
||||
ax.set_xlabel('Input x')
|
||||
#ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
ax.set_ylim(-80, 80)
|
||||
ax.set_xticks(np.arange(0,121, 40))
|
||||
|
@ -15,8 +15,8 @@ def create_data():
|
||||
|
||||
def plot_data(ax, x, y, m, n):
|
||||
show_spines(ax, 'lb')
|
||||
set_xlabel(ax, 'Input x')
|
||||
set_ylabel(ax, 'Output y')
|
||||
ax.set_xlabel('Input x')
|
||||
ax.set_ylabel('Output y')
|
||||
ax.set_xlim(0, 120)
|
||||
ax.set_ylim(-80, 80)
|
||||
ax.set_xticks(np.arange(0,121, 40))
|
||||
|
@ -21,8 +21,8 @@ if __name__ == "__main__":
|
||||
ax1 = fig.add_subplot(spec[0, 0])
|
||||
show_spines(ax1, 'lb')
|
||||
ax1.scatter(indices, data, c=colors['blue'], edgecolor='white', s=50)
|
||||
set_xlabel(ax1, 'Index')
|
||||
set_ylabel(ax1, 'Weight', 'kg')
|
||||
ax1.set_xlabel('Index')
|
||||
ax1.set_ylabel('Weight', 'kg')
|
||||
ax1.set_xlim(-10, 310)
|
||||
ax1.set_ylim(0, 370)
|
||||
ax1.set_yticks(np.arange(0, 351, 100))
|
||||
@ -35,7 +35,7 @@ if __name__ == "__main__":
|
||||
bw = 20.0
|
||||
h, b = np.histogram(data, np.arange(0, 401, bw))
|
||||
ax2.barh(b[:-1], h/np.sum(h)/(b[1]-b[0]), fc=colors['yellow'], height=bar_fac*bw, align='edge')
|
||||
set_xlabel(ax2, 'Pdf', '1/kg')
|
||||
ax2.set_xlabel('Pdf', '1/kg')
|
||||
ax2.set_xlim(0, 0.012)
|
||||
ax2.set_xticks([0, 0.005, 0.01])
|
||||
ax2.set_xticklabels(['0', '0.005', '0.01'])
|
||||
|
@ -26,8 +26,8 @@ if __name__ == "__main__":
|
||||
show_spines(ax1, 'lb')
|
||||
ax1.plot(xx, yy, colors['red'], lw=2)
|
||||
ax1.scatter(x, y, c=colors['blue'], edgecolor='white', s=50)
|
||||
set_xlabel(ax1, 'Hair deflection', 'nm')
|
||||
set_ylabel(ax1, 'Conductance', 'nS')
|
||||
ax1.set_xlabel('Hair deflection', 'nm')
|
||||
ax1.set_ylabel('Conductance', 'nS')
|
||||
ax1.set_xlim(-20, 20)
|
||||
ax1.set_ylim(-1.5, 9.5)
|
||||
ax1.set_xticks(np.arange(-20.0, 21.0, 10.0))
|
||||
@ -41,8 +41,8 @@ if __name__ == "__main__":
|
||||
bw = 0.25
|
||||
h, b = np.histogram(y-boltzmann(x, x0, k), np.arange(-3.0, 3.01, bw))
|
||||
ax2.bar(b[:-1], h/np.sum(h)/(b[1]-b[0]), fc=colors['yellow'], width=bar_fac*bw, align='edge')
|
||||
set_xlabel(ax2, 'Residuals', 'nS')
|
||||
set_ylabel(ax2, 'Pdf', '1/nS')
|
||||
ax2.set_xlabel('Residuals', 'nS')
|
||||
ax2.set_ylabel('Pdf', '1/nS')
|
||||
ax2.set_xlim(-2.5, 2.5)
|
||||
ax2.set_ylim(0, 0.75)
|
||||
ax2.set_yticks(np.arange(0, 0.75, 0.2))
|
||||
|
Reference in New Issue
Block a user