import matplotlib.pyplot as plt
import numpy as np
from plotstyle import colors, show_spines

def create_data():
    # wikipedia:
    # Generally, males vary in total length from 250 to 390 cm and
    # weigh between 90 and 306 kg
    c = 6
    x = np.arange(2.2, 3.9, 0.05)
    y = c * x**3.0
    rng = np.random.RandomState(32281)
    noise = rng.randn(len(x))*50
    y += noise
    return x, y, c

    
def plot_data(ax, x, y, c):
    ax.scatter(x, y, marker='o', color=colors['blue'], s=40, zorder=10)
    xx = np.linspace(2.1, 3.9, 100)
    ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2, zorder=5)
    for cc in [0.25*c, 0.5*c, 2.0*c, 4.0*c]:
        ax.plot(xx, cc*xx**3.0, color=colors['orange'], lw=1.5, zorder=5)

    show_spines(ax, 'lb')
    ax.set_xlabel('Size x / m')
    ax.set_ylabel('Weight y / kg')
    ax.set_xlim(2, 4)
    ax.set_ylim(0, 400)
    ax.set_xticks(np.arange(2.0, 4.1, 0.5))
    ax.set_yticks(np.arange(0, 401, 100))
    
    
def plot_data_errors(ax, x, y, c):
    show_spines(ax, 'lb')
    ax.set_xlabel('Size x / m')
    #ax.set_ylabel('Weight y / kg')
    ax.set_xlim(2, 4)
    ax.set_ylim(0, 400)
    ax.set_xticks(np.arange(2.0, 4.1, 0.5))
    ax.set_yticks(np.arange(0, 401, 100))
    ax.set_yticklabels([])
    ax.annotate('Error',
                xy=(x[28]+0.05, y[28]+60), xycoords='data',
                xytext=(3.4, 70), textcoords='data', ha='left',
                arrowprops=dict(arrowstyle="->", relpos=(0.9,1.0),
                connectionstyle="angle3,angleA=50,angleB=-30") )
    ax.scatter(x[:40], y[:40], color=colors['blue'], s=10, zorder=0)
    inxs = [3, 10, 11, 17, 18, 21, 28, 30, 33]
    ax.scatter(x[inxs], y[inxs], color=colors['blue'], s=40, zorder=10)
    xx = np.linspace(2.1, 3.9, 100)
    ax.plot(xx, c*xx**3.0, color=colors['red'], lw=2)
    for i in inxs :
        xx = [x[i], x[i]]
        yy = [c*x[i]**3.0, y[i]]
        ax.plot(xx, yy, color=colors['orange'], lw=2, zorder=5)

def plot_error_hist(ax, x, y, c):
    show_spines(ax, 'lb')
    ax.set_xlabel('Squared error')
    ax.set_ylabel('Frequency')
    bins = np.arange(0.0, 1250.0, 100)
    ax.set_xlim(bins[0], bins[-1])
    #ax.set_ylim(0, 35)
    ax.set_xticks(np.arange(bins[0], bins[-1], 200))
    #ax.set_yticks(np.arange(0, 36, 10))
    errors = (y-(c*x**3.0))**2.0
    mls = np.mean(errors)
    ax.annotate('Mean\nsquared\nerror',
                xy=(mls, 0.5), xycoords='data',
                xytext=(800, 3), textcoords='data', ha='left',
                arrowprops=dict(arrowstyle="->", relpos=(0.0,0.2),
                connectionstyle="angle3,angleA=10,angleB=90") )
    ax.hist(errors, bins, color=colors['orange'])



if __name__ == "__main__":
    x, y, c = create_data()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7., 2.6))
    plot_data(ax1, x, y, c)
    plot_data_errors(ax2, x, y, c)
    #plot_error_hist(ax2, x, y, c)
    fig.tight_layout()
    fig.savefig("cubicerrors.pdf")
    plt.close()