67 lines
2.2 KiB
Python
67 lines
2.2 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from plotstyle import *
|
|
|
|
def create_data():
|
|
# wikipedia:
|
|
# Generally, males vary in total length from 250 to 390 cm and
|
|
# weigh between 90 and 306 kg
|
|
c = 6
|
|
x = np.arange(2.2, 3.9, 0.05)
|
|
y = c * x**3.0
|
|
rng = np.random.RandomState(32281)
|
|
noise = rng.randn(len(x))*50
|
|
y += noise
|
|
return x, y, c
|
|
|
|
|
|
def plot_data_errors(ax, x, y, c):
|
|
ax.set_xlabel('Size x', 'm')
|
|
ax.set_ylabel('Weight y', 'kg')
|
|
ax.set_xlim(2, 4)
|
|
ax.set_ylim(0, 400)
|
|
ax.set_xticks(np.arange(2.0, 4.1, 0.5))
|
|
ax.set_yticks(np.arange(0, 401, 100))
|
|
ax.annotate('Error',
|
|
xy=(x[28]+0.05, y[28]+60), xycoords='data',
|
|
xytext=(3.4, 70), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(0.9,1.0),
|
|
connectionstyle="angle3,angleA=50,angleB=-30") )
|
|
ax.plot(x[:40], y[:40], zorder=0, **psAm)
|
|
inxs = [3, 10, 11, 17, 18, 21, 28, 30, 33]
|
|
ax.plot(x[inxs], y[inxs], zorder=10, **psA)
|
|
xx = np.linspace(2.1, 3.9, 100)
|
|
ax.plot(xx, c*xx**3.0, **lsBm)
|
|
for i in inxs :
|
|
xx = [x[i], x[i]]
|
|
yy = [c*x[i]**3.0, y[i]]
|
|
ax.plot(xx, yy, zorder=5, **lsDm)
|
|
|
|
|
|
def plot_error_hist(ax, x, y, c):
|
|
ax.set_xlabel('Squared error')
|
|
ax.set_ylabel('Frequency')
|
|
bins = np.arange(0.0, 11000.0, 750)
|
|
ax.set_xlim(bins[0], bins[-1])
|
|
ax.set_ylim(0, 15)
|
|
ax.set_xticks(np.arange(bins[0], bins[-1], 5000))
|
|
ax.set_yticks(np.arange(0, 16, 5))
|
|
errors = (y-(c*x**3.0))**2.0
|
|
mls = np.mean(errors)
|
|
ax.annotate('Mean\nsquared\nerror',
|
|
xy=(mls, 0.5), xycoords='data',
|
|
xytext=(4500, 6), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2),
|
|
connectionstyle="angle3,angleA=10,angleB=90") )
|
|
ax.hist(errors, bins, **fsC)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x, y, c = create_data()
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=cm_size(figure_width, 0.9*figure_height))
|
|
fig.subplots_adjust(wspace=0.5, **adjust_fs(left=6.0, right=1.2))
|
|
plot_data_errors(ax1, x, y, c)
|
|
plot_error_hist(ax2, x, y, c)
|
|
fig.savefig("cubicerrors.pdf")
|
|
plt.close()
|