import numpy as np
import matplotlib.pyplot as plt
from IPython import embed

def despine(axis, spines=None, hide_ticks=True):

    def hide_spine(spine):
        spine.set_visible(False)

    for spine in axis.spines.keys():
        if spines is not None:
            if spine in spines:
                hide_spine(axis.spines[spine])
        else:
            hide_spine(axis.spines[spine])
    if hide_ticks:
        axis.xaxis.set_ticks([])
        axis.yaxis.set_ticks([])


def create_chirp(eodf=500, chirpsize=100, chirpduration=0.015, ampl_reduction=0.05, chirptimes=[0.05, 0.2], kurtosis=1.0, duration=1., dt=0.00001):
    """create a fake fish eod that contains chirps at the given times. EOF is a simple sinewave. Chirps are modeled with Gaussian profiles in amplitude reduction and frequency ecxcursion.

    Args:
        eodf (int, optional): The chriping fish's EOD frequency. Defaults to 500 Hz.
        chirpsize (int, optional): the size of the chrip's frequency excursion. Defaults to 100 Hz.
        chirpwidth (float, optional): the duration of the chirp. Defaults to 0.015 s.
        ampl_reduction (float, optional): Amount of amplitude reduction during the chrips. Defaults to 0.05, i.e. 5\%
        chirptimes (list, optional): Times of chirp centers. Defaults to [0.05, 0.2].
        kurtosis (float, optional): The kurtosis of the Gaussian profiles. Defaults to 1.0
        dt (float, optional): the stepsize of the simulation. Defaults to 0.00001 s.

    Returns:
        np.ndarray: the time
        np.ndarray: the eod
        np.ndarray: the amplitude profile
        np.adarray: tha frequency profile
    """
    p = 0.

    time = np.arange(0.0, duration, dt)
    signal = np.zeros(time.shape)
    ampl = np.ones(time.shape)
    freq = np.ones(time.shape)

    ck = 0
    """
    for ( int i=0; i<pointsperchirp; i++ ) {
        double tt = i*deltat - t0; // actual time
        double gg = exp( -0.5 * ::pow( (tt/sig)*(tt/sig), ChirpKurtosis ) ); // gaussian profile
        double cc = ChirpSize*gg; // current chirp size
        p += (DeltaF + cc)*deltat;
        dp += cc * deltat;
        signal.push( (1.0 - ChirpDip*gg) * cos( 6.28318 * p ) );
        if ( tt > 2.0*ChirpWidth )
        break;
    }
    """
    csig = 0.5 * chirpduration / np.power(2.0*np.log(10.0), 0.5/kurtosis)
    for k, t in enumerate(time):
        a = 1.
        f = eodf

        if ck < len(chirptimes):
            if np.abs( t - chirptimes[ck] ) < 2.0 * chirpduration:
                x = t - chirptimes[ck]
                gg = np.exp(-0.5 * np.power((x/csig)**2, kurtosis))
                cc = chirpsize * gg
                
                # g = np.exp( -0.5 * (x/csig)**2 )
                f = chirpsize * gg + eodf
                a *= 1.0 - ampl_reduction * gg
            elif t >  chirptimes[ck] + 2.0 * chirpduration:
                ck += 1
        freq[k] = f
        ampl[k] = a
        p += f * dt
        signal[k] = a * np.sin( 2*np.pi * p )
    
    return time, signal, ampl, freq


def sender_receiver_simulation(sender_eodf=250, receiver_eodf=255, sender_intensity=1.0, receiver_intensity=0.2, sender_chirps=True, duration=0.5, chirp_count=4):
    """[summary]

    Args:
        sender_eodf (int, optional): [description]. Defaults to 500.
        receiver_eodf (int, optional): [description]. Defaults to 750.
        sender_intensity (float, optional): [description]. Defaults to 1.0.
        receiver_intensity (float, optional): [description]. Defaults to 0.2.
        sender_chirps (bool, optional): [description]. Defaults to True.
    """
    chirp_times = np.arange(0.05 * duration , 0.95 * duration, (0.8 * duration)/(chirp_count-2)) 
    if sender_chirps:
        time, sender_eod, _, _ = create_chirp(sender_eodf, duration=duration, chirptimes=chirp_times) 
        sender_eod *= sender_intensity
        receiver_eod = np.sin(time * 2 * np.pi * receiver_eodf) * receiver_intensity
    else:
        time, receiver_eod, _, _ = create_chirp(receiver_eodf, duration=duration, chirptimes=chirp_times) 
        receiver_eod *= receiver_intensity
        sender_eod = np.sin(time * 2 * np.pi * sender_eodf) * sender_intensity
   
    return time, sender_eod, receiver_eod

 
def plot_simulation():
    time, sender_eod, receiver_eod = sender_receiver_simulation()
    _, sender_eod2, receiver_eod2 = sender_receiver_simulation(sender_chirps=False)

    combined_signal = sender_eod + receiver_eod

    ylims = [-np.ceil(np.max(combined_signal)), np.ceil(np.max(combined_signal))]
   
    fig = plt.figure()
    eod1_ax = fig.add_subplot(321)
    eod1_ax.plot(time, sender_eod)
    eod1_ax.set_ylim(ylims)
    despine(eod1_ax, ["top", "bottom", "right", "left"])
    
    combined_ax = fig.add_subplot(324)
    combined_ax.plot(time, sender_eod + receiver_eod)
    combined_ax.set_ylim(ylims)
    despine(combined_ax, ["top", "bottom", "right", "left"])

    eod2_ax = fig.add_subplot(325)
    eod2_ax.plot(time, receiver_eod)
    eod2_ax.set_ylim(ylims)
    despine(eod2_ax, ["top", "bottom", "right", "left"])
    plt.show()
    
    
def main():
    plot_simulation()
    
    """
    time, eod, ampl, freq = create_chirp(kurtosis=0.01)
    fig = plt.figure()
    ax1 = fig.add_subplot(311)
    ax2 = fig.add_subplot(312)
    ax3 = fig.add_subplot(313)

    ax1.plot(time, eod)
    ax2.plot(time, ampl)
    ax3.plot(time, freq)

    ax3.set_xlabel("time [s]")
    ax3.set_ylabel("frequency [Hz]")
    ax2.set_ylabel("amplitude [rel]")
    ax1.set_ylabel("fake fish field [rel]")
    plt.show()
    """
    
if __name__ == "__main__":
    main()