import matplotlib as mpl import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D xkcd_style = False # default size of figure: figure_width = 15.0 # cm, should be set according to \textwidth in the latex document figure_height = 6.0 # cm, for a 1 x 2 figure # points per inch: ppi = 72.0 # colors: colors = { 'red': '#CC0000', 'orange': '#FF9900', 'lightorange': '#FFCC00', 'yellow': '#FFFF66', 'green': '#99FF00', 'blue': '#0000CC' } def cm_size(*args): """ Convert dimensions from cm to inch. Use this function to set the size of a figure in centimeter: ``` fig = plt.figure(figsize=cm_size(16.0, 10.0)) ``` Parameters ---------- args: one or many float Size in centimeter. Returns ------- inches: float or list of floats Input arguments converted to inch. """ inch_per_cm = 2.54 if len(args) == 1: return args[0]/inch_per_cm else: return [v/inch_per_cm for v in args] def adjust_fs(fig=None, left=5.5, right=0.5, bottom=2.8, top=0.5): """ Compute plot margins from multiples of the current font size. Parameters ---------- fig: matplotlib.figure or None The figure from which the figure size is taken. If None use the current figure. left: float the left margin of the plots given in multiples of the width of a character (in fact, simply 60% of the current font size). right: float the right margin of the plots given in multiples of the width of a character (in fact, simply 60% of the current font size). *Note:* in contrast to the matplotlib `right` parameters, this specifies the width of the right margin, not its position relative to the origin. bottom: float the bottom margin of the plots given in multiples of the height of a character (the current font size). top: float the right margin of the plots given in multiples of the height of a character (the current font size). *Note:* in contrast to the matplotlib `top` parameters, this specifies the width of the top margin, not its position relative to the origin. Example ------- ``` fig, axs = plt.subplots(2, 2, figsize=(10, 5)) fig.subplots_adjust(**adjust_fs(fig, left=4.5)) # no matter what the figsize is! ``` """ if fig is None: fig = plt.gcf() w, h = fig.get_size_inches()*ppi fs = plt.rcParams['font.size'] return { 'left': left*0.6*fs/w, 'right': 1.0 - right*0.6*fs/w, 'bottom': bottom*fs/h, 'top': 1.0 - top*fs/h } def show_spines(ax, spines): """ Show and hide spines. Parameters ---------- ax: matplotlib figure, matplotlib axis, or list of matplotlib axes Axis whose spines and ticks are manipulated. If figure, then apply manipulations on all axes of the figure. If list of axes, apply manipulations on each of the given axes. spines: string Specify which spines and ticks should be shown. All other ones or hidden. 'l' is the left spine, 'r' the right spine, 't' the top one and 'b' the bottom one. E.g. 'lb' shows the left and bottom spine, and hides the top and and right spines, as well as their tick marks and labels. '' shows no spines at all. 'lrtb' shows all spines and tick marks. """ # collect spine visibility: xspines = [] if 't' in spines: xspines.append('top') if 'b' in spines: xspines.append('bottom') yspines = [] if 'l' in spines: yspines.append('left') if 'r' in spines: yspines.append('right') # collect axes: if isinstance(ax, (list, tuple)): axs = ax else: axs = ax.get_axes() if not isinstance(axs, (list, tuple)): axs = [axs] for ax in axs: # hide spines: if not 'top' in xspines: ax.spines['top'].set_visible(False) if not 'bottom' in xspines: ax.spines['bottom'].set_visible(False) if not 'left' in yspines: ax.spines['left'].set_visible(False) if not 'right' in yspines: ax.spines['right'].set_visible(False) # ticks: if len(xspines) == 0: ax.xaxis.set_ticks_position('none') ax.set_xticks([]) elif len(xspines) == 1: ax.xaxis.set_ticks_position(xspines[0]) else: ax.xaxis.set_ticks_position('both') if len(yspines) == 0: ax.yaxis.set_ticks_position('none') ax.set_yticks([]) elif len(yspines) == 1: ax.yaxis.set_ticks_position(yspines[0]) else: ax.yaxis.set_ticks_position('both') def axis_label(label, unit=None): """ Format an axis label from a label and a unit Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. Returns ------- label: string An axis label formatted from `label` and `unit`. """ if not unit: return label elif xkcd_style: return '%s / %s' % (label, unit) else: return '%s [%s]' % (label, unit) def set_xlabel(ax, label, unit=None, **kwargs): """ Format the xlabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_xlabel() function. """ ax.set_xlabel_orig(axis_label(label, unit), **kwargs) def set_ylabel(ax, label, unit=None, **kwargs): """ Format the ylabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_ylabel() function. """ ax.set_ylabel_orig(axis_label(label, unit), **kwargs) def set_zlabel(ax, label, unit=None, **kwargs): """ Format the zlabel from a label and an unit. Uses the axis_label() function to format the axis label. Parameters ---------- label: string The name of the axis. unit: string The unit of the axis values. kwargs: key-word arguments Further arguments passed on to the set_zlabel() function. """ ax.set_zlabel_orig(axis_label(label, unit), **kwargs) # overwrite set_[xy]label member functions: mpl.axes.Axes.set_xlabel_orig = mpl.axes.Axes.set_xlabel mpl.axes.Axes.set_xlabel = set_xlabel mpl.axes.Axes.set_ylabel_orig = mpl.axes.Axes.set_ylabel mpl.axes.Axes.set_ylabel = set_ylabel Axes3D.set_zlabel_orig = Axes3D.set_zlabel Axes3D.set_zlabel = set_zlabel # initialization: if xkcd_style: plt.xkcd() bar_fac = 0.9 mpl.rcParams['xtick.major.size'] = 6 mpl.rcParams['ytick.major.size'] = 6 else: bar_fac = 1.0 mpl.rcParams['font.family'] = 'sans-serif' mpl.rcParams['xtick.labelsize'] = 'small' mpl.rcParams['ytick.labelsize'] = 'small' mpl.rcParams['xtick.major.size'] = 2.5 mpl.rcParams['ytick.major.size'] = 2.5 mpl.rcParams['legend.fontsize'] = 'x-small' mpl.rcParams['figure.figsize'] = cm_size(figure_width, figure_height) mpl.rcParams['figure.subplot.left'] = 5.5*0.6*mpl.rcParams['font.size']/cm_size(figure_width)/ppi mpl.rcParams['figure.subplot.right'] = 1.0 - 0.5*0.6*mpl.rcParams['font.size']/cm_size(figure_width)/ppi mpl.rcParams['figure.subplot.bottom'] = 2.8*mpl.rcParams['font.size']/cm_size(figure_height)/ppi mpl.rcParams['figure.subplot.top'] = 1.0 - 0.5*mpl.rcParams['font.size']/cm_size(figure_height)/ppi mpl.rcParams['figure.subplot.wspace'] = 0.4 mpl.rcParams['figure.subplot.hspace'] = 0.6 mpl.rcParams['figure.facecolor'] = 'white' #mpl.rcParams['axes.spines.left'] = True # newer matplotlib only #mpl.rcParams['axes.spines.bottom'] = True #mpl.rcParams['axes.spines.top'] = False #mpl.rcParams['axes.spines.right'] = False mpl.rcParams['xtick.direction'] = 'out' mpl.rcParams['ytick.direction'] = 'out' mpl.rcParams['xtick.major.width'] = 1.25 mpl.rcParams['ytick.major.width'] = 1.25