import matplotlib.pyplot as plt import numpy as np def create_data(): m = 0.75 n= -40 x = np.concatenate( (np.arange(10.,110., 2.5), np.arange(0.,120., 2.0)) ) y = m * x + n; rng = np.random.RandomState(37281) noise = rng.randn(len(x))*15 y += noise return x, y, m, n def plot_data(ax, x, y, m, n): ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') ax.tick_params(direction="out", width=1.25) ax.tick_params(direction="out", width=1.25) ax.set_xlabel('Input x') ax.set_ylabel('Output y') ax.set_xlim(0, 120) ax.set_ylim(-80, 80) ax.set_xticks(np.arange(0,121, 40)) ax.set_yticks(np.arange(-80,81, 40)) ax.annotate('Error', xy=(x[34]+1, y[34]+15), xycoords='data', xytext=(80, -50), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(0.9,1.0), connectionstyle="angle3,angleA=50,angleB=-30") ) ax.scatter(x[:40], y[:40], color='b', s=10, zorder=0) inxs = [3, 13, 16, 19, 25, 34, 36] ax.scatter(x[inxs], y[inxs], color='b', s=40, zorder=10) xx = np.asarray([2, 118]) ax.plot(xx, m*xx+n, color='#CC0000', lw=2) for i in inxs : xx = [x[i], x[i]] yy = [m*x[i]+n, y[i]] ax.plot(xx, yy, color='#FF9900', lw=2, zorder=5) def plot_error_hist(ax, x, y, m, n): ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.yaxis.set_ticks_position('left') ax.xaxis.set_ticks_position('bottom') ax.tick_params(direction="out", width=1.25) ax.tick_params(direction="out", width=1.25) ax.set_xlabel('Squared error') ax.set_ylabel('Frequency') bins = np.arange(0.0, 602.0, 50.0) ax.set_xlim(bins[0], bins[-1]) ax.set_ylim(0, 35) ax.set_xticks(np.arange(bins[0], bins[-1], 100)) ax.set_yticks(np.arange(0, 36, 10)) errors = (y-(m*x+n))**2.0 mls = np.mean(errors) ax.annotate('Mean\nsquared\nerror', xy=(mls, 0.5), xycoords='data', xytext=(350, 20), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2), connectionstyle="angle3,angleA=10,angleB=90") ) ax.hist(errors, bins, color='#FF9900') if __name__ == "__main__": x, y, m, n = create_data() plt.xkcd() fig = plt.figure() ax = fig.add_subplot(1, 2, 1) plot_data(ax, x, y, m, n) ax = fig.add_subplot(1, 2, 2) plot_error_hist(ax, x, y, m, n) fig.set_facecolor("white") fig.set_size_inches(7., 2.6) fig.tight_layout() fig.savefig("linear_least_squares.pdf") plt.close()