import matplotlib.pyplot as plt
import numpy as np

def create_data():
    m = 0.75
    n=  -40
    x = np.concatenate( (np.arange(10.,110., 2.5), np.arange(0.,120., 2.0)) )
    y = m * x + n;
    rng = np.random.RandomState(37281)
    noise = rng.randn(len(x))*15
    y += noise
    return x, y, m, n

    
def plot_data(ax, x, y, m, n):
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(direction="out", width=1.25)
    ax.tick_params(direction="out", width=1.25)
    ax.set_xlabel('Input x')
    ax.set_ylabel('Output y')
    ax.set_xlim(0, 120)
    ax.set_ylim(-80, 80)
    ax.set_xticks(np.arange(0,121, 40))
    ax.set_yticks(np.arange(-80,81, 40))
    ax.annotate('Error',
                xy=(x[34]+1, y[34]+15), xycoords='data',
                xytext=(80, -50), textcoords='data', ha='left',
                arrowprops=dict(arrowstyle="->", relpos=(0.9,1.0),
                connectionstyle="angle3,angleA=50,angleB=-30") )
    ax.scatter(x[:40], y[:40], color='b', s=10, zorder=0)
    inxs = [3, 13, 16, 19, 25, 34, 36]
    ax.scatter(x[inxs], y[inxs], color='b', s=40, zorder=10)
    xx = np.asarray([2, 118])
    ax.plot(xx, m*xx+n, color='#CC0000', lw=2)
    for i in inxs :
        xx = [x[i], x[i]]
        yy = [m*x[i]+n, y[i]]
        ax.plot(xx, yy, color='#FF9900', lw=2, zorder=5)
    

def plot_error_hist(ax, x, y, m, n):
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.tick_params(direction="out", width=1.25)
    ax.tick_params(direction="out", width=1.25)
    ax.set_xlabel('Squared error')
    ax.set_ylabel('Frequency')
    bins = np.arange(0.0, 602.0, 50.0)
    ax.set_xlim(bins[0], bins[-1])
    ax.set_ylim(0, 35)
    ax.set_xticks(np.arange(bins[0], bins[-1], 100))
    ax.set_yticks(np.arange(0, 36, 10))
    errors = (y-(m*x+n))**2.0
    mls = np.mean(errors)
    ax.annotate('Mean\nsquared\nerror',
                xy=(mls, 0.5), xycoords='data',
                xytext=(350, 20), textcoords='data', ha='left',
                arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2),
                connectionstyle="angle3,angleA=10,angleB=90") )
    ax.hist(errors, bins, color='#FF9900')



if __name__ == "__main__":
    x, y, m, n = create_data()
    plt.xkcd()
    fig = plt.figure()
    ax = fig.add_subplot(1, 2, 1)
    plot_data(ax, x, y, m, n)
    ax = fig.add_subplot(1, 2, 2)
    plot_error_hist(ax, x, y, m, n)
    fig.set_facecolor("white")
    fig.set_size_inches(7., 2.6)
    fig.tight_layout()
    fig.savefig("linear_least_squares.pdf")
    plt.close()