57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import numpy as np
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.cm as cm
|
|
from plotstyle import *
|
|
|
|
def create_data():
|
|
m = 0.75
|
|
n= -40
|
|
x = np.arange(10.,110., 2.5)
|
|
y = m * x + n;
|
|
rng = np.random.RandomState(37281)
|
|
noise = rng.randn(len(x))*15
|
|
y += noise
|
|
return x, y, m, n
|
|
|
|
|
|
def plot_error_plane(ax, x, y, m, n):
|
|
ax.set_xlabel('Slope m')
|
|
ax.set_ylabel('Intercept b')
|
|
ax.set_zlabel('Mean squared error')
|
|
ax.set_xlim(-4.5, 5.0)
|
|
ax.set_ylim(-60.0, -20.0)
|
|
ax.set_zlim(0.0, 700.0)
|
|
ax.set_xticks(np.arange(-4, 5, 2))
|
|
ax.set_yticks(np.arange(-60, -19, 10))
|
|
ax.set_zticks(np.arange(0, 700, 200))
|
|
ax.grid(True)
|
|
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
|
|
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
|
|
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
|
|
ax.invert_xaxis()
|
|
ax.view_init(25, 40)
|
|
slopes = np.linspace(-4.5, 5, 40)
|
|
intercepts = np.linspace(-60, -20, 40)
|
|
x, y = np.meshgrid(slopes, intercepts)
|
|
error_surf = np.zeros(x.shape)
|
|
for i, s in enumerate(slopes) :
|
|
for j, b in enumerate(intercepts) :
|
|
error_surf[j,i] = np.mean((y-s*x-b)**2.0)
|
|
ax.plot_surface(x, y, error_surf, rstride=1, cstride=1, cmap=cm.coolwarm,
|
|
linewidth=0, shade=True)
|
|
# Minimum:
|
|
mini = np.unravel_index(np.argmin(error_surf), error_surf.shape)
|
|
ax.scatter(slopes[mini[1]], intercepts[mini[0]], [0.0], color='#cc0000', s=60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
x, y, m, n = create_data()
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
|
plot_error_plane(ax, x, y, m, n)
|
|
fig.set_size_inches(7., 5.)
|
|
fig.subplots_adjust(**adjust_fs(fig, 1.0, 0.0, 0.0, 0.0))
|
|
fig.savefig("error_surface.pdf")
|
|
plt.close()
|