import matplotlib as mpl
import plottools.plottools as pt
from plottools.spines import spines_params
from plottools.labels import labels_params
from plottools.colors import lighter, darker


def significance_str(p):
    if p > 0.05:
        return f'$p={p:.2f}$'
    elif p > 0.01:
        return '$p<0.05$'
    elif p > 0.001:
        return '$p<0.01$'
    else:
        return '$p<0.001$'


def plot_style():
    palette = pt.palettes['muted']
    lwthick = 1.6
    lwmid = 1.0
    lwthin = 0.5
    lwspines = 1.0
    names = ['A1', 'A2', 'A3',
             'B1', 'B2', 'B3', 'B4',
             'C1', 'C2', 'C3', 'C4']
    colors = [palette['red'], palette['orange'], palette['yellow'],
              palette['blue'], palette['purple'], palette['magenta'], palette['lightblue'],
              palette['lightgreen'], palette['green'], palette['darkgreen'], palette['cyan']]
    dashes = ['-', '-', '-',
              '-', '-', '-', '-.',
              '-', '-', '-', '-']
    markers = [('o', 1.0), ('p', 1.1), ('h', 1.1),
               ((3, 1, 60), 1.25), ((3, 1, 0), 1.25), ((3, 1, 90), 1.25), ((3, 1, 30), 1.25),
               ('s', 0.9), ('D', 0.85), ('*', 1.6), ((4, 1, 45), 1.4)]
    class ns: pass
    ns.colors = palette
    ns.lwthick = lwthick
    ns.lwmid = lwmid
    ns.lwthin = lwthin
    ns.plot_width = 16.5
    pt.make_linepointfill_styles(ns, names, colors, dashes, markers,
                                 lwthick=lwthick, lwthin=lwthin,
                                 markerlarge=5, markersmall=4,
                                 mec=0.0, mew=0.5, fillalpha=0.4)
    pt.make_line_styles(ns, 'ls', 'Spine', '', palette['black'], '-',
                        lwspines, clip_on=False)
    pt.make_line_styles(ns, 'ls', 'Grid', '', palette['gray'], '--',
                        0.7*lwthin)
    pt.make_line_styles(ns, 'ls', 'Dotted', '', palette['gray'], ':',
                        0.7*lwthin)
    pt.make_line_styles(ns, 'ls', 'Marker', '', palette['black'], '-',
                        lwthick, clip_on=False)
    pt.make_line_styles(ns, 'ls', 'Line', '', palette['black'], '-',
                        lwthin)
    pt.make_line_styles(ns, 'ls', 'EOD', '', palette['gray'], '-', lwthin)
    pt.make_line_styles(ns, 'ls', 'AM', '', palette['red'], '-', lwmid)
    pt.make_line_styles(ns, 'ls', 'AMsplit', '', palette['orange'], '-', lwmid)
    pt.make_line_styles(ns, 'ls', 'Noise', '', palette['gray'], '-', lwmid)
    pt.make_line_styles(ns, 'ls', 'Median', '', palette['black'], '-', lwthick)
    pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '-', lwmid)

    ns.lsStim = dict(color='gray', lw=ns.lwmid)
    ns.lsRaster = dict(color='black', lw=ns.lwthin)
    ns.lsPower = dict(color='gray', lw=ns.lwmid)
    ns.lsF0 = dict(color='blue', lw=ns.lwthick)
    ns.lsF01 = dict(color='green', lw=ns.lwthick)
    ns.lsF02 = dict(color='purple', lw=ns.lwthick)
    ns.lsF012 = dict(color='orange', lw=ns.lwthick)
    ns.lsF01_2 = dict(color='red', lw=ns.lwthick)
    ns.lsF0m = dict(color=lighter('blue', 0.5), lw=ns.lwthin)
    ns.lsF01m = dict(color=lighter('green', 0.6), lw=ns.lwthin)
    ns.lsF02m = dict(color=lighter('purple', 0.5), lw=ns.lwthin)
    ns.lsF012m = dict(color=darker('orange', 0.9), lw=ns.lwthin)
    ns.lsF01_2m = dict(color=darker('red', 0.9), lw=ns.lwthin)

    ns.psFEOD = dict(color='black', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    ns.psF0 = dict(color='blue', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    ns.psF01 = dict(color='green', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    ns.psF02 = dict(color='purple', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    ns.psF012 = dict(color='orange', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    ns.psF01_2 = dict(color='red', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
    
    ns.model_color1 = palette['purple']
    ns.model_color2 = lighter(ns.model_color1, 0.6)
    ns.punit_color1 = palette['blue']
    ns.punit_color2 = lighter(ns.punit_color1, 0.6)
    ns.ampul_color1 = palette['green']
    ns.ampul_color2 = lighter(ns.ampul_color1, 0.6)
    pt.make_line_styles(ns, 'ls', 'M%d', '', [ns.model_color1, ns.model_color2], '-', lwmid)
    pt.make_point_styles(ns, 'ps', 'M%d', '', [ns.model_color1, ns.model_color2], '-', markers=('o', 1), markersizes=4)
    pt.make_line_styles(ns, 'ls', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', lwmid)
    pt.make_point_styles(ns, 'ps', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', markers=('o', 1), markersizes=4)
    pt.make_line_styles(ns, 'ls', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', lwmid)
    pt.make_point_styles(ns, 'ps', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', markers=('o', 1), markersizes=4)
    pt.make_line_styles(ns, 'ls', 'Diag', '', palette['white'], '--', lwthin)
    
    pt.arrow_style(ns, 'Line', dist=3.0, style='>', shrink=0, lw=0.6,
                   color=palette['black'], head_length=4, head_width=4,
                   bbox=dict(boxstyle='round,pad=0.1', facecolor='white',
                             edgecolor='none', alpha=1.0))
    pt.arrow_style(ns, 'Hertz', dist=3.0, style='>', shrink=0, lw=0.6,
                   color=palette['black'], head_length=3, head_width=3,
                   heads='<>', text='%.0f\u2009Hz', rotation='vertical',
                   fontsize='x-small',
                   bbox=dict(boxstyle='round,pad=0.1', facecolor='white',
                             edgecolor='none', alpha=0.6))
    pt.arrow_style(ns, 'Point', dist=3.0, style='>>', shrink=0,
                   lw=1.1, color=palette['black'], head_length=5,
                   head_width=4)
    pt.arrow_style(ns, 'PointSmall', dist=3.0, style='>>', shrink=0,
                   lw=1, color=palette['black'], head_length=5,
                   head_width=5, fontsize='small')
    pt.arrow_style(ns, 'Marker', dist=3.0, style='>>', shrink=0,
                   lw=0.9, color=palette['black'], head_length=5,
                   head_width=5, fontsize='small', ha='center',
                   va='center',
                   bbox=dict(boxstyle='round, pad=0.1, rounding_size=0.4',
                             facecolor=palette['white'],
                             edgecolor='none', alpha=0.6))
    
    # rc settings:
    mpl.rcdefaults()
    pt.axes_params(0.0, 0.0, 0.0, color='none')
    cmcolors = [pt.lighter(palette['yellow'], 0.2),
                pt.lighter(palette['orange'], 0.5), palette['orange'],
                palette['red'], palette['red']]
    cmvalues = [0.0, 0.3, 0.7, 0.95, 1.0]
    pt.colormap('YR', cmcolors, cmvalues)
    cycle_colors = ['blue', 'red', 'orange', 'lightgreen', 'magenta',
                    'yellow', 'cyan', 'pink']
    pt.colors_params(palette, cycle_colors, 'viridis')
    pt.figure_params(palette['white'], format='pdf', compression=6,
                     fonttype=3, stripfonts=True)
    pt.labels_params('{label} [{unit}]', labelsize='medium', labelpad=6)
    pt.legend_params(fontsize='medium', frameon=False, borderpad=0.0,
                     handlelength=1.5, handletextpad=1,
                     numpoints=1, scatterpoints=1,
                     labelspacing=0.5, columnspacing=0.5)
    pt.scalebars_params(format_large='%.0f', format_small='%.1f',
                        lw=2.2, color=palette['black'], capsize=0,
                        clw=0.5, font=dict(fontsize='medium',
                                           fontstyle='normal'))
    pt.spines_params(spines='lb', spines_offsets={'lrtb': 3},
                     spines_bounds={'lrtb': 'full'})
    pt.tag_params(xoffs='auto', yoffs='auto', label='%A',
                  minor_label=r'%A$_{\text{%mi}}$',
                  font=dict(fontsize='x-large', fontstyle='normal',
                            fontweight='normal'))
    pt.text_params(font_size=7, font_family='sans-serif', latex=True,
                   preamble=('p:sfmath', 'p:marvosym', 'p:xfrac',
                             'p:SIunits'))
    pt.ticks_params(xtick_dir='out', xtick_size=3)
    return ns