# pylint: disable=blacklisted-name
from os.path import join, dirname
import numpy as np
from ..log import get_logger
logger = get_logger()
from .. import const
from ..os import makedirs, open_file
[docs]class Plot:
[docs] def __init__(
self,
legend_fontsize=20,
legend_loc=0,
figsize=(14, 14),
figtitle=None,
figtitle_fontsize=20,
xlabel=None,
xlabel_fontsize=20,
ylabel=None,
ylabel_fontsize=20,
zlabel=None,
zlabel_fontsize=20,
xlim=None,
ylim=None,
zlim=None,
xticks=None,
xticks_fontsize=10,
xticks_rotation=0,
yticks=None,
yticks_fontsize=10,
yticks_rotation=0,
zticks=None,
zticks_fontsize=10,
zticks_rotation=0,
grid=True,
labels=None,
outpath=None):
"""Plotter.
Args:
legend_fontsize (int, optional): Legend font size.
legend_loc (str, optional): Legend location: ``'best'``,
``'upper right'``, ``'lower left'``, ``'right'``,
``'center left'``, ``'lower center'``, ``'upper center'``,
``'center'``, etc. Effective only when ``labels`` is not
``None``.
figsize (tuple, optional): Width and height of the figure in inches.
figtitle (str, optional): Figure title.
*_fontsize (int, optional): Font size.
?label (str, optional): Axis labels.
?lim (array_like, optional): Axis min. and max. ``None`` means auto.
?ticks (array_like, optional): Axis tick values. ``None`` means
auto.
?ticks_rotation (float, optional): Tick rotation in degrees.
grid (bool, optional): Whether to draw grid.
labels (list, optional): Labels.
outpath (str, optional): Path to which the plot is saved to. Should
end with ``'.png'``, and ``None`` means to ``const.Dir.tmp``.
"""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
#
self.plt = plt
self.legend_fontsize = legend_fontsize
self.legend_loc = legend_loc
self.figsize = figsize
self.figtitle = figtitle
self.figtitle_fontsize = figtitle_fontsize
self.xlabel = xlabel
self.xlabel_fontsize = xlabel_fontsize
self.ylabel = ylabel
self.ylabel_fontsize = ylabel_fontsize
self.zlabel = zlabel
self.zlabel_fontsize = zlabel_fontsize
self.xlim = xlim
self.ylim = ylim
self.zlim = zlim
self.xticks = xticks
self.xticks_rotation = xticks_rotation
self.xticks_fontsize = xticks_fontsize
self.yticks = yticks
self.yticks_rotation = yticks_rotation
self.yticks_fontsize = yticks_fontsize
self.zticks = zticks
self.zticks_rotation = zticks_rotation
self.zticks_fontsize = zticks_fontsize
self.grid = grid
self.labels = labels
self.outpath = outpath
def _savefig(self, outpath, contents_only=False, dpi=None):
# Make directory, if necessary
outdir = dirname(outpath)
makedirs(outdir)
#
if contents_only:
ax = self.plt.gca()
ax.set_position([0, 0, 1, 1])
ax.set_axis_off()
with open_file(outpath, 'wb') as h:
self.plt.savefig(h, dpi=dpi)
else:
with open_file(outpath, 'wb') as h:
self.plt.savefig(h, bbox_inches='tight', dpi=dpi)
def _add_legend(self, plot_objs):
if self.labels is None:
return
n_plot_objs = len(plot_objs)
assert (len(self.labels) == n_plot_objs), (
"Number of labels must equal number of plot objects; "
"use None for object without a label")
for i in range(n_plot_objs):
plot_objs[i].set_label(self.labels[i])
self.plt.legend(fontsize=self.legend_fontsize, loc=self.legend_loc)
def _add_axis_labels(self, ax):
if self.xlabel is not None:
ax.set_xlabel(self.xlabel, fontsize=self.xlabel_fontsize)
if self.ylabel is not None:
ax.set_ylabel(self.ylabel, fontsize=self.ylabel_fontsize)
if self.zlabel is not None:
ax.set_zlabel(self.zlabel, fontsize=self.zlabel_fontsize)
def _set_axis_ticks(self, ax):
# FIXME: if xticks is not provided, xticks_fontsize and xticks_rotation have
# no effect, which shouldn't be the case
if self.xticks is not None:
ax.set_xticklabels(
self.xticks, fontsize=self.xticks_fontsize,
rotation=self.xticks_rotation)
if self.yticks is not None:
ax.set_yticklabels(
self.yticks, fontsize=self.yticks_fontsize,
rotation=self.yticks_rotation)
if self.zticks is not None:
ax.set_zticklabels(
self.zticks, fontsize=self.zticks_fontsize,
rotation=self.zticks_rotation)
def _set_axis_lim(self, ax):
if self.xlim is not None:
ax.set_xlim(*self.xlim)
if self.ylim is not None:
ax.set_ylim(*self.ylim)
if self.zlim is not None:
ax.set_zlim(*self.zlim)
@staticmethod
def _set_axes_equal(ax, xyz):
# plt.axis('equal') not working, hence the hack of creating a cubic
# bounding box
x_data, y_data, z_data = xyz[:, 0], xyz[:, 1], xyz[:, 2]
max_range = np.array([
x_data.max() - x_data.min(),
y_data.max() - y_data.min(),
z_data.max() - z_data.min()]).max()
xb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][0].flatten() \
+ 0.5 * (x_data.max() + x_data.min())
yb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][1].flatten() \
+ 0.5 * (y_data.max() + y_data.min())
zb = 0.5 * max_range * np.mgrid[-1:2:2, -1:2:2, -1:2:2][2].flatten() \
+ 0.5 * (z_data.max() + z_data.min())
for xb_, yb_, zb_ in zip(xb, yb, zb):
ax.plot([xb_], [yb_], [zb_], 'w')
def _set_title(self, ax):
if self.figtitle is not None:
ax.set_title(self.figtitle, fontsize=self.figtitle_fontsize)
[docs] def bar(self, y, group_width=0.8):
"""Bar plot.
Args:
y (array_like): N-by-M array of N groups, each with M bars,
or N-array of N groups, each with one bar.
group_width (float, optional): Width allocated to each group,
shared by all bars within the group.
Writes
- The bar plot.
"""
outpath = join(const.Dir.tmp, 'bar.png') if self.outpath is None \
else self.outpath
fig = self.plt.figure(figsize=self.figsize)
ax = fig.add_subplot(111)
self._set_title(ax)
# Ensure y is 2D, with columns representing values within groups
# and rows across groups
if y.ndim == 1:
y = np.reshape(y, (-1, 1))
n, n_grp = y.shape
# Group width is shared by all groups
bar_width = group_width / n_grp
# Assume x is evenly spaced
x = np.arange(n)
# Plot
plot_objs = []
for i in range(n_grp):
x_ = x - 0.5 * group_width + 0.5 * bar_width + i * bar_width
plot_obj = ax.bar(x_, y[:, i], bar_width)
plot_objs.append(plot_obj)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
self._savefig(outpath)
self.plt.close('all')
return outpath
[docs] def scatter3d(
self, xyz, colors=None, size=None, equal_axes=False, views=None):
"""3D scatter plot.
Args:
xyz (array_like): N-by-3 array of N points.
colors (array_like or list(str) or str, optional): If N-array, these
values are colormapped. If N-list, its elements should be color
strings. If a single color string, all points use that color.
size (int, optional): Scatter size.
equal_axes (bool, optional): Whether to have the same scale for all
axes.
views (list(tuple), optional): List of elevation-azimuth angle pairs
(in degrees). A good set of views is ``[(30, 0), (30, 45),
(30, 90), (30, 135)]``.
Writes
- One or multiple (if ``views`` is provided) views of the 3D plot.
"""
from mpl_toolkits.mplot3d import Axes3D # noqa; pylint: disable=unused-import
#
outpath = join(const.Dir.tmp, 'scatter3d.png') if self.outpath is None \
else self.outpath
fig = self.plt.figure(figsize=self.figsize)
ax = fig.add_subplot(111, projection='3d')
self._set_title(ax)
# Prepare kwargs to scatter()
kwargs = {}
need_colorbar = False
if isinstance(colors, np.ndarray):
kwargs['c'] = colors # will be colormapped with color map
kwargs['cmap'] = 'viridis'
need_colorbar = True
elif colors is not None:
kwargs['c'] = colors
if size is not None:
kwargs['s'] = size
# Plot
plot_objs = ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], **kwargs)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
if equal_axes:
self._set_axes_equal(ax, xyz)
if need_colorbar:
self.plt.colorbar(plot_objs)
# TODO: this seems to mess up equal axes
# Save plot
outpaths = []
if outpath.endswith('.png'):
if views is None:
self._savefig(outpath)
outpaths.append(outpath)
else:
for elev, azim in views:
ax.view_init(elev, azim)
self.plt.draw()
outpath_ = outpath[:-len('.png')] + \
'_elev%03d_azim%03d.png' % (elev, azim)
self._savefig(outpath_)
outpaths.append(outpath_)
else:
raise ValueError("`outpath` must end with '.png'")
self.plt.close('all')
return outpaths
[docs] def line(self, xy, width=None, marker=None, marker_size=None):
"""Line/curve plot.
Args:
xy (array_like): N-by-M array of N x-values (first column) and
their corresponding y-values (the remaining M-1 columns).
width (float, optional): Line width.
marker (str, optional): Marker.
marker_size (float, optional): Marker size.
Writes
- The line plot.
"""
outpath = join(const.Dir.tmp, 'line.png') if self.outpath is None \
else self.outpath
fig = self.plt.figure(figsize=self.figsize)
ax = fig.add_subplot(111)
self._set_title(ax)
# Prepare kwargs to scatter()
kwargs_list = []
n_lines = xy.shape[1] - 1
for i in range(n_lines):
kwargs = {}
if width is not None:
kwargs['linewidth'] = width
if marker is not None:
kwargs['marker'] = marker
if marker_size is not None:
kwargs['markersize'] = marker_size
kwargs_list.append(kwargs)
# Plot
plot_objs = []
for i in range(n_lines):
plot_obj = self.plt.plot(xy[:, 0], xy[:, 1 + i], **kwargs_list[i])
assert len(plot_obj) == 1
plot_obj = plot_obj[0]
plot_objs.append(plot_obj)
#
self._add_legend(plot_objs)
self.plt.grid(self.grid)
self._add_axis_labels(ax)
self._set_axis_ticks(ax)
self._set_axis_lim(ax)
self._savefig(outpath)
self.plt.close('all')
return outpath