import numpy as np
import matplotlib.pyplot as plt
import thunderfish.peakdetection as pd


def create_chirp(eodf):
    stimulusrate = eodf  # the eod frequency of the fake fish
    currentchirptimes = [0.0]
    chirpwidth = 0.014  # ms
    chirpsize = 100.
    chirpampl = 0.02
    chirpkurtosis = 1.
    p = 0.
    stepsize = 0.00001


    time = np.arange(-0.05, 0.05, stepsize)
    signal = np.zeros(time.shape)
    ampl = np.ones(time.shape)
    freq = np.ones(time.shape)

    ck = 0
    csig = 0.5 * chirpwidth / np.power(2.0*np.log(10.0), 0.5/chirpkurtosis)

    for k, t in enumerate(time):
        a = 1.
        f = stimulusrate
        if ck < len(currentchirptimes):
            if np.abs(t - currentchirptimes[ck]) < 2.0 * chirpwidth:
                x = t - currentchirptimes[ck]
                g = np.exp(-0.5 * (x/csig)**2)
                f = chirpsize * g + stimulusrate
                a *= 1.0 - chirpampl * g
            elif t > currentchirptimes[ck] + 2.0 * chirpwidth:
                ck += 1
        freq[k] = f
        ampl[k] = a
        p += f * stepsize
        signal[k] = a * np.sin(6.28318530717959 * p)

    return time, signal


def plot_chirp(eodf, eodf1, phase, axis):
    time, chirp_eod = create_chirp(eodf)
    eod = np.sin(time * 2 * np.pi * eodf1 + phase)

    y = chirp_eod * 0.4 + eod
    p, t = pd.detect_peaks(y, 0.1)
    axis.plot(time*1000, y, color = 'royalblue')
    axis.plot(time[p]*1000, (y)[p], lw=2, color='k')
    axis.plot(time[t]*1000, (y)[t], lw=2, color='k')
    axis.spines["top"].set_visible(False)
    axis.spines["right"].set_visible(False)



inch_factor = 2.54

fig = plt.figure(figsize=(20 / inch_factor, 10 / inch_factor))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

plot_chirp(600, 650, 0, ax2)
plot_chirp(600, 650, np.pi, ax4)

plot_chirp(600, 620, 0, ax1)
plot_chirp(600, 620, np.pi, ax3)

ax1.set_ylabel('EOD [mV]', fontsize=22)
ax1.set_title('$\Delta$f = 20 Hz', fontsize = 18)
ax1.yaxis.set_tick_params(labelsize=18)
ax1.set_xticklabels([])

ax2.set_title('$\Delta$f = 50 Hz', fontsize = 18)
ax2.set_xticklabels([])
ax2.set_yticklabels([])
ax3.set_ylabel('EOD [mV]', fontsize=22)
ax3.xaxis.set_tick_params(labelsize=18)
ax3.yaxis.set_tick_params(labelsize=18)

ax3.set_xlabel('Time [ms]', fontsize=22)
ax4.set_xlabel('Time [ms]', fontsize=22)
ax4.xaxis.set_tick_params(labelsize=18)
ax4.set_yticklabels([])






fig.tight_layout()
#plt.show()
plt.savefig('chirps_while_beat.png')