import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mt
from plotstyle import *


def power_law(x, c, a):
    return c*x**a


def create_data():
    # wikipedia:
    # Generally, males vary in total length from 250 to 390 cm and
    # weigh between 90 and 306 kg
    c = 6.0
    x = np.arange(2.2, 3.9, 0.05)
    y = power_law(x, c, 3.0)
    rng = np.random.RandomState(32281)
    noise = rng.randn(len(x))*50
    y += noise
    return x, y, c


def gradient_descent(x, y, func, p0):
    n = 20000
    h = 1e-7
    ph = np.identity(len(p0))*h
    eps = 0.00001
    p = p0
    ps = np.zeros((n, len(p0)))
    mses = np.zeros(n)
    for k in range(n):
        m0 = np.mean((y-func(x, *p))**2.0)
        gradient = np.array([(np.mean((y-func(x, *(p+ph[:,i])))**2.0) - m0)/h
                             for i in range(len(p))])
        ps[k,:] = p
        mses[k] = m0
        p -= eps*gradient
    return ps, mses

    
def plot_gradient_descent(ax, x, y, c, ps, mses):
    cs = np.linspace(0.0, 10.0, 300)
    bs = np.linspace(1.0, 5.5, 180)
    mse = np.zeros((len(bs), len(cs)))
    for i in range(len(bs)):
        for k in range(len(cs)):
            mse[i, k] = np.mean((y-power_law(x, cs[k], bs[i]))**2.0)
    z = np.log10(mse)
    ax.contourf(cs, bs, z, levels=(3.3, 3.36, 3.5, 4.0, 4.5, 5.5, 6.5, 7.5, 8.5),
                cmap='Blues_r')
    ax.plot(ps[::5,0], ps[::5,1], **lsBm)
    ax.plot(ps[-1,0], ps[-1,1], **psC)
    ax.set_xlabel('c')
    ax.set_ylabel('a')
    ax.yaxis.set_major_locator(mt.MultipleLocator(1.0))
    ax.set_aspect('equal')


if __name__ == "__main__":
    x, y, c = create_data()
    ps, mses = gradient_descent(x, y, power_law, [1.0, 1.0])
    fig, ax = plt.subplots(figsize=cm_size(figure_width, 1.3*figure_height))
    fig.subplots_adjust(**adjust_fs(left=4.5, right=1.0))
    plot_gradient_descent(ax, x, y, c, ps, mses)
    fig.savefig("powergradientdescent.pdf")
    plt.close()