import sys
sys.path.insert(0, 'ephys')  # for analysing data
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from spectral import diag_projection, peakedness
from plotstyle import plot_style


cell_name = '2020-10-27-ag-invivo-1'
run1 = 0
run2 = 1

example_cells = [
                 ['2021-06-18-ae-invivo-1', 3],    #  98Hz, 1%, ok
                 ['2012-03-30-ah', 2],             # 177Hz, 2.5%, 2.0, nice
                 ##['2012-07-03-ak', 0],             # 120Hz, 2.5%, 1.8, broader
                 ##['2012-12-20-ac', 0],             # 213Hz, 2.5%, 2.1, ok
                 #['2017-07-18-ai-invivo-1', 1],    #  78Hz, 5%, 2.3, weak
                 ##['2019-06-28-ae', 0],             # 477Hz, 10%, 2.6, weak
                 ##['2020-10-27-aa-invivo-1', 4],    # 259Hz, 0.5%, 2.0, ok
                 ##['2020-10-27-ae-invivo-1', 4],    # 375Hz, 0.5%, 4.3, nice, additional low freq line
                 ###['2020-10-27-ag-invivo-1', 2],    # 405Hz, 5%, 3.9, strong, is already the example
                 ##['2021-08-03-ab-invivo-1', 1],    # 140Hz, 0.5%, ok
                 ['2020-10-29-ag-invivo-1', 2],    #  164Hz, 5%, 1.6, no diagonal
                 ##['2010-08-31-ag', 1],             #  269Hz, 5%, no diagonal
                 ['2018-08-24-ak', 1],             #  145Hz, 5%, no diagonal
                 ##['2018-08-29-af', 1],             #  383Hz, 5%, no diagonal
                 ]

data_path = Path('data') / 'cells'


def load_baseline(path, cell_name):
    d = path / f'{cell_name}-baseline.npz'
    data = np.load(d)
    ['eodf', 'isis', 'isih', 'lags', 'corrs', 'freqs', 'prr']
    eodf = float(data['eodf'])
    rate = float(data['ratebase/Hz'])
    cv = float(data['cvbase'])
    isis = data['isis']
    pdf = data['isih']
    freqs = data['freqs']
    prr = data['prr']
    return eodf, rate, cv, isis, pdf, freqs, prr
    

def load_noise(path, cell_name, run):
    data = np.load(path / f'{cell_name}-spectral-data-s{run:02d}.npz')
    contrast = data['contrast']
    time = data['time']
    stimulus = data['stimulus']
    name = str(data['stimulus_name'])
    fcutoff = float(name.lower().replace('blwn', '').replace('inputarr_', '').replace('gwn', '').split('h')[0])
    spikes = []
    for k in range(1000):
        key = f'spikes_{k:03d}'
        if not key in data.keys():
            break
        spikes.append(data[key])
    return contrast, time, stimulus, spikes


def load_spectra(path, cell_name, run=None):
    if run is None:
        data = np.load(cell_name)
    else:
        d = list(path.glob(f'{cell_name}-spectral*-s{run:02d}.npz'))
        data = np.load(d[0])
    contrast = float(data['alpha'])
    fcutoff = float(data['fcutoff'])
    freqs = data['freqs']
    pss = data['pss']
    prs = data['prs']
    prss = data['prss']
    nsegs = int(data['n'])
    gain = np.abs(prs)/pss
    chi2 = np.abs(prss)*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
    return fcutoff, contrast, freqs, gain, chi2


def plot_isih(ax, s, rate, cv, isis, pdf):
    ax.show_spines('b')
    ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1)
    ax.set_xlim(0, 8)
    ax.set_xticks_delta(2)
    ax.set_xlabel('ISI', 'ms')
    ax.text(0, 1.08, 'P-unit:', transform=ax.transAxes, color=s.cell_color1,
            fontsize='large')
    ax.text(0.6, 1.08, f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}',
            transform=ax.transAxes)


def plot_response_spectrum(ax, s, eodf, rate, freqs, prr):
    rate_i = np.argmax(prr[freqs < 0.7*eodf])
    eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500)
    power_db = 10*np.log10(prr/np.max(prr))
    ax.show_spines('b')
    mask = freqs < 890
    ax.plot(freqs[mask], power_db[mask], **s.lsC1)
    ax.plot(freqs[eod_i], power_db[eod_i], **s.psA1c)
    ax.plot(freqs[rate_i], power_db[rate_i], **s.psA2c)
    ax.set_xlim(0, 900)
    ax.set_ylim(-25, 5)
    ax.set_xticks_delta(300)
    ax.set_xlabel('$f$', 'Hz')
    ax.text(freqs[eod_i], power_db[eod_i] + 2, '$f_{\\rm EOD}$',
            ha='center')
    ax.text(freqs[rate_i], power_db[rate_i] + 2, '$r$',
            ha='center')
    ax.yscalebar(1.05, 0, 10, 'dB', ha='right')


def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2):
    t0 = 0.3
    t1 = 0.4
    #print(len(spikes1), len(spikes2))
    maxtrials = 8
    trials = np.arange(maxtrials)
    ax.show_spines('')
    ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1,
                 linelength=0.8, linewidths=1, color=s.cell_color1)
    ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials,
                 linelength=0.8, linewidths=1, color=s.cell_color2)
    am = 1 + contrast1*stimulus1
    eod = np.sin(2*np.pi*eodf*time1) * am
    ax.plot(time1, 4*eod + 7, **s.lsEOD)
    ax.plot(time1, 4*am + 7, **s.lsAM)
    ax.set_xlim(t0, t1)
    ax.set_ylim(-2*maxtrials - 0.5, 14)
    ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right')
    ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%',
            va='center', color=s.cell_color1)
    ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%',
            va='center', color=s.cell_color2)
    

def plot_gain(ax, s, contrast1, freqs1, gain1, contrast2, freqs2, gain2, fcutoff):
    ax.plot(freqs2, gain2, label=f'{100*contrast2:.0f}', **s.lsC2)
    ax.plot(freqs1, gain1, label=f'{100*contrast1:.0f}', **s.lsC1)
    ax.set_xlim(0, fcutoff)
    ax.set_ylim(0, 800)
    ax.set_xticks_delta(100)
    ax.set_xlabel('$f$', 'Hz')
    ax.set_ylabel(r'$|\chi_1|$', 'Hz')
    

def plot_colorbar(ax, pc, dc=None):
    cax = ax.inset_axes([1.04, 0, 0.05, 1])
    cax.set_spines_outward('lrbt', 0)
    cb = cax.get_figure().colorbar(pc, cax=cax, label=r'$|\chi_2|$ [kHz]')
    cb.outline.set_color('none')
    cb.outline.set_linewidth(0)
    if dc is not None:
        cax.set_yticks_delta(dc)

    
def plot_chi2(ax, s, contrast, freqs, chi2, fcutoff, vmax):
    ax.set_aspect('equal')
    if vmax is None:
        vmax = np.quantile(1e-3*chi2, 0.99)
    pc = ax.pcolormesh(freqs, freqs, 1e-3*chi2, vmin=0, vmax=vmax,
                       cmap='viridis', rasterized=True, zorder=10)
    ax.set_xlim(0, fcutoff)
    ax.set_ylim(0, fcutoff)
    df = 100 if fcutoff == 300 else 50
    ax.set_xticks_delta(df)
    ax.set_yticks_delta(df)
    ax.set_xlabel('$f_1$', 'Hz')
    ax.set_ylabel('$f_2$', 'Hz')
    return pc


def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff):
    diags = []
    nlis = []
    nlips = []
    nlifs = []
    for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]:
        dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
        diags.append([dfreqs, diag])
        nli, nlif = peakedness(dfreqs, diag, fbase, median=False)
        nlip = diag[np.argmin(np.abs(dfreqs - nlif))]
        nlis.append(nli)
        nlips.append(nlip)
        nlifs.append(nlif)
        print(f'    SI at {100*contrast:.1f}% contrast: {nli:.2f}')
    ax.axvline(fbase, **s.lsGrid)
    ax.plot(diags[1][0], 1e-3*diags[1][1], **s.lsC2)
    ax.plot(diags[0][0], 1e-3*diags[0][1], **s.lsC1)
    ax.plot(nlifs[1], 1e-3*nlips[1], **s.psC2)
    ax.plot(nlifs[0], 1e-3*nlips[0], **s.psC1)
    ax.set_xlim(0, 2*fcutoff)
    ax.set_ylim(0, 4.2)
    ax.set_xticks_delta(300)
    ax.set_yticks_delta(1)
    ax.set_xlabel('$f_1 + f_2$', 'Hz')
    #ax.set_ylabel(r'$|\chi_2|$', 'kHz')
    ax.text(nlifs[1] - 50, 1e-3*nlips[1], f'{100*contrast2:.0f}\\%',
            ha='right')
    ax.text(nlifs[1] + 70, 1e-3*nlips[1], f'SI={nlis[1]:.1f}')
    ax.text(nlifs[0] - 50, 1e-3*nlips[0], f'{100*contrast1:.0f}\\%',
            ha='right')
    ax.text(nlifs[0] + 70, 1e-3*nlips[0], f'SI={nlis[0]:.1f}')
    ax.text(fbase, 4.3, '$r$', ha='center')

    
if __name__ == '__main__':
    print('Example P-unit:', cell_name)
    eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name)
    print(f'    baseline firing rate: {rate:.0f}Hz')
    print(f'    baseline firing CV  : {cv:.2f}')
    contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1)
    contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2)
    fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path, cell_name, run1)
    fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path, cell_name, run2)
    
    s = plot_style()
    s.cell_color1 = s.punit_color1
    s.cell_color2 = s.punit_color2
    s.lsC1 = s.lsP1
    s.lsC2 = s.lsP2
    s.psC1 = s.psP1
    s.psC2 = s.psP2
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.5, 3],
                                        cmsize=(s.plot_width, 0.8*s.plot_width))
    fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4,
                        wspace=0.4, hspace=0.5)
    axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10])
    axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4)
    axg = axg.subplots(1, 1, width_ratios=[1, 0.1])
    axd = axd.subplots(1, 1, width_ratios=[0.2, 1])
    axs = ax3.subplots(1, 4, wspace=0.4)
    
    plot_isih(axi, s, rate, cv, isis, pdf)
    plot_response_spectrum(axp, s, eodf, rate, freqs, prr)
    plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1,
                  contrast2, spikes2)
    
    plot_gain(axg, s, contrast1, freqs1, gain1,
              contrast2, freqs2, gain2, fcutoff1)
    pc = plot_chi2(axc1, s, contrast2, freqs2, chi22, fcutoff2, 4)
    axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag)
    axc1.set_title(f'$c$={100*contrast2:g}\\,\\%',
                   fontsize='medium', color=s.cell_color2)
    pc = plot_chi2(axc2, s, contrast1, freqs1, chi21, fcutoff1, 4)
    axc2.set_title(f'$c$={100*contrast1:g}\\,\\%',
                   fontsize='medium', color=s.cell_color1)
    axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag)
    plot_colorbar(axc2, pc)
    plot_diagonals(axd, s, rate, contrast1, freqs1, chi21,
                   contrast2, freqs2, chi22, fcutoff1)
    
    fig.common_yticks(axc1, axc2)
    fig.tag([axi, axp, axr], xoffs=-3, yoffs=-1)
    fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2)

    print('Additional example cells:')
    for k, (cell, run) in enumerate(example_cells):
        eodf, rate, cv, _, _, _, _ = load_baseline(data_path, cell)
        fcutoff, contrast, freqs, gain, chi2 = load_spectra(data_path, cell, run)
        dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
        nli, nlif = peakedness(dfreqs, diag, rate, median=False)
        print(f'    {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}')
        pc = plot_chi2(axs[k], s, contrast, freqs, chi2, fcutoff, 1.3)
        axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium')
        axs[k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50,
                    color='white', fontsize='medium',
                    transform=axs[k].transAxes)
    plot_colorbar(axs[-1], pc)
    fig.common_yticks(axs)
    fig.tag(axs, xoffs=-3, yoffs=2)
    
    fig.savefig()