import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind, mannwhitneyu


def auc(n, dx, uniform=False, plot=False):
    # loser:
    if uniform:
        x0 = np.random.rand(n)
    else:
        x0 = np.random.randn(n)*0.3
    y0 = np.zeros(len(x0))
    # winner:
    if uniform:
        x1 = np.random.rand(n) + dx
    else:
        x1 = np.random.randn(n)*0.3 + dx
    y1 = np.ones(len(x1))

    # combine into a single table:
    data = np.zeros((len(x0) + len(y0), 2))
    data[:len(x0),0] = x0
    data[:len(x0),1] = y0
    data[len(x0):,0] = x1
    data[len(x0):,1] = y1

    # fraction of overlapping data values:
    si = np.argsort(data[:,0])
    i0 = np.argmax(data[si,1] != data[si[0],1])
    i1 = len(data) - 1 - np.argmax(data[si[::-1],1] != data[si[-1],1])
    overlap = (i1-i0+1)/len(data)

    # Cohen's d:
    m0 = np.mean(data[data[:,1] < 0.5,0])
    v0 = np.var(data[data[:,1] < 0.5,0])
    m1 = np.mean(data[data[:,1] > 0.5,0])
    v1 = np.var(data[data[:,1] > 0.5,0])
    cohensd = (m1 - m0)/np.sqrt(0.5*(v0+v1))

    # t-test:
    ttest, p = ttest_ind(data[data[:,1] > 0.5,0], data[data[:,1] < 0.5,0])

    # Mann-Whitney U:
    mannu, p = mannwhitneyu(data[data[:,1] < 0.5,0], data[data[:,1] > 0.5,0])

    # ROC:
    thresh = np.arange(np.min(data[:,0])-0.1, np.max(data[:,0])+0.2, 0.01)
    true_pos = np.zeros(len(thresh))
    false_pos = np.zeros(len(thresh))
    for k in range(len(thresh)):
        true_pos[k] = np.sum(data[data[:,0] > thresh[k],1] > 0.5)/np.sum(data[:,1] > 0.5)
        false_pos[k] = np.sum(data[data[:,0] > thresh[k],1] < 0.5)/np.sum(data[:,1] < 0.5)

    # AUC:
    droc = 0.001
    xroc = np.arange(0.0, 1.0+droc, droc)
    yroc = np.interp(xroc, false_pos[::-1], true_pos[::-1])
    auc = np.sum(yroc)*droc

    if plot:
        fig = plt.figure()
        ax = fig.add_subplot(211)
        ax.axvline(data[si[i0],0], color='k')
        ax.axvline(data[si[i1],0], color='k', lw=2)
        ax.plot(data[:,0], data[:,1], 'o')
        ax.plot(data[data[:,1] < 0.5,0], np.zeros(len(data[data[:,1] < 0.5,0]))-0.5, 'or')
        ax.plot(data[data[:,1] > 0.5,0], np.zeros(len(data[data[:,1] > 0.5,0]))-0.5, 'og')
        ax.text(0.5*(data[si[i0],0]+data[si[i1],0]), 0.65, 'overlap=%.0f%%' % (100.0*overlap), ha='center')
        ax.text(0.5*(data[si[i0],0]+data[si[i1],0]), 0.35, "Cohen's d=%.2f" % cohensd, ha='center')
        ax.set_xlabel('x')
        ax.set_yticks([0, 1])
        ax.set_yticklabels(['Lose', 'Win'])
        if uniform:
            ax.set_title('Uniformly distributed data')
        else:
            ax.set_title('Normally distributed data')
        
        ax = fig.add_subplot(223)
        ax.plot(thresh, true_pos, '-og', label='TP')
        ax.plot(thresh, false_pos, '-or', label='FP')
        ax.legend()
        ax.set_xlabel('threshold')
        
        ax = fig.add_subplot(224)
        ax.plot(false_pos, true_pos, '-o')
        ax.fill_between(xroc, yroc)
        ax.text(0.5, 0.5, 'AUC=%.0f%%' % (100.0*auc))
        ax.set_xlabel('FP')
        ax.set_ylabel('TP')
        fig.tight_layout()
        plt.show()

    return auc, overlap, cohensd, ttest, mannu


# demo:
auc(20, 0.5, True, True)
auc(20, 0.5, False, True)

# AUC versus overlap:
n = 100
aucs_uni = []
overlaps_uni = []
cohensd_uni = []
ttest_uni = []
mannu_uni = []
aucs_norm = []
overlaps_norm = []
cohensd_norm = []
ttest_norm = []
mannu_norm = []
for frac in np.arange(-1.5, 1.5, 0.02):
    a, o, d, t, u = auc(n, frac, True, False)
    aucs_uni.append(a)
    overlaps_uni.append(o)
    cohensd_uni.append(d)
    ttest_uni.append(t)
    mannu_uni.append(u)
    a, o, d, t, u = auc(n, frac, False, False)
    aucs_norm.append(a)
    overlaps_norm.append(o)
    cohensd_norm.append(d)
    ttest_norm.append(t)
    mannu_norm.append(u)
    
fig, axs = plt.subplots(2, 2)
ax = axs[0, 0]
ax.plot([0.0, 1.0, 0.0], [0.0, 0.5, 1.0], 'k')
ax.plot(overlaps_uni, aucs_uni, 'o', label='uniform pdfs')
ax.plot(overlaps_norm, aucs_norm, 'o', label='normal pdfs')
ax.set_ylim(0, 1)
ax.set_xlabel('fraction of overlapping data')
ax.set_ylabel('AUC')
ax.legend(loc='center left')
ax = axs[0, 1]
ax.plot(cohensd_uni, aucs_uni, 'o', label='uniform pdfs')
ax.plot(cohensd_norm, aucs_norm, 'o', label='normal pdfs')
ax.set_ylim(0, 1)
ax.set_xlabel("Cohen's d")
ax.set_ylabel('AUC')
#ax.legend(loc='center left')
ax = axs[1, 1]
ax.plot(ttest_uni, aucs_uni, 'o', label='uniform pdfs')
ax.plot(ttest_norm, aucs_norm, 'o', label='normal pdfs')
ax.set_ylim(0, 1)
ax.set_xlabel("Student t")
ax.set_ylabel('AUC')
ax = axs[1, 0]
ax.plot(mannu_uni, aucs_uni, 'o', label='uniform pdfs')
ax.plot(mannu_norm, aucs_norm, 'o', label='normal pdfs')
ax.set_ylim(0, 1)
ax.set_xlabel("Mann-Whitney U")
ax.set_ylabel('AUC')
fig.savefig('aucoverlap.pdf')
plt.show()