import matplotlib as mpl import matplotlib.pyplot as plt from cycler import cycler from mpl_toolkits.mplot3d import Axes3D xkcd_style = False # default size of figure: figure_width = 15.0 # cm, should be set according to \textwidth in the latex document figure_height = 6.0 # cm, for a 1 x 2 figure # points per inch: ppi = 72.0 # colors: def lighter(color, lightness): """ Make a color lighter. Parameters ---------- color: string An RGB color as a hexadecimal string (e.g. '#rrggbb'). lightness: float The smaller the lightness, the lighter the returned color. A lightness of 1 leaves the color untouched. A lightness of 0 returns white. Returns ------- color: string The lighter color as a hexadecimal RGB string (e.g. '#rrggbb'). """ r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) rl = r + (1.0-lightness)*(0xff - r) gl = g + (1.0-lightness)*(0xff - g) bl = b + (1.0-lightness)*(0xff - b) return '#%02X%02X%02X' % (rl, gl, bl) def darker(color, saturation): """ Make a color darker. Parameters ---------- color: string An RGB color as a hexadecimal string (e.g. '#rrggbb'). saturation: float The smaller the saturation, the darker the returned color. A saturation of 1 leaves the color untouched. A saturation of 0 returns black. Returns ------- color: string The darker color as a hexadecimal RGB string (e.g. '#rrggbb'). """ r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) rd = r * saturation gd = g * saturation bd = b * saturation return '#%02X%02X%02X' % (rd, gd, bd) # colors: colors = { 'red': '#CC0000', 'orange': '#FF9900', 'lightorange': '#FFCC00', 'yellow': '#FFFF66', 'green': '#99FF00', 'blue': '#0000CC' } """ Muted colors used by the Benda-lab. """ colors_bendalab = {} colors_bendalab['red'] = '#C02010' colors_bendalab['orange'] = '#F78010' colors_bendalab['yellow'] = '#F0D730' colors_bendalab['green'] = '#A0B717' colors_bendalab['cyan'] = '#40A787' colors_bendalab['blue'] = '#2757A0' colors_bendalab['purple'] = '#573790' colors_bendalab['pink'] = '#C72750' colors_bendalab['grey'] = '#A0A0A0' colors_bendalab['black'] = '#000000' """ Vivid colors used by the Benda-lab. """ colors_bendalab_vivid = {} colors_bendalab_vivid['red'] = '#D71000' colors_bendalab_vivid['orange'] = '#FF9000' colors_bendalab_vivid['yellow'] = '#FFF700' colors_bendalab_vivid['green'] = '#30D700' colors_bendalab_vivid['cyan'] = '#00F0B0' colors_bendalab_vivid['blue'] = '#0020C0' colors_bendalab_vivid['purple'] = '#B000B0' colors_bendalab_vivid['pink'] = '#F00080' colors_bendalab_vivid['grey'] = '#A7A7A7' colors_bendalab_vivid['black'] = '#000000' # colors for the plots of the script: colors = colors_bendalab_vivid colors['lightorange'] = colors['yellow'] #colors['yellow'] = lighter(colors['yellow'], 0.65) colors['yellow'] = '#FFFF55' # line styles for plot(): lsSpine = {'c': colors['black'], 'linestyle': '-', 'linewidth': 1} lsGrid = {'c': colors['grey'], 'linestyle': '--', 'linewidth': 1} # 'B1': prominent line with first color and style from color group 'B' # 'C2m': minor line with second color and style from color group 'C' ls = { 'A1': {'c': colors['red'], 'linestyle': '-', 'linewidth': 3}, 'A2': {'c': colors['orange'], 'linestyle': '-', 'linewidth': 3}, 'A3': {'c': colors['lightorange'], 'linestyle': '-', 'linewidth': 3}, 'B1': {'c': colors['orange'], 'linestyle': '-', 'linewidth': 3}, 'B2': {'c': colors['lightorange'], 'linestyle': '-', 'linewidth': 3}, 'B3': {'c': colors['yellow'], 'linestyle': '-', 'linewidth': 3}, 'C1': {'c': colors['green'], 'linestyle': '-', 'linewidth': 3}, 'D1': {'c': colors['blue'], 'linestyle': '-', 'linewidth': 3}, 'A1m': {'c': colors['red'], 'linestyle': '-', 'linewidth': 2}, 'A2m': {'c': colors['orange'], 'linestyle': '-', 'linewidth': 2}, 'A3m': {'c': colors['lightorange'], 'linestyle': '-', 'linewidth': 2}, 'B1m': {'c': colors['orange'], 'linestyle': '-', 'linewidth': 2}, 'B2m': {'c': colors['lightorange'], 'linestyle': '-', 'linewidth': 2}, 'B3m': {'c': colors['yellow'], 'linestyle': '-', 'linewidth': 2}, 'C1m': {'c': colors['green'], 'linestyle': '-', 'linewidth': 2}, 'D1m': {'c': colors['blue'], 'linestyle': '-', 'linewidth': 2}, } # factor for scaling widths of bars in a bar plot: bar_fac = 1.0 def cm_size(*args): """ Convert dimensions from cm to inch. Use this function to set the size of a figure in centimeter: ``` fig = plt.figure(figsize=cm_size(16.0, 10.0)) ``` Parameters ---------- args: one or many float Size in centimeter. Returns ------- inches: float or list of floats Input arguments converted to inch. """ cm_per_inch = 2.54 if len(args) == 1: return args[0]/cm_per_inch else: return [v/cm_per_inch for v in args] def adjust_fs(fig=None, left=5.5, right=0.5, bottom=2.8, top=0.5): """ Compute plot margins from multiples of the current font size. Parameters ---------- fig: matplotlib.figure or None The figure from which the figure size is taken. If None use the current figure. left: float the left margin of the plots given in multiples of the width of a character (in fact, simply 60% of the current font size). right: float the right margin of the plots given in multiples of the width of a character (in fact, simply 60% of the current font size). *Note:* in contrast to the matplotlib `right` parameters, this specifies the width of the right margin, not its position relative to the origin. bottom: float the bottom margin of the plots given in multiples of the height of a character (the current font size). top: float the right margin of the plots given in multiples of the height of a character (the current font size). *Note:* in contrast to the matplotlib `top` parameters, this specifies the width of the top margin, not its position relative to the origin. Example ------- ``` fig, axs = plt.subplots(2, 2, figsize=(10, 5)) fig.subplots_adjust(**adjust_fs(fig, left=4.5)) # no matter what the figsize is! ``` """ if fig is None: fig = plt.gcf() w, h = fig.get_size_inches()*ppi fs = plt.rcParams['font.size'] return { 'left': left*0.6*fs/w, 'right': 1.0 - right*0.6*fs/w, 'bottom': bottom*fs/h, 'top': 1.0 - top*fs/h } def show_spines(ax, spines='lb'): """ Show and hide spines. Parameters ---------- ax: matplotlib figure, matplotlib axis, or list of matplotlib axes Axis whose spines and ticks are manipulated. If figure, then apply manipulations on all axes of the figure. If list of axes, apply manipulations on each of the given axes. spines: string Specify which spines and ticks should be shown. All other ones or hidden. 'l' is the left spine, 'r' the right spine, 't' the top one and 'b' the bottom one. E.g. 'lb' shows the left and bottom spine, and hides the top and and right spines, as well as their tick marks and labels. '' shows no spines at all. 'lrtb' shows all spines and tick marks. """ # collect spine visibility: xspines = [] if 't' in spines: xspines.append('top') if 'b' in spines: xspines.append('bottom') yspines = [] if 'l' in spines: yspines.append('left') if 'r' in spines: yspines.append('right') # collect axes: if isinstance(ax, (list, tuple)): axs = ax else: axs = ax.get_axes() if not isinstance(axs, (list, tuple)): axs = [axs] for ax in axs: # hide spines: if not 'top' in xspines: ax.spines['top'].set_visible(False) if not 'bottom' in xspines: ax.spines['bottom'].set_visible(False) if not 'left' in yspines: ax.spines['left'].set_visible(False) if not 'right' in yspines: ax.spines['right'].set_visible(False) # ticks: if len(xspines) == 0: ax.xaxis.set_ticks_position('none') ax.set_xticks([]) elif len(xspines) == 1: ax.xaxis.set_ticks_position(xspines[0]) else: ax.xaxis.set_ticks_position('both') if len(yspines) == 0: ax.yaxis.set_ticks_position('none') ax.set_yticks([]) elif len(yspines) == 1: ax.yaxis.set_ticks_position(yspines[0]) else: ax.yaxis.set_ticks_position('both') def __axes__init__(ax, *args, **kwargs): """ Set some default formatting for a new Axes instance. """ ax.__init__orig(*args, **kwargs) ax.show_spines('lb') def axis_label(label, unit=None): """ Format an axis label from a label and a unit Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. Returns ------- label: string An axis label formatted from `label` and `unit`. """ if not unit: return label elif xkcd_style: return '%s / %s' % (label, unit) else: return '%s [%s]' % (label, unit) def set_xlabel(ax, label, unit=None, **kwargs): """ Format the xlabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_xlabel() function. """ ax.set_xlabel_orig(axis_label(label, unit), **kwargs) def set_ylabel(ax, label, unit=None, **kwargs): """ Format the ylabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_ylabel() function. """ ax.set_ylabel_orig(axis_label(label, unit), **kwargs) def set_zlabel(ax, label, unit=None, **kwargs): """ Format the zlabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_zlabel() function. """ ax.set_zlabel_orig(axis_label(label, unit), **kwargs) def common_format(): """ Set some rc parameter. """ mpl.rcParams['figure.figsize'] = cm_size(figure_width, figure_height) mpl.rcParams['figure.subplot.left'] = 5.5*0.6*mpl.rcParams['font.size']/cm_size(figure_width)/ppi mpl.rcParams['figure.subplot.right'] = 1.0 - 0.5*0.6*mpl.rcParams['font.size']/cm_size(figure_width)/ppi mpl.rcParams['figure.subplot.bottom'] = 2.8*mpl.rcParams['font.size']/cm_size(figure_height)/ppi mpl.rcParams['figure.subplot.top'] = 1.0 - 0.5*mpl.rcParams['font.size']/cm_size(figure_height)/ppi mpl.rcParams['figure.subplot.wspace'] = 0.4 mpl.rcParams['figure.subplot.hspace'] = 0.6 mpl.rcParams['figure.facecolor'] = 'white' mpl.rcParams['xtick.direction'] = 'out' mpl.rcParams['ytick.direction'] = 'out' mpl.rcParams['xtick.major.width'] = 1.25 mpl.rcParams['ytick.major.width'] = 1.25 mpl.rcParams['grid.color'] = lsGrid['c'] mpl.rcParams['grid.linestyle'] = lsGrid['linestyle'] mpl.rcParams['grid.linewidth'] = lsGrid['linewidth'] mpl.rcParams['axes.facecolor'] = 'none' mpl.rcParams['axes.edgecolor'] = lsSpine['c'] mpl.rcParams['axes.linewidth'] = lsSpine['linewidth'] if 'axes.prop_cycle' in mpl.rcParams: mpl.rcParams['axes.prop_cycle'] = cycler(color=[colors['blue'], colors['red'], colors['orange'], colors['green'], colors['purple'], colors['yellow'], colors['cyan'], colors['pink']]) else: mpl.rcParams['axes.color_cycle'] = [colors['blue'], colors['red'], colors['orange'], colors['green'], colors['purple'], colors['yellow'], colors['cyan'], colors['pink']] # overwrite axes constructor: if not hasattr(mpl.axes.Subplot, '__init__orig'): mpl.axes.Subplot.__init__orig = mpl.axes.Subplot.__init__ mpl.axes.Subplot.__init__ = __axes__init__ mpl.axes.Axes.show_spines = show_spines # overwrite axes set_[xyz]label() member functions: if not hasattr(mpl.axes.Axes, 'set_xlabel_orig'): mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel mpl.axes.Axes.set_xlabel = set_xlabel if not hasattr(mpl.axes.Axes, 'set_ylabel_orig'): mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel mpl.axes.Axes.set_ylabel = set_ylabel if not hasattr(Axes3D, 'set_zlabel_orig'): Axes3D.set_zlabel_orig = Axes3D.set_zlabel Axes3D.set_zlabel = set_zlabel def sketch_style(): """ Activate xkcd style and adapt some rc parameter. """ global bar_fac bar_fac = 0.9 plt.xkcd() common_format() mpl.rcParams['legend.fontsize'] = 'medium' mpl.rcParams['xtick.labelsize'] = 'medium' mpl.rcParams['ytick.labelsize'] = 'medium' mpl.rcParams['xtick.major.size'] = 6 mpl.rcParams['ytick.major.size'] = 6 def plain_style(): """ Deactivate xkcd style and adapt some rc parameter. """ global bar_fac bar_fac = 1.0 plt.rcdefaults() common_format() mpl.rcParams['font.family'] = 'sans-serif' mpl.rcParams['legend.fontsize'] = 'x-small' mpl.rcParams['xtick.labelsize'] = 'small' mpl.rcParams['ytick.labelsize'] = 'small' mpl.rcParams['xtick.major.size'] = 2.5 mpl.rcParams['ytick.major.size'] = 2.5 def plot_style(): """ Set rc parameter in dependence on xkcd_style. """ if xkcd_style: sketch_style() else: plain_style() # automatic initialization: plot_style()