import matplotlib.pyplot as plt import numpy as np from plotstyle import * def create_data(): # wikipedia: # Generally, males vary in total length from 250 to 390 cm and # weigh between 90 and 306 kg c = 6 x = np.arange(2.2, 3.9, 0.05) y = c * x**3.0 rng = np.random.RandomState(32281) noise = rng.randn(len(x))*50 y += noise return x, y, c def plot_data(ax, x, y, c): ax.scatter(x, y, marker='o', color=colors['blue'], s=40, zorder=10) xx = np.linspace(2.1, 3.9, 100) ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2, zorder=5) for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]: ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5) show_spines(ax, 'lb') ax.set_xlabel('Size x', 'm') ax.set_ylabel('Weight y', 'kg') ax.set_xlim(2, 4) ax.set_ylim(0, 400) ax.set_xticks(np.arange(2.0, 4.1, 0.5)) ax.set_yticks(np.arange(0, 401, 100)) def plot_data_errors(ax, x, y, c): show_spines(ax, 'lb') ax.set_xlabel('Size x', 'm') #ax.set_ylabel('Weight y', 'kg') ax.set_xlim(2, 4) ax.set_ylim(0, 400) ax.set_xticks(np.arange(2.0, 4.1, 0.5)) ax.set_yticks(np.arange(0, 401, 100)) ax.set_yticklabels([]) ax.annotate('Error', xy=(x[28]+0.05, y[28]+60), xycoords='data', xytext=(3.4, 70), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(0.9,1.0), connectionstyle="angle3,angleA=50,angleB=-30") ) ax.scatter(x[:40], y[:40], color=colors['blue'], s=10, zorder=0) inxs = [3, 10, 11, 17, 18, 21, 28, 30, 33] ax.scatter(x[inxs], y[inxs], color=colors['blue'], s=40, zorder=10) xx = np.linspace(2.1, 3.9, 100) ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2) for i in inxs : xx = [x[i], x[i]] yy = [c*x[i]**3.0, y[i]] ax.plot(xx, yy, color=colors['orange'], lw=2, zorder=5) def plot_error_hist(ax, x, y, c): show_spines(ax, 'lb') ax.set_xlabel('Squared error') ax.set_ylabel('Frequency') bins = np.arange(0.0, 1250.0, 100) ax.set_xlim(bins[0], bins[-1]) #ax.set_ylim(0, 35) ax.set_xticks(np.arange(bins[0], bins[-1], 200)) #ax.set_yticks(np.arange(0, 36, 10)) errors = (y-(c*x**3.0))**2.0 mls = np.mean(errors) ax.annotate('Mean\nsquared\nerror', xy=(mls, 0.5), xycoords='data', xytext=(800, 3), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2), connectionstyle="angle3,angleA=10,angleB=90") ) ax.hist(errors, bins, color=colors['orange']) if __name__ == "__main__": x, y, c = create_data() fig, (ax1, ax2) = plt.subplots(1, 2) fig.subplots_adjust(wspace=0.2, **adjust_fs(left=6.0, right=1.2)) plot_data(ax1, x, y, c) plot_data_errors(ax2, x, y, c) #plot_error_hist(ax2, x, y, c) fig.savefig("cubicerrors.pdf") plt.close()