83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mt
|
|
from plotstyle import *
|
|
|
|
def create_data():
|
|
# wikipedia:
|
|
# Generally, males vary in total length from 250 to 390 cm and
|
|
# weigh between 90 and 306 kg
|
|
c = 6
|
|
x = np.arange(2.2, 3.9, 0.05)
|
|
y = c * x**3.0
|
|
rng = np.random.RandomState(32281)
|
|
noise = rng.randn(len(x))*50
|
|
y += noise
|
|
return x, y, c
|
|
|
|
|
|
def plot_mse(ax, x, y, c):
|
|
ccs = np.linspace(0.5, 10.0, 200)
|
|
mses = np.zeros(len(ccs))
|
|
for i, cc in enumerate(ccs):
|
|
mses[i] = np.mean((y-(cc*x**3.0))**2.0)
|
|
imin = np.argmin(mses)
|
|
|
|
ax.plot(ccs, mses, **lsAm)
|
|
ax.plot(c, 500.0, **psB)
|
|
ax.plot(ccs[imin], mses[imin], **psC)
|
|
ax.annotate('Minimum of\ncost\nfunction',
|
|
xy=(ccs[imin], mses[imin]*1.2), xycoords='data',
|
|
xytext=(4, 7000), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(0.2,0.0),
|
|
connectionstyle="angle3,angleA=10,angleB=90") )
|
|
ax.text(2.2, 500, 'True\nparameter\nvalue')
|
|
ax.annotate('', xy=(c-0.2, 500), xycoords='data',
|
|
xytext=(4.1, 700), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(1.0,0.0),
|
|
connectionstyle="angle3,angleA=-10,angleB=0") )
|
|
ax.set_xlabel('c')
|
|
ax.set_ylabel('Mean squared error')
|
|
ax.set_xlim(2, 8.2)
|
|
ax.set_ylim(0, 10000)
|
|
ax.set_xticks(np.arange(2.0, 8.1, 2.0))
|
|
ax.set_yticks(np.arange(0, 10001, 5000))
|
|
|
|
|
|
def plot_mse_min(ax, x, y, c):
|
|
ccs = np.arange(0.5, 10.0, 0.05)
|
|
mses = np.zeros(len(ccs))
|
|
for i, cc in enumerate(ccs):
|
|
mses[i] = np.mean((y-(cc*x**3.0))**2.0)
|
|
imin = np.argmin(mses)
|
|
di = 25
|
|
i0 = 16
|
|
dimin = np.argmin(mses[i0::di])*di + i0
|
|
|
|
ax.plot(c, 500.0, **psB)
|
|
ax.plot(ccs, mses, **lsAm)
|
|
ax.plot(ccs[i0::di], mses[i0::di], **psAm)
|
|
ax.plot(ccs[dimin], mses[dimin], **psD)
|
|
#ax.plot(ccs[imin], mses[imin], **psCm)
|
|
ax.annotate('Estimated\nminimum of\ncost\nfunction',
|
|
xy=(ccs[dimin], mses[dimin]*1.2), xycoords='data',
|
|
xytext=(4, 6700), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(0.8,0.0),
|
|
connectionstyle="angle3,angleA=0,angleB=85") )
|
|
ax.set_xlabel('c')
|
|
ax.set_xlim(2, 8.2)
|
|
ax.set_ylim(0, 10000)
|
|
ax.set_xticks(np.arange(2.0, 8.1, 2.0))
|
|
ax.set_yticks(np.arange(0, 10001, 5000))
|
|
ax.yaxis.set_major_formatter(mt.NullFormatter())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x, y, c = create_data()
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=cm_size(figure_width, 1.1*figure_height))
|
|
fig.subplots_adjust(**adjust_fs(left=8.0, right=1.2))
|
|
plot_mse(ax1, x, y, c)
|
|
plot_mse_min(ax2, x, y, c)
|
|
fig.savefig("cubiccost.pdf")
|
|
plt.close()
|