import glob
import os
import nixio as nix
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.collections import PatchCollection
from matplotlib.patches import ConnectionPatch

from nix_util import sort_blocks, read_baseline, get_signals
from util import despine

figure_folder = "figures"
data_folder = "data"


def plot_comparisons(current_df=20):
    files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
    if len(files) < 1:
        print("plot comparisons: no data found!")
        return
    filename = files[0]
    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
    block_map, all_contrasts, _, all_conditions  = sort_blocks(nf)
    conditions = ["no-other", "self", "other"]
    min_time = 0.5
    max_time = min_time + 0.5

    fig = plt.figure(figsize=(6.5, 2.))
    fig_grid = (3, len(all_conditions)*3+2)
    axes = []
    for i, condition in enumerate(conditions):
        # plot the signals
        block = block_map[(all_contrasts[0], current_df, condition)]
        _, self_freq, other_freq, time = get_signals(block)
        
        self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
        other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
        
        # plot frequency traces
        ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
                color="#ff7f0e", label="%iHz" % self_eodf)
        ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
        if other_freq is not None:
            ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
                    color="#1f77b4", label="%iHz" % other_eodf)
            ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)       
        # ax.set_title(condition_labels[i])
        ax.set_ylim([735, 885])
        despine(ax, ["top", "bottom", "left", "right"], True)
        axes.append(ax)

    rects = []
    rect = Rectangle((0.675, 740), 0.098, 140)
    rects.append(rect)
    rect = Rectangle((0.57, 740), 0.098, 140)
    rects.append(rect)
       
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[0].add_collection(pc)
        
    rects = []
    rect = Rectangle((0.675, 740), 0.098, 140)
    rects.append(rect)
    rect = Rectangle((0.575, 740), 0.098, 140)
    rects.append(rect)
       
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[1].add_collection(pc)

    rects = []
    rect = Rectangle((0.57, 740), 0.098, 140)
    rects.append(rect)   
    pc = PatchCollection(rects, facecolor=None, alpha=0.15, edgecolor="k", ls="--")
    axes[2].add_collection(pc)

    con = ConnectionPatch(xyA=(0.625, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
    axes[1].add_artist(con)
    con = ConnectionPatch(xyA=(0.725, 885), xyB=(0.725, 880), coordsA="data", coordsB="data",
                          axesA=axes[0], axesB=axes[1], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=-.25")
    axes[1].add_artist(con)
    con = ConnectionPatch(xyA=(0.725, 735), xyB=(0.625, 740), coordsA="data", coordsB="data",
                          axesA=axes[1], axesB=axes[2], arrowstyle="<->", shrinkB=5, connectionstyle="arc3,rad=.35")
    axes[1].add_artist(con)

    axes[0].text(1., 660, "2.")
    axes[1].text(1.05, 660, "3.")
    axes[0].text(1.1, 890, "1.")
    fig.subplots_adjust(bottom=0.1, top=0.8, left=0.1, right=0.9)
    fig.savefig(os.path.join(figure_folder, "comparisons.pdf"))
    plt.close()
    nf.close()


def create_response_plot(filename, current_df=20, figure_name=None):
    files = sorted(glob.glob(os.path.join(data_folder, "*.nix")))
    if len(files) < 1:
        print("plot comparisons: no data found!")
        return
    filename = files[0]
    nf = nix.File.open(filename, nix.FileMode.ReadOnly)
    block_map, all_contrasts, _, all_conditions  = sort_blocks(nf)

    conditions = ["no-other", "self", "other"]
    condition_labels = ["soliloquy", "self chirping", "other chirping"]
    min_time = 0.5
    max_time = min_time + 0.5

    fig = plt.figure(figsize=(6.5, 5.5))
    fig_grid = (len(all_contrasts)*2 + 6, len(all_conditions)*3+2)
    all_contrasts = sorted(all_contrasts, reverse=True)

    for i, condition in enumerate(conditions):
        # plot the signals
        block = block_map[(all_contrasts[0], current_df, condition)]
        signal, self_freq, other_freq, time = get_signals(block)
        am = extract_am(signal)
        
        self_eodf = block.metadata["stimulus parameter"]["eodfs"]["self"]
        other_eodf = block.metadata["stimulus parameter"]["eodfs"]["other"]
        
        # plot frequency traces
        ax = plt.subplot2grid(fig_grid, (0, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], self_freq[(time > min_time) & (time < max_time)],
                color="#ff7f0e", label="%iHz" % self_eodf)
        ax.text(min_time-0.05, self_eodf, "%iHz" % self_eodf, color="#ff7f0e", va="center", ha="right", fontsize=9)
        if other_freq is not None:
            ax.plot(time[(time > min_time) & (time < max_time)], other_freq[(time > min_time) & (time < max_time)],
                    color="#1f77b4", label="%iHz" % other_eodf)
            ax.text(min_time-0.05, other_eodf, "%iHz" % other_eodf, color="#1f77b4", va="center", ha="right", fontsize=9)       
        ax.set_title(condition_labels[i])
        despine(ax, ["top", "bottom", "left", "right"], True)
        
        # plot the am
        ax = plt.subplot2grid(fig_grid, (3, i * 3 + i), rowspan=2, colspan=3, fig=fig)
        ax.plot(time[(time > min_time) & (time < max_time)], signal[(time > min_time) & (time < max_time)],
                color="#2ca02c", label="signal")
        ax.plot(time[(time > min_time) & (time < max_time)], am[(time > min_time) & (time < max_time)],
                color="#d62728", label="am")
        despine(ax, ["top", "bottom", "left", "right"], True)
        ax.set_ylim([-1.25, 1.25])
        ax.legend(ncol=2, loc=(0.01, -0.5), fontsize=7, markerscale=0.5, frameon=False)
   
        # for each contrast plot the firing rate
        for j, contrast in enumerate(all_contrasts):
            t, rates, _ = get_firing_rate(block_map, current_df, contrast, condition)
            avg_resp = np.mean(rates, axis=0)
            error = np.std(rates, axis=0)
            ax = plt.subplot2grid(fig_grid, (j*2 + 6, i * 3 + i), rowspan=2, colspan=3)
            ax.plot(t[(t > min_time) & (t < max_time)], avg_resp[(t > min_time) & (t < max_time)], color="k", lw=0.5)
            ax.fill_between(t[(t > min_time) & (t < max_time)], (avg_resp - error)[(t > min_time) & (t < max_time)],
                            (avg_resp + error)[(t > min_time) & (t < max_time)], color="k", lw=0.0, alpha=0.25)
            ax.set_ylim([0, 750])
            ax.set_xlabel("")
            ax.set_ylabel("")
            ax.set_xticks(np.arange(min_time, max_time+.01, 0.250))
            ax.set_xticklabels(map(int, (np.arange(min_time, max_time + .01, 0.250) - min_time) * 1000))
            ax.set_xticks(np.arange(min_time, max_time+.01, 0.125), minor=True)
            if j < len(all_contrasts) -1:
                ax.set_xticklabels([])
            ax.set_yticks(np.arange(0.0, 751., 500)) 
            ax.set_yticks(np.arange(0.0, 751., 125), minor=True)   
            if i > 0:
                ax.set_yticklabels([])
            despine(ax, ["top", "right"], False)
            if i == 2:
                ax.text(max_time + 0.025*max_time, 350, "c=%.3f" % all_contrasts[j], 
                        color="#d62728", ha="left", fontsize=7)

        if i == 1:
            ax.set_xlabel("time [ms]")
        if i == 0:
            ax.set_ylabel("frequency [Hz]", va="center")
            ax.yaxis.set_label_coords(-0.45, 3.5)
        
    name = figure_name if figure_name is not None else "chirp_responses.pdf"
    name = (name + ".pdf") if ".pdf" not in name else name
    plt.savefig(os.path.join(figure_folder, name))
    plt.close()
    nf.close()

def response_examples(*kwargs):
    
    if filename in kwargs:

    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
    
def main(task=None, parameter={}):
    plot_tasks = {"comparisons": plot_comparisons,
                  "response_examples": create_response_plot}
    if task is not None and task in plot_tasks.keys():
        plot_tasks[task](*parameter)
    elif task is None:
        for t in plot_tasks.keys():
            plot_tasks[t](*parameter)


if __name__ == "__main__":
    main("comparisons")


def plot_examples(filename, dfs=[], contrasts=[], conditions=[]):
    # plot the responses 
    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_20Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, 20, figure_name=fig_name)
    #fig_name = filename.split(os.path.sep)[-1].split(".nix")[0] + "_df_-100Hz.pdf"
    #create_response_plot(block_map, all_dfs, all_contrasts, all_conditions, -100, figure_name=fig_name)
    
    # sketch showing the comparisons
    #plot_comparisons(block_map, all_dfs, all_contrasts, all_conditions, 20)

    # plot the discrimination analyses
    #cell_name = filename.split(os.path.sep)[-1].split(".nix")[0]
    # results = foreign_fish_detection(block_map, all_dfs, all_contrasts, all_conditions, current_df=20, 
    #                                  cell_name=cell_name, store_roc=True)
    # pdf = pd.DataFrame(results)
    # plot_detection_results(pdf, 20, 0.001, cell_name)
    
    #nf.close()    
    pass