83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from plotstyle import *
|
|
|
|
def create_data():
|
|
# wikipedia:
|
|
# Generally, males vary in total length from 250 to 390 cm and
|
|
# weigh between 90 and 306 kg
|
|
c = 6
|
|
x = np.arange(2.2, 3.9, 0.05)
|
|
y = c * x**3.0
|
|
rng = np.random.RandomState(32281)
|
|
noise = rng.randn(len(x))*50
|
|
y += noise
|
|
return x, y, c
|
|
|
|
def gradient_descent(x, y):
|
|
n = 20
|
|
dc = 0.01
|
|
eps = 0.0001
|
|
cc = 1.1
|
|
cs = []
|
|
mses = []
|
|
for k in range(n):
|
|
m0 = np.mean((y-(cc*x**3.0))**2.0)
|
|
m1 = np.mean((y-((cc+dc)*x**3.0))**2.0)
|
|
dmdc = (m1 - m0)/dc
|
|
cs.append(cc)
|
|
mses.append(m0)
|
|
cc -= eps*dmdc
|
|
return cs, mses
|
|
|
|
def plot_mse(ax, x, y, c, cs):
|
|
ms = np.zeros(len(cs))
|
|
for i, cc in enumerate(cs):
|
|
ms[i] = np.mean((y-(cc*x**3.0))**2.0)
|
|
ccs = np.linspace(0.5, 10.0, 200)
|
|
mses = np.zeros(len(ccs))
|
|
for i, cc in enumerate(ccs):
|
|
mses[i] = np.mean((y-(cc*x**3.0))**2.0)
|
|
|
|
ax.plot(ccs, mses, colors['blue'], lw=2, zorder=10)
|
|
ax.scatter(cs, ms, color=colors['red'], s=40, zorder=20)
|
|
ax.scatter(cs[-1], ms[-1], color=colors['orange'], s=60, zorder=30)
|
|
for i in range(4):
|
|
ax.annotate('',
|
|
xy=(cs[i+1]+0.2, ms[i+1]), xycoords='data',
|
|
xytext=(cs[i]+0.3, ms[i]+200), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(0.0,0.0),
|
|
connectionstyle="angle3,angleA=10,angleB=70") )
|
|
|
|
|
|
show_spines(ax, 'lb')
|
|
ax.set_xlabel('c')
|
|
ax.set_ylabel('Mean squared error')
|
|
ax.set_xlim(0, 10)
|
|
ax.set_ylim(0, 25000)
|
|
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
|
|
ax.set_yticks(np.arange(0, 30001, 10000))
|
|
|
|
def plot_descent(ax, cs, mses):
|
|
ax.plot(np.arange(len(mses))+1, mses, '-o', c=colors['red'], mew=0, ms=8)
|
|
|
|
show_spines(ax, 'lb')
|
|
ax.set_xlabel('Iteration')
|
|
#ax.set_ylabel('Mean squared error')
|
|
ax.set_xlim(0, 10.5)
|
|
ax.set_ylim(0, 25000)
|
|
ax.set_xticks(np.arange(0.0, 10.1, 2.0))
|
|
ax.set_yticks(np.arange(0, 30001, 10000))
|
|
ax.set_yticklabels([])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x, y, c = create_data()
|
|
cs, mses = gradient_descent(x, y)
|
|
fig, (ax1, ax2) = plt.subplots(1, 2)
|
|
fig.subplots_adjust(wspace=0.2, **adjust_fs(left=8.0, right=0.5))
|
|
plot_mse(ax1, x, y, c, cs)
|
|
plot_descent(ax2, cs, mses)
|
|
fig.savefig("cubicmse.pdf")
|
|
plt.close()
|