import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
import scipy.io as scio

def load_data(filename):
    data = scio.loadmat(filename)
    t = data['t']
    eod = data['eod']
    return t[0], eod[0]
    

def eod_model(p, t):
    b_0 = p[0]
    omega_0 = p[1]
    
    params = p[2:]
    n = len(params)/2
    eod = np.zeros(t.shape)
    for i in range(n):
        eod += params[i*2]  * np.sin(2 * np.pi * t * (i+1) * omega_0 + params[i*2+1]) 
    eod += b_0
    return eod


def fit_error(p, t, y):
    y_dash = eod_model(p,t)
    err = np.mean((y - y_dash)**2)
    return err


def gradient(p, t, y, scale=None):
    if scale is None:
        scale = np.ones(len(p))
    grad = np.zeros(len(p))
    h = 0.002
    for i in range(len(p)):
        p_temp = list(p)
        p_temp[i] =  p[i] + scale[i] * h
        grad[i] = (fit_error(p_temp, t, y) - fit_error(p, t, y)) / h
    return grad


def gradient_descent(t, y):
    count = 80
    b_0 = np.mean(y)
    omega_0 = 650
    params = [b_0, omega_0, np.min(y) +  np.max(y), np.pi/2, (np.min(y) +  np.max(y))/2,  np.pi/3, (np.min(y) +  np.max(y))/4,  np.pi/4, (np.min(y) +  np.max(y))/5,  np.pi]
    scale = np.ones(len(params))
    scale[1] = 1000
    eps = 0.01
    grad = None
    errors = []
    plt.axis([0, t[count], -3, 3.5])
    plt.ion()
    plt.plot(t[:count], y[:count])
    l = None
    while (grad is None) or (np.linalg.norm(grad) > 0.01):
        errors.append(fit_error(params, t[:count], y[:count]))
        
        grad = gradient(params, t[:count], y[:count])        
        params -= eps * scale * grad 
        if l is None:
            l = plt.plot(t[:count], eod_model(params, t[:count]))
        else:
            l[0].set_data(t[:count], eod_model(params, t[:count]))
        plt.title("norm: %.2f, freq: %.2f" % (np.linalg.norm(grad), params[1]))
        plt.pause(0.005)
    embed()
    exit()
    


if __name__ == "__main__":
    datafile = '../data/EOD_data.mat'
    t, eod = load_data(datafile)
    # eod = eod_model([0.0, 600, 1.0, 0.0], t[:100])
    gradient_descent(t, eod)
    embed()
    exit()