from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np

def create_data():
    m = 0.75
    n=  -40
    x = np.arange(10.,110., 2.5)
    y = m * x + n;
    rng = np.random.RandomState(37281)
    noise = rng.randn(len(x))*15
    y += noise
    return x, y, m, n
    

def plot_error_plane(ax, x, y, m, n):
    ax.set_xlabel('Slope m')
    ax.set_ylabel('Intercept b')
    ax.set_zlabel('Mean squared error')
    ax.set_xlim(-4.5, 5.0)
    ax.set_ylim(-60.0, -20.0)
    ax.set_zlim(0.0, 700.0)
    ax.set_xticks(np.arange(-4, 5, 2))
    ax.set_yticks(np.arange(-60, -19, 10))
    ax.set_zticks(np.arange(0, 700, 200))
    ax.grid(True)
    ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.invert_xaxis()
    ax.view_init(25, 40)
    slopes = np.linspace(-4.5, 5, 40)
    intercepts = np.linspace(-60, -20, 40)
    x, y = np.meshgrid(slopes, intercepts)
    error_surf = np.zeros(x.shape)
    for i, s in enumerate(slopes) :
        for j, b in enumerate(intercepts) :
                error_surf[j,i] = np.mean((y-s*x-b)**2.0)
    ax.plot_surface(x, y, error_surf, rstride=1, cstride=1, cmap=cm.coolwarm,
                    linewidth=0, shade=True)
    # Minimum:
    mini = np.unravel_index(np.argmin(error_surf), error_surf.shape)
    ax.scatter(slopes[mini[1]], intercepts[mini[0]], [0.0], color='#cc0000', s=60)


if __name__ == "__main__":
    x, y, m, n = create_data()
    plt.xkcd()
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    plot_error_plane(ax, x, y, m, n)
    #fig.set_facecolor("white")
    fig.set_size_inches(7., 5.)
    fig.tight_layout()
    fig.savefig("error_surface.pdf")
    plt.close()