# pylint: disable=blacklisted-name
from os.path import join, dirname
import numpy as np
from ..log import get_logger
logger = get_logger()
from .. import const
from ..os import makedirs, open_file
[docs]class Plot:
[docs] def __init__(
self,
figsize=None,
legend_labels=None, legend_fontsize=None, legend_loc='best',
figtitle=None, figtitle_fontsize=None,
axis_labels=None, axis_label_fontsizes=None,
axis_lim=None,
axis_ticks=None, axis_tick_labels=None,
axis_tick_label_fontsizes=None, axis_tick_label_rotations=None,
grid=True):
"""Plotter.
Args:
figsize (tuple, optional): Width and height of the figure in inches.
legend_labels (list, optional): Legend labels.
legend_fontsize (int, optional): Legend font size.
legend_loc (str, optional): Legend location: ``'best'``,
``'upper right'``, ``'lower left'``, ``'right'``,
``'center left'``, ``'lower center'``, ``'upper center'``,
``'center'``, etc. Effective only when ``legend_labels`` is not
``None``.
figtitle (str, optional): Figure title.
figtitle_fontsize (int, optional): Font size.
axis_labels (dict, optional): Axis labels with ``'x'``, ``'y'``,
and/or ``'z'`` as keys.
axis_label_fontsizes (dict, optional): Axis label font sizes with
``'x'``, ``'y'``, and/or ``'z'`` as keys.
axis_lim (dict, optional): Mapping ``'x'``, ``'y'``, or ``'z'`` to
an ``array_like`` of axis min. and max.
axis_ticks (dict, optional): Axis tick locations, mapping ``'x'``,
``'y'``, or ``'z'`` to an ``array_like`` of floats.
axis_tick_labels (dict, optional): Axis tick labels, mapping
``'x'``, ``'y'``, or ``'z'`` to a list of strings.
axis_tick_label_fontsizes (dict, optional): Axis tick label font
sizes, mapping ``'x'``, ``'y'``, or ``'z'`` to a float.
axis_tick_label_rotations (dict, optional): Axis tick label
rotations in degrees, mapping ``'x'``, ``'y'``, or ``'z'``
to a float.
grid (bool, optional): Whether to draw grid.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
#
self.plt = plt
self.legend_labels = legend_labels
self.legend_fontsize = legend_fontsize
self.legend_loc = legend_loc
self.figsize = figsize
self.figtitle = figtitle
self.figtitle_fontsize = figtitle_fontsize
self.axis_labels = self._init_axis_dict(axis_labels)
self.axis_label_fontsizes = self._init_axis_dict(axis_label_fontsizes)
self.axis_lim = self._init_axis_dict(axis_lim)
self.axis_ticks = self._init_axis_dict(axis_ticks)
self.axis_tick_labels = self._init_axis_dict(axis_tick_labels)
self.axis_tick_label_fontsizes = self._init_axis_dict(
axis_tick_label_fontsizes)
self.axis_tick_label_rotations = self._init_axis_dict(
axis_tick_label_rotations)
self.grid = grid
@staticmethod
def _init_axis_dict(overrides):
default = {'x': None, 'y': None, 'z': None}
if overrides is None:
return default
for k, v in overrides.items():
default[k] = v
return default
def _create_fig(self):
if self.figsize is None:
fig = self.plt.figure()
else:
fig = self.plt.figure(figsize=self.figsize)
return fig
def _savefig(self, outpath, contents_only=False, dpi=None):
# Make directory, if necessary
outdir = dirname(outpath)
makedirs(outdir)
#
if contents_only:
ax = self.plt.gca()
ax.set_position([0, 0, 1, 1])
ax.set_axis_off()
with open_file(outpath, 'wb') as h:
self.plt.savefig(h, dpi=dpi)
else:
with open_file(outpath, 'wb') as h:
self.plt.savefig(h, bbox_inches='tight', dpi=dpi)
def _add_legend(self, plot_objs):
if self.legend_labels is None:
return
n_plot_objs = len(plot_objs)
assert (len(self.legend_labels) == n_plot_objs), (
"Number of legend labels must equal number of plot objects; "
"use None for object without a legend label")
for i in range(n_plot_objs):
plot_objs[i].set_label(self.legend_labels[i])
if self.legend_fontsize is None:
self.plt.legend(loc=self.legend_loc)
else:
self.plt.legend(fontsize=self.legend_fontsize, loc=self.legend_loc)
def _add_axis_labels(self, ax):
for axis, label in self.axis_labels.items():
if label is None:
continue
fontsize = self.axis_label_fontsizes[axis]
set_func = getattr(ax, f'set_{axis}label')
set_func(label, fontsize=fontsize)
def _set_axis_ticks(self, ax):
# Tick locations
for axis, ticks in self.axis_ticks.items():
if ticks is None:
continue
set_func = getattr(ax, f'set_{axis}ticks')
set_func(ticks)
# Tick labels
for axis, tick_labels in self.axis_tick_labels.items():
if tick_labels is None:
continue
set_func = getattr(ax, f'set_{axis}ticklabels')
set_func(tick_labels)
# Tick label font size
for axis, tick_label_fontsize in self.axis_tick_label_fontsizes.items():
if tick_label_fontsize is None:
continue
ax.tick_params(axis=axis, labelsize=tick_label_fontsize)
# Tick rotation
for axis, tick_label_rotation in self.axis_tick_label_rotations.items():
if tick_label_rotation is None:
continue
ax.tick_params(axis=axis, labelrotation=tick_label_rotation)
def _set_axis_lim(self, ax):
for axis, lim in self.axis_lim.items():
if lim is None:
continue
set_func = getattr(ax, f'set_{axis}lim')
set_func(*lim)
@staticmethod
def _set_axes_equal(ax, xyz):
# plt.axis('equal') not working, hence the hack of creating a cubic
# bounding box
x_data, y_data, z_data = xyz[:, 0], xyz[:, 1], xyz[:, 2]
max_range = np.array([
x_data.max() - x_data.min(),
y_data.max() - y_data.min(),
z_data.max() - z_data.min()]).max()
xb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][0].flatten() \
+ 0.5 * (x_data.max() + x_data.min())
yb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][1].flatten() \
+ 0.5 * (y_data.max() + y_data.min())
zb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][2].flatten() \
+ 0.5 * (z_data.max() + z_data.min())
for xb_, yb_, zb_ in zip(xb, yb, zb):
ax.plot([xb_], [yb_], [zb_], 'w')
def _set_title(self, ax):
if self.figtitle is not None:
if self.figtitle_fontsize is None:
ax.set_title(self.figtitle)
else:
ax.set_title(self.figtitle, fontsize=self.figtitle_fontsize)
[docs] def bar(self, y, group_width=0.8, outpath=None):
"""Bar plot.
Args:
y (array_like): N-by-M array of N groups, each with M bars,
or N-array of N groups, each with one bar.
group_width (float, optional): Width allocated to each group,
shared by all bars within the group.
outpath (str, optional): Path to which the plot is saved. ``None``
means a temporary file in ``const.Dir.tmp``.
Returns:
str: Path to the plot written.
Writes
- The bar plot.
"""
if outpath is None:
outpath = join(const.Dir.tmp, 'bar.png')
fig = self._create_fig()
ax = fig.add_subplot(111)
self._set_title(ax)
# Ensure y is 2D, with columns representing values within groups
# and rows across groups
if y.ndim == 1:
y = np.reshape(y, (-1, 1))
n, n_grp = y.shape
# Group width is shared by all groups
bar_width = group_width / n_grp
# Assume x is evenly spaced
x = np.arange(n)
# Plot
plot_objs = []
for i in range(n_grp):
x_ = x - 0.5 * group_width + 0.5 * bar_width + i * bar_width
plot_obj = ax.bar(x_, y[:, i], bar_width)
plot_objs.append(plot_obj)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
self._savefig(outpath)
self.plt.close('all')
return outpath
[docs] def scatter3d(
self, xyz, colors=None, size=None, equal_axes=False, views=None,
outpath=None):
"""3D scatter plot.
Args:
xyz (array_like): N-by-3 array of N points.
colors (array_like or list(str) or str, optional): If N-array, these
values are colormapped. If N-list, its elements should be color
strings. If a single color string, all points use that color.
size (int, optional): Scatter size.
equal_axes (bool, optional): Whether to have the same scale for all
axes.
views (list(tuple), optional): List of elevation-azimuth angle pairs
(in degrees). A good set of views is ``[(30, 0), (30, 45),
(30, 90), (30, 135)]``.
outpath (str, optional): Path to which the plot is saved. ``None``
means a temporary file in ``const.Dir.tmp``.
Returns:
str: Path to the plot written.
Writes
- One or multiple (if ``views`` is provided) views of the 3D plot.
"""
from mpl_toolkits.mplot3d import Axes3D # noqa; pylint: disable=unused-variable
#
if outpath is None:
outpath = join(const.Dir.tmp, 'scatter3d.png')
fig = self._create_fig()
ax = fig.add_subplot(111, projection='3d')
self._set_title(ax)
# Prepare kwargs to scatter()
kwargs = {}
need_colorbar = False
if isinstance(colors, np.ndarray):
kwargs['c'] = colors # will be colormapped with color map
kwargs['cmap'] = 'viridis'
need_colorbar = True
elif colors is not None:
kwargs['c'] = colors
if size is not None:
kwargs['s'] = size
# Plot
plot_objs = ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], **kwargs)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
if equal_axes:
self._set_axes_equal(ax, xyz)
if need_colorbar:
self.plt.colorbar(plot_objs)
# FIXME: This seems to mess up equal axes
# Save plot
outpaths = []
if outpath.endswith('.png'):
if views is None:
self._savefig(outpath)
outpaths.append(outpath)
else:
for elev, azim in views:
ax.view_init(elev, azim)
self.plt.draw()
outpath_ = outpath[:-len('.png')] + \
'_elev%03d_azim%03d.png' % (elev, azim)
self._savefig(outpath_)
outpaths.append(outpath_)
else:
raise ValueError("`outpath` must end with '.png'")
self.plt.close('all')
return outpaths
[docs] def line(self, xy, width=None, marker=None, marker_size=None, outpath=None):
"""Line/curve plot.
Args:
xy (array_like): N-by-M array of N x-values (first column) and
their corresponding y-values (the remaining M-1 columns).
width (float, optional): Line width.
marker (str, optional): Marker.
marker_size (float, optional): Marker size.
outpath (str, optional): Path to which the plot is saved. ``None``
means a temporary file in ``const.Dir.tmp``.
Returns:
str: Path to the plot written.
Writes
- The line plot.
"""
if outpath is None:
outpath = join(const.Dir.tmp, 'line.png')
fig = self._create_fig()
ax = fig.add_subplot(111)
self._set_title(ax)
# Prepare kwargs to scatter()
kwargs_list = []
n_lines = xy.shape[1] - 1
for i in range(n_lines):
kwargs = {}
if width is not None:
kwargs['linewidth'] = width
if marker is not None:
kwargs['marker'] = marker
if marker_size is not None:
kwargs['markersize'] = marker_size
kwargs_list.append(kwargs)
# Plot
plot_objs = []
for i in range(n_lines):
plot_obj = self.plt.plot(xy[:, 0], xy[:, 1 + i], **kwargs_list[i])
assert len(plot_obj) == 1
plot_obj = plot_obj[0]
plot_objs.append(plot_obj)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
self._savefig(outpath)
self.plt.close('all')
return outpath