import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import pearsonr, gaussian_kde
from scipy.stats import mannwhitneyu
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, significance_str


data_path = Path('data')

from noisesplit import model_cell as model_split_example
from modelsusceptcontrasts import model_cells as model_contrast_examples
from modelsusceptlown import model_cell as model_lown_example
from punitexamplecell import example_cell as punit_example
from punitexamplecell import example_cells as punit_examples
from noisesplit import example_cell as punit_split_example
from ampullaryexamplecell import example_cell as ampul_example
from ampullaryexamplecell import example_cells as ampul_examples

model_examples = ([[model_lown_example, 0.01],
                   [model_lown_example, 0.03],
                   [model_lown_example, 0.1]],
                  [[model_split_example, 0.01]],
                  [[m, a] for m in model_contrast_examples for a in [0.01, 0.03, 0.1]])
punit_examples = (punit_example, [punit_split_example], punit_examples)
ampul_examples = (ampul_example, [], ampul_examples)


def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
              si_thresh, example=[], split_example=[], examples=[]):
    ax.axhline(si_thresh, color='k', ls=':', lw=0.5)
    xmax = ax.get_xlim()[1]
    ymax = ax.get_ylim()[1]
    mask = (data[xcol] < xmax) & (data[ycol] < ymax)
    if 'stimindex' in data:
        for cell, run in example + split_example + examples:
            mask &= ~((data['cell'] == cell) & (data['stimindex'] == run))
    else:  # simulations
        for cell, alpha in example + split_example + examples:
            mask &= ~((data['cell'] == cell) & (data['contrast'] == alpha))
    sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                    s=4, marker='o', linewidth=0, edgecolors='none',
                    clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20)
    elw = 0.3
    if 'stimindex' in data:
        for cell, run in example:
            mask = (data['cell'] == cell) & (data['stimindex'] == run)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='^', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)
        for cell, run in split_example:
            mask = (data['cell'] == cell) & (data['stimindex'] == run)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='s', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)        
        for cell, run in examples:
            mask = (data['cell'] == cell) & (data['stimindex'] == run)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='o', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)
    else:  # simulations
        for cell, alpha in example:
            mask = (data['cell'] == cell) & (data['contrast'] == alpha)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='^', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)
        for cell, alpha in split_example:
            mask = (data['cell'] == cell) & (data['contrast'] == alpha)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='s', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)        
        for cell, alpha in examples:
            mask = (data['cell'] == cell) & (data['contrast'] == alpha)
            ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
                       s=6, marker='o', linewidth=elw, edgecolors='black',
                       clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
                       zorder=20)
    # color bar:
    fig = ax.get_figure()
    cax = ax.inset_axes([1.3, 0, 0.04, 1])
    cax.set_spines_outward('lrbt', 0)
    cb = fig.colorbar(sc, cax=cax)
    cb.outline.set_color('none')
    cb.outline.set_linewidth(0)
    # pdf x-axis:
    kde = gaussian_kde(data[xcol], 0.02*xmax/np.std(data[xcol], ddof=1))
    xx = np.linspace(0, ax.get_xlim()[1], 400)
    pdf = kde(xx)
    xax = ax.inset_axes([0, 1.05, 1, 0.2])
    xax.show_spines('')
    xax.fill_between(xx, pdf, facecolor=color, edgecolor='none')
    #xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5)
    xax.set_ylim(bottom=0)
    xax.set_ylim(0, xpdfmax)
    # pdf y-axis:
    kde = gaussian_kde(data[ycol], 0.02*ymax/np.std(data[ycol], ddof=1))
    xx = np.linspace(0, ax.get_ylim()[1], 400)
    pdf = kde(xx)
    yax = ax.inset_axes([1.05, 0, 0.2, 1])
    yax.show_spines('')
    yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none')
    #yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5)
    yax.set_xlim(left=0)
    # threshold:
    if 'cvbase' in xcol:
        ax.text(xmax, 0.4*ymax, f'{100*np.sum(data[ycol] > si_thresh)/len(data):.0f}\\%',
                ha='right', va='bottom', fontsize='small')
        ax.text(xmax, 0.3, f'{100*np.sum(data[ycol] < si_thresh)/len(data):.0f}\\%',
                ha='right', va='center', fontsize='small')
    # statistics:
    r, p = pearsonr(data[xcol], data[ycol])
    ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right',
            transform=ax.transAxes, fontsize='small')
    #ax.text(1, 0.77, f'{significance_str(p)}', ha='right',
    #        transform=ax.transAxes, fontsize='small')
    if 'cvbase' in xcol:
        ax.text(1, 0.77, f'$n={data.rows()}$', ha='right',
                transform=ax.transAxes, fontsize='small')
    print(f'    correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}')
    return cax


def si_stats(title, data, sicol, si_thresh, nsegscol):
    print(title)
    cells = np.unique(data['cell'])
    ncells = len(cells)
    nrecs = len(data)
    print(f'    cells:              {ncells}')
    print(f'    recordings:         {nrecs}')
    print(f'    SI threshold:       {si_thresh:.1f}')
    hcells = np.unique(data[data(sicol) > si_thresh, 'cell'])
    print(f'    high SI cells:      n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%')
    print(f'    high SI recordings: n={np.sum(data(sicol) > si_thresh):3d}, '
          f'{100*np.sum(data(sicol) > si_thresh)/nrecs:4.1f}%')
    nsegs = data[nsegscol]
    print(f'    number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}')
    nrecs = []
    for cell in cells:
        nrecs.append(len(data[data["cell"] == cell, :]))
    print(f'    number of recordings per cell: {np.min(nrecs):4.0f} - {np.max(nrecs):4.0f}, median={np.median(nrecs):4.0f}, mean={np.mean(nrecs):4.0f}, std={np.std(nrecs):4.0f}')
    fcutoff = data['fcutoff']
    print('    cutoff frequencies:', ' '.join([f'{f:3.0f}Hz' for f in np.unique(fcutoff)]))
    print('    cutoff frequencies:', '   '.join([f'{np.sum(fcutoff == f):3d}' for f in np.unique(fcutoff)]))
    print(f'    cutoff frequencies: {np.min(fcutoff):.0f}Hz - {np.max(fcutoff):.0f}Hz, median={np.median(fcutoff):.0f}Hz, mean={np.mean(fcutoff):.0f}Hz, std={np.std(fcutoff):.0f}Hz')
    contrasts = 100*data['contrast']
    print('    contrasts:         ', ' '.join([f'{c:.2g}%' for c in np.unique(contrasts)]))
    print(f'    contrasts:          {np.min(contrasts):.2g}% - {np.max(contrasts):.2g}%, median={np.median(contrasts):.2g}%, mean={np.mean(contrasts):.2g}%, std={np.std(contrasts):.2g}%')

    
def plot_cvbase_si_punit(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('CV$_{\\rm base}$')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 1.5)
    ax.set_ylim(0, 7.2)
    ax.set_yticks_delta(2)
    examples = punit_examples if 'stimindex' in data else model_examples
    cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3,
                    'coolwarm', color, si_thresh, *examples)
    cax.set_ylabel('Response mod.', 'Hz')

    
def plot_cvstim_si_punit(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('CV$_{\\rm stim}$')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 1.6)
    ax.set_ylim(0, 7.2)
    ax.set_xticks_delta(0.5)
    ax.set_yticks_delta(2)
    examples = punit_examples if 'stimindex' in data else model_examples
    #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2,
    #                'coolwarm', color, si_thresh, *examples)
    #cax.set_ylabel('Response mod.', 'Hz')
    cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2,
                    'coolwarm', color, si_thresh, *examples)
    cax.set_ylabel('CV$_{\\rm base}$')
    #cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2,
    #                'coolwarm', color, si_thresh, *examples)
    #cax.set_ylabel('$r$', 'Hz')
    #cax = plot_corr(ax, data, 'cvstim', ycol, 'serialcorr1', -0.6, 0, 2,
    #                'coolwarm', color, si_thresh, *examples)
    #cax.set_ylabel('$\\rho_1$')

    
def plot_rmod_si_punit(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('Response modulation', 'Hz')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 250)
    ax.set_ylim(0, 7.2)
    ax.set_yticks_delta(2)
    examples = punit_examples if 'stimindex' in data else model_examples
    cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016,
                    'coolwarm', color, si_thresh, *examples)
    cax.set_ylabel('CV$_{\\rm base}$')

    
def plot_cvbase_si_ampul(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('CV$_{\\rm base}$')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 0.2)
    ax.set_ylim(0, 15)
    ax.set_xticks_delta(0.1)
    ax.set_yticks_delta(5)
    cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20,
                    'coolwarm', color, si_thresh, *ampul_examples)
    cax.set_ylabel('Response mod.', 'Hz')

    
def plot_cvstim_si_ampul(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('CV$_{\\rm stim}$')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 0.85)
    ax.set_ylim(0, 15)
    ax.set_xticks_delta(0.2)
    ax.set_yticks_delta(5)
    #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6,
    #                'coolwarm', color, si_thresh, *ampul_examples)
    #cax.set_ylabel('Response mod.', 'Hz')
    cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6,
                    'coolwarm', color, si_thresh, *ampul_examples)
    cax.set_ylabel('CV$_{\\rm base}$')
    cax.set_yticks_delta(0.1)
    #cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6,
    #                'coolwarm', color, si_thresh, *ampul_examples)
    #cax.set_ylabel('$r$', 'Hz')
    #cax.set_yticks_delta(30)

    
def plot_rmod_si_ampul(ax, data, ycol, si_thresh, color):
    ax.set_xlabel('Response modulation', 'Hz')
    ax.set_ylabel('SI($r$)')
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 15)
    ax.set_xticks_delta(20)
    ax.set_yticks_delta(5)
    cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06,
                    'coolwarm', color, si_thresh, *ampul_examples)
    cax.set_ylabel('CV$_{\\rm base}$')
    cax.set_yticks_delta(0.1)

    
if __name__ == '__main__':
    punit_model = TableData(data_path /
                            'Apteronotus_leptorhynchus-Punit-models.csv',
                            sep=';')
    print(punit_model.keys())
    punit_model = punit_model[punit_model['contrast'] > 1e-6, :]
    punit_data = TableData(data_path /
                           'Apteronotus_leptorhynchus-Punit-data.csv',
                           sep=';')
    ampul_data = TableData(data_path /
                           'Apteronotus_leptorhynchus-Ampullary-data.csv',
                           sep=';')
    #si = ''
    si = '_nmax'
    si_thresh = 1.8

    cvmodel = punit_model['cvbase']
    cvdata = punit_data['cvbase']
    u, p = mannwhitneyu(cvmodel, cvdata)
    print('CV differs between P-unit models and data:')
    print(f'    U={u:g}, p={p:.2g}')
    print(f'    CV model: min={np.min(cvmodel):4.2f} max={np.max(cvmodel):4.2f} median={np.median(cvmodel):4.2f}')
    print(f'    CV data:  min={np.min(cvdata):4.2f} max={np.max(cvdata):.2f} median={np.median(cvdata):4.2f}')
    print()
    rmmodel = punit_model['respmod2']
    rmdata = punit_data['respmod2']
    u, p = mannwhitneyu(rmmodel, rmdata)
    print('Response modulation differs between P-unit models and data:')
    print(f'    U={u:g}, p={p:.2g}')
    print(f'    response modulation model: min={np.min(rmmodel):3.0f}Hz max={np.max(rmmodel):3.0f}Hz median={np.median(rmmodel):3.0f}Hz')
    print(f'    response modulation data:  min={np.min(rmdata):3.0f}Hz max={np.max(rmdata):3.0f}Hz median={np.median(rmdata):3.0f}Hz')
    print()
    simodel = punit_model['dsinorm100']
    sidata = punit_data['sinorm' + si]
    u, p = mannwhitneyu(simodel, sidata)
    print('SI does not differ between P-unit models and data:')
    print(f'    U={u:g}, p={p:.2g}')
    print(f'    SI model: min={np.min(simodel):4.1f} max={np.max(simodel):4.1f} median={np.median(simodel):4.1f}')
    print(f'    SI data:  min={np.min(sidata):4.1f} max={np.max(sidata):4.1f} median={np.median(sidata):4.1f}')
    print()
    
    s = plot_style()
    fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width),
                            height_ratios=[1, 0, 1, 0.3, 1])
    fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4,
                        wspace=1.1, hspace=0.6)
    
    si_stats('P-unit model:', punit_model, 'dsinorm100', si_thresh,
             'nsegs100')
    axs[0, 0].text(0, 1.35, 'P-unit models',
                   transform=axs[0, 0].transAxes, color=s.model_color1)
    plot_cvbase_si_punit(axs[0, 0], punit_model, 'dsinorm100', si_thresh,
                         s.model_color2)
    plot_rmod_si_punit(axs[0, 1], punit_model, 'dsinorm100', si_thresh,
                       s.model_color2)
    plot_cvstim_si_punit(axs[0, 2], punit_model, 'dsinorm100', si_thresh,
                         s.model_color2)
    print()

    si_stats('P-unit data:', punit_data, 'sinorm' + si, si_thresh,
             'nsegs' + si)
    axs[1, 0].text(0, 1.35, 'P-unit data',
                   transform=axs[1, 0].transAxes, color=s.punit_color1)
    plot_cvbase_si_punit(axs[1, 0], punit_data, 'sinorm' + si, si_thresh,
                         s.punit_color2)
    plot_rmod_si_punit(axs[1, 1], punit_data, 'sinorm' + si, si_thresh,
                       s.punit_color2)
    plot_cvstim_si_punit(axs[1, 2], punit_data, 'sinorm' + si, si_thresh,
                         s.punit_color2)
    print()

    si_stats('Ampullary data:', ampul_data, 'sinorm' + si, si_thresh,
             'nsegs' + si)
    axs[2, 0].text(0, 1.35, 'Ampullary data',
                   transform=axs[2, 0].transAxes, color=s.ampul_color1)
    plot_cvbase_si_ampul(axs[2, 0], ampul_data, 'sinorm' + si, si_thresh,
                         s.ampul_color2)
    plot_rmod_si_ampul(axs[2, 1], ampul_data, 'sinorm' + si, si_thresh,
                       s.ampul_color2)
    plot_cvstim_si_ampul(axs[2, 2], ampul_data, 'sinorm' + si, si_thresh,
                         s.ampul_color2)
    print()

    fig.common_xticks(axs[:2, 0])
    fig.common_xticks(axs[:2, 1])
    fig.common_xticks(axs[:2, 2])
    fig.common_yticks(axs[0, :])
    fig.common_yticks(axs[1, :])
    fig.common_yticks(axs[2, :])
    fig.tag(axs, xoffs=-3.5, yoffs=2)
    fig.savefig()