Source code for graphinglib.multifigure
from shutil import which
from string import ascii_lowercase
from typing import Literal, Optional
import matplotlib.pyplot as plt
from matplotlib import rcParamsDefault
from matplotlib.collections import LineCollection
from matplotlib.gridspec import GridSpec
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Polygon
from matplotlib.transforms import ScaledTranslation
from graphinglib.file_manager import FileLoader, get_default_style
from graphinglib.graph_elements import GraphingException, Plottable
from graphinglib.legend_artists import (
HandlerMultipleLines,
HandlerMultipleVerticalLines,
VerticalLineCollection,
histogram_legend_artist,
)
from .figure import Figure
try:
from typing import Self
except ImportError:
from typing_extensions import Self
[docs]
class MultiFigure:
"""
This class implements the "canvas" on which multiple plots are displayed.
The canvas consists of a grid of a specified size on which the
:class:`~graphinglib.figure.Figure` objects are displayed.
Parameters
----------
num_rows, num_cols : int
Number of rows and columns for the grid. These parameters determine the
the number of "squares" on which a plot can be placed.
.. note::
Note that a single plot can span multiple squares.
See :py:meth:`~graphinglib.multifigure.MultiFigure.add_SubFigure`.
size : tuple[float, float]
Overall size of the multifigure.
Default depends on the ``figure_style`` configuration.
title : str, optional
General title of the figure.
reference_labels : bool
Whether or not to add reference labels to the SubFigures.
Defaults to ``True``.
.. note::
The reference labels are in the form of "a)", "b)", etc. and are used to refer
to a particular SubFigure in a caption accompanying the MultiFigure.
reflabel_loc : str
Location of the reference labels of the SubFigures. Either "inside" or "outside".
Defaults to "outside".
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
"""
[docs]
def __init__(
self,
num_rows: int,
num_cols: int,
size: tuple[float, float] | Literal["default"] = "default",
title: Optional[str] = None,
reference_labels: bool = True,
reflabel_loc: str = "outside",
figure_style: str = "default",
) -> None:
"""
This class implements the "canvas" on which multiple plots are displayed.
The canvas consists of a grid of a specified size on which the individual
:class:`~graphinglib.figure.Figure` objects are displayed.
Parameters
----------
num_rows, num_cols : int
Number of rows and columns for the grid. These parameters determine the
the number of "squares" on which a plot can be placed.
.. note::
Note that a single plot can span multiple squares.
See :py:meth:`~graphinglib.multifigure.MultiFigure.add_SubFigure`.
size : tuple[float, float]
Overall size of the figure.
Default depends on the ``figure_style`` configuration.
title : str, optional
General title of the figure.
reference_labels : bool
Whether or not to add reference labels to the SubFigures.
Defaults to ``True``.
.. note::
The reference labels are in the form of "a)", "b)", etc. and are used to refer
to a particular SubFigure in a caption accompanying the MultiFigure.
reflabel_loc : str
Location of the reference labels of the SubFigures. Either "inside" or "outside".
Defaults to "outside".
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
"""
if type(num_rows) != int or type(num_cols) != int:
raise TypeError("The number of rows and columns must be integers.")
if num_rows < 1 or num_cols < 1:
raise ValueError("The number of rows and columns must be greater than 0.")
self._num_rows = num_rows
self._num_cols = num_cols
self._title = title
self._reference_labels = reference_labels
self._reflabel_loc = reflabel_loc
self._figure_style = figure_style
self._size = size
self._sub_figures = []
self._rc_dict = {}
self._user_rc_dict = {}
@property
def num_rows(self) -> int:
return self._num_rows
@property
def num_cols(self) -> int:
return self._num_cols
@property
def title(self) -> Optional[str]:
return self._title
@title.setter
def title(self, title: Optional[str]) -> None:
self._title = title
@property
def reference_labels(self) -> bool:
return self._reference_labels
@reference_labels.setter
def reference_labels(self, reference_labels: bool) -> None:
self._reference_labels = reference_labels
@property
def reflabel_loc(self) -> str:
return self._reflabel_loc
@reflabel_loc.setter
def reflabel_loc(self, reflabel_loc: str) -> None:
self._reflabel_loc = reflabel_loc
@property
def figure_style(self) -> str:
return self._figure_style
@figure_style.setter
def figure_style(self, figure_style: str) -> None:
self._figure_style = figure_style
@property
def size(self) -> tuple[float, float] | Literal["default"]:
return self._size
@size.setter
def size(self, size: tuple[float, float] | Literal["default"]) -> None:
self._size = size
[docs]
@classmethod
def from_row(
cls,
figures: list[Figure],
size: tuple[float, float] | Literal["default"] = "default",
title: Optional[str] = None,
reference_labels: bool = True,
reflabel_loc: str = "outside",
figure_style: str = "default",
) -> Self:
"""Creates a MultiFigure with the specified :class:`~graphinglib.figure.Figure` objects in a horizontal configuration.
Parameters
----------
figures : list[Figure]
The :class:`~graphinglib.figure.Figure` objects to add to the MultiFigure, from left to right.
size : tuple[float, float]
Overall size of the figure.
Default depends on the ``figure_style`` configuration.
title : str, optional
Title of the MultiFigure.
Defaults to ``None``.
reference_labels : bool
Whether or not to add reference labels to the SubFigures.
Defaults to ``True``.
reflabel_loc : str
Location of the reference labels of the SubFigures. Either "inside" or "outside".
Defaults to "outside".
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
Returns
-------
A new MultiFigure object.
"""
multi_fig = cls(
num_rows=1,
num_cols=len(figures),
size=size,
title=title,
reference_labels=reference_labels,
reflabel_loc=reflabel_loc,
figure_style=figure_style,
)
for i, figure in enumerate(figures):
multi_fig.add_figure(figure, 0, i, 1, 1)
return multi_fig
[docs]
@classmethod
def from_stack(
cls,
figures: list[Figure],
size: tuple[float, float] | Literal["default"] = "default",
title: Optional[str] = None,
reference_labels: bool = True,
reflabel_loc: str = "outside",
figure_style: str = "default",
) -> Self:
"""Creates a MultiFigure with the specified :class:`~graphinglib.figure.Figure` objects in a vertical configuration.
Parameters
----------
figures : list[Figure]
The :class:`~graphinglib.figure.Figure` objects to add to the MultiFigure, from top to bottom.
size : tuple[float, float]
Overall size of the figure.
Default depends on the ``figure_style`` configuration.
title : str, optional
Title of the MultiFigure.
Defaults to ``None``.
reference_labels : bool
Whether or not to add reference labels to the SubFigures.
Defaults to ``True``.
reflabel_loc : str
Location of the reference labels of the SubFigures. Either "inside" or "outside".
Defaults to "outside".
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
Returns
-------
A new MultiFigure object.
"""
multi_fig = cls(
num_rows=len(figures),
num_cols=1,
size=size,
title=title,
reference_labels=reference_labels,
reflabel_loc=reflabel_loc,
figure_style=figure_style,
)
for i, figure in enumerate(figures):
multi_fig.add_figure(figure, i, 0, 1, 1)
return multi_fig
[docs]
@classmethod
def from_grid(
cls,
figures: list[Figure],
dimensions: tuple[int, int],
size: tuple[float, float] | Literal["default"] = "default",
title: Optional[str] = None,
reference_labels: bool = True,
reflabel_loc: str = "outside",
figure_style: str = "default",
) -> Self:
"""Creates a MultiFigure with the specified :class:`~graphinglib.figure.Figure` objects in a grid configuration.
Parameters
----------
figures : list[Figure]
The :class:`~graphinglib.figure.Figure` objects to add to the MultiFigure, from top-left to bottom-right.
dimensions : tuple[int, int]
The number of rows and columns of the grid (product should equal the number of figures).
size : tuple[float, float]
Overall size of the figure.
Default depends on the ``figure_style`` configuration.
title : str, optional
Title of the MultiFigure.
Defaults to ``None``.
reference_labels : bool
Whether or not to add reference labels to the SubFigures.
Defaults to ``True``.
reflabel_loc : str
Location of the reference labels of the SubFigures. Either "inside" or "outside".
Defaults to "outside".
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
Returns
-------
A new MultiFigure object.
"""
num_rows, num_cols = dimensions
if num_rows * num_cols < len(figures):
raise ValueError(
f"The product of the dimensions ({num_rows} x {num_cols}) must be greater than or equal to the number of figures ({len(figures)})."
)
multi_fig = cls(
num_rows=num_rows,
num_cols=num_cols,
size=size,
title=title,
reference_labels=reference_labels,
reflabel_loc=reflabel_loc,
figure_style=figure_style,
)
for i, figure in enumerate(figures):
row = i // num_cols
col = i % num_cols
multi_fig.add_figure(figure, row, col, 1, 1)
return multi_fig
[docs]
def add_figure(
self,
figure: Figure,
row_start: int,
col_start: int,
row_span: int,
col_span: int,
) -> None:
"""
Adds a :class:`~graphinglib.figure.Figure` to a :class:`~graphinglib.multifigure.MultiFigure`.
Parameters
----------
figure : Figure
The :class:`~graphinglib.figure.Figure` to add to the MultiFigure.
row_start : int
The row where to set the upper-left corner of the SubFigure.
col_start : int
The column where to set the upper-left corner of the SubFigure.
row_span : int
The number of rows spanned by the SubFigure.
col_span : int
The number of columns spanned by the SubFigure.
"""
if type(row_start) != int or type(col_start) != int:
raise TypeError("The placement values must be integers.")
if row_start < 0 or col_start < 0:
raise ValueError("The placement values cannot be negative.")
if type(row_span) != int or type(col_span) != int:
raise TypeError("The span values must be integers.")
if row_span < 1 or col_span < 1:
raise ValueError("The span values must be greater than 0.")
if (
row_start + row_span > self._num_rows
or col_start + col_span > self._num_cols
):
raise ValueError(
"The placement values and span values must be inside the size of the MultiFigure."
)
# Add location and span to the SubFigure (create new attributes)
figure._row_start = row_start
figure._col_start = col_start
figure._row_span = row_span
figure._col_span = col_span
self._sub_figures.append(figure)
[docs]
def show(
self,
general_legend: bool = False,
legend_loc: str = "outside lower center",
legend_cols: int = 1,
) -> None:
"""
Displays the :class:`~graphinglib.multifigure.MultiFigure`.
Parameters
----------
general_legend : bool
Whether or not to display an overall legend for the :class:`~graphinglib.multifigure.MultiFigure` containing
the labels of every :class:`~graphinglib.Figure.Figure` inside it. Note that enabling this option will
disable the individual legends for every :class:`~graphinglib.multifigure.SubFigure`.
Defaults to ``False``.
legend_loc : str
The location of the legend in the MultiFigure. Possible placement keywords are: for vertical placement: ``{"upper", "center", "lower"}``, for horizontal placement: ``{"left", "center", "right"}``. The keyword ``"outside"`` can be added to put the legend outside of the axes. Defaults to ``"outside lower center"``.
legend_cols : int
Number of colums in which to arrange the legend items. Defaults to 1.
"""
self._prepare_multi_figure(
general_legend=general_legend,
legend_loc=legend_loc,
legend_cols=legend_cols,
)
plt.show()
plt.rcParams.update(rcParamsDefault)
[docs]
def save(
self,
file_name: str,
general_legend: bool = False,
legend_loc: str = "outside lower center",
legend_cols: int = 1,
dpi: Optional[int] = None,
) -> None:
"""
Saves the :class:`~graphinglib.multifigure.MultiFigure` to a file.
Parameters
----------
file_name : str
File name or path at which to save the figure.
general_legend : bool
Whether or not to display an overall legend for the :class:`~graphinglib.multifigure.MultiFigure` containing
the labels of every :class:`~graphinglib.figure.Figure` inside it. Note that enabling this option will
disable the individual legends for every :class:`~graphinglib.figure.Figure`.
Defaults to ``False``.
legend_loc : str
The location of the legend in the MultiFigure. Possible placement keywords are: for vertical placement: ``{"upper", "center", "lower"}``, for horizontal placement: ``{"left", "center", "right"}``. The keyword ``"outside"`` can be added to put the legend outside of the axes. Defaults to ``"outside lower center"``.
legend_cols : int
Number of colums in which to arrange the legend items. Defaults to 1.
dpi : int, optional
The resolution of the saved MultiFigure. Only used for raster formats (e.g. PNG, JPG, etc.).
Default depends on the ``figure_style`` configuration.
"""
self._prepare_multi_figure(
general_legend=general_legend,
legend_loc=legend_loc,
legend_cols=legend_cols,
)
if dpi is not None:
plt.savefig(file_name, dpi=dpi, bbox_inches="tight")
else:
plt.savefig(file_name, bbox_inches="tight")
plt.close()
plt.rcParams.update(rcParamsDefault)
def _prepare_multi_figure(
self,
general_legend: bool = False,
legend_loc: str = "outside lower center",
legend_cols: int = 1,
) -> None:
"""
Prepares the :class:`~graphinglib.multifigure.MultiFigure` to be displayed.
"""
if self._figure_style == "default":
self._figure_style = get_default_style()
try:
file_loader = FileLoader(self._figure_style)
self._default_params = file_loader.load()
is_matplotlib_style = False
except FileNotFoundError:
is_matplotlib_style = True
try:
if self._figure_style == "matplotlib":
plt.style.use("default")
else:
plt.style.use(self._figure_style)
file_loader = FileLoader("plain")
self._default_params = file_loader.load()
except OSError:
raise GraphingException(
f"The figure style {self._figure_style} was not found. Please choose a different style."
)
multi_figure_params_to_reset = self._fill_in_missing_params(self)
self._fill_in_rc_params(is_matplotlib_style)
self._figure = plt.figure(layout="constrained", figsize=self._size)
MultiFigure_grid = GridSpec(self._num_rows, self._num_cols, figure=self._figure)
if self._reflabel_loc == "outside":
trans = ScaledTranslation(-5 / 72, 10 / 72, self._figure.dpi_scale_trans)
elif self._reflabel_loc == "inside":
trans = ScaledTranslation(10 / 72, -15 / 72, self._figure.dpi_scale_trans)
else:
raise ValueError(
"Invalid reference label location. Please specify either 'inside' or 'outside'."
)
sub_figures_do_legend = True if not general_legend else False
labels, handles = [], []
for i, sub_figure in enumerate(self._sub_figures):
self._fill_in_rc_params(is_matplotlib_style)
sub_figure_labels, sub_figure_handles = self._prepare_sub_figure(
sub_figure,
MultiFigure_grid,
transformation=trans,
reference_label=ascii_lowercase[i] + ")",
legend=sub_figures_do_legend,
is_matplotlib_style=is_matplotlib_style,
)
labels += sub_figure_labels
handles += sub_figure_handles
self._fill_in_rc_params(is_matplotlib_style)
if general_legend:
try:
self._figure.legend(
handles=handles,
labels=labels,
handleheight=1.3,
handler_map={
Polygon: HandlerPatch(patch_func=histogram_legend_artist),
LineCollection: HandlerMultipleLines(),
VerticalLineCollection: HandlerMultipleVerticalLines(),
},
draggable=True,
loc=legend_loc,
ncols=legend_cols,
)
except:
self._figure.legend(
handles=handles,
labels=labels,
handleheight=1.3,
handler_map={
Polygon: HandlerPatch(patch_func=histogram_legend_artist),
LineCollection: HandlerMultipleLines(),
VerticalLineCollection: HandlerMultipleVerticalLines(),
},
loc=legend_loc,
ncols=legend_cols,
)
self._figure.suptitle(self._title)
self._reset_params_to_default(self, multi_figure_params_to_reset)
self._rc_dict = {}
def _prepare_sub_figure(
self,
sub_figure: Figure,
grid: GridSpec,
transformation: ScaledTranslation,
reference_label: str,
legend: bool,
is_matplotlib_style: bool,
):
"""
Prepares a single subfigure.
"""
sub_rcs = sub_figure._user_rc_dict
plt.rcParams.update(sub_rcs)
axes = plt.subplot(
grid.new_subplotspec(
(sub_figure._row_start, sub_figure._col_start),
rowspan=sub_figure._row_span,
colspan=sub_figure._col_span,
)
)
if self._reference_labels:
axes.text(
0,
1,
reference_label,
transform=axes.transAxes + transformation,
)
default_params_copy = self._default_params.copy()
default_params_copy.update(is_a_subfigure=True)
default_params_copy["Figure"]["_figure_style"] = self._figure_style
labels, handles = sub_figure._prepare_figure(
legend=legend,
axes=axes,
default_params=default_params_copy,
is_matplotlib_style=is_matplotlib_style,
)
return labels, handles
def _fill_in_missing_params(self, element: Plottable) -> list[str]:
"""
Fills in the missing parameters from the specified ``figure_style``.
"""
params_to_reset = []
object_type = type(element).__name__
for property, value in vars(element).items():
if (type(value) == str) and (value == "default"):
params_to_reset.append(property)
if self._default_params[object_type][property] == "same as curve":
element.__dict__["_errorbars_color"] = self._default_params[
object_type
]["_color"]
element.__dict__["_errorbars_line_width"] = self._default_params[
object_type
]["_line_width"]
element.__dict__["_cap_thickness"] = self._default_params[
object_type
]["_line_width"]
elif self._default_params[object_type][property] == "same as scatter":
element.__dict__["_errorbars_color"] = self._default_params[
object_type
]["_face_color"]
else:
element.__dict__[property] = self._default_params[object_type][
property
]
return params_to_reset
def _reset_params_to_default(
self, element: Plottable, params_to_reset: list[str]
) -> None:
"""
Resets the parameters that were set to default in the _fill_in_missing_params method.
"""
for param in params_to_reset:
setattr(element, param, "default")
def _fill_in_rc_params(self, is_matplotlib_style: bool = False) -> None:
"""
Fills in and sets the missing rc parameters from the specified ``figure_style``.
If ``is_matplotlib_style`` is ``True``, the rc parameters are reset to the default values for the specified ``figure_style``.
If ``is_matplotlib_style`` is ``False``, the rc parameters are updated with the missing parameters from the specified ``figure_style``.
In both cases, the rc parameters are then updated with the user-specified parameters.
"""
if is_matplotlib_style:
if self._figure_style == "matplotlib":
plt.style.use("default")
else:
plt.style.use(self._figure_style)
plt.rcParams.update(self._user_rc_dict)
else:
params = self._default_params["rc_params"]
for property, value in params.items():
# add to rc_dict if not already in there
if (property not in self._rc_dict) and (
property not in self._user_rc_dict
):
self._rc_dict[property] = value
all_rc_params = {**self._rc_dict, **self._user_rc_dict}
try:
if all_rc_params["text.usetex"] and which("latex") is None:
all_rc_params["text.usetex"] = False
except KeyError:
pass
plt.rcParams.update(all_rc_params)
[docs]
def set_rc_params(
self,
rc_params_dict: dict[str, str | float] = {},
reset: bool = False,
) -> None:
"""
Customize the visual style of the :class:`~graphinglib.multifigure.MultiFigure`.
Any rc parameter that is not specified in the dictionary will be set to the default value for the specified ``figure_style``.
Parameters
----------
rc_params_dict : dict[str, str | float]
Dictionary of rc parameters to update.
Defaults to empty dictionary.
reset : bool
Whether or not to reset the rc parameters to the default values for the specified ``figure_style``.
Defaults to ``False``.
"""
if reset:
self._user_rc_dict = {}
for property, value in rc_params_dict.items():
self._user_rc_dict[property] = value
[docs]
def set_visual_params(
self,
reset: bool = False,
figure_face_color: str | None = None,
axes_face_color: str | None = None,
axes_edge_color: str | None = None,
axes_label_color: str | None = None,
axes_line_width: float | None = None,
color_cycle: list[str] | None = None,
x_tick_color: str | None = None,
y_tick_color: str | None = None,
legend_face_color: str | None = None,
legend_edge_color: str | None = None,
font_family: str | None = None,
font_size: float | None = None,
font_weight: str | None = None,
text_color: str | None = None,
use_latex: bool | None = None,
grid_line_style: str | None = None,
grid_line_width: float | None = None,
grid_color: str | None = None,
grid_alpha: float | None = None,
) -> None:
"""
Customize the visual style of the :class:`~graphinglib.figure.Figure`.
Any parameter that is not specified (None) will be set to the default value for the specified ``figure_style``.
Parameters
----------
reset : bool
Whether or not to reset the rc parameters to the default values for the specified ``figure_style``.
Defaults to ``False``.
figure_face_color : str
The color of the figure face.
Defaults to ``None``.
axes_face_color : str
The color of the axes face.
Defaults to ``None``.
axes_edge_color : str
The color of the axes edge.
Defaults to ``None``.
axes_label_color : str
The color of the axes labels.
Defaults to ``None``.
axes_line_width : float
The width of the axes lines.
Defaults to ``None``.
color_cycle : list[str]
A list of colors to use for the color cycle.
Defaults to ``None``.
x_tick_color : str
The color of the x-axis ticks.
Defaults to ``None``.
y_tick_color : str
The color of the y-axis ticks.
Defaults to ``None``.
legend_face_color : str
The color of the legend face.
Defaults to ``None``.
legend_edge_color : str
The color of the legend edge.
Defaults to ``None``.
font_family : str
The font family to use.
Defaults to ``None``.
font_size : float
The font size to use.
Defaults to ``None``.
font_weight : str
The font weight to use.
Defaults to ``None``.
text_color : str
The color of the text.
Defaults to ``None``.
use_latex : bool
Whether or not to use latex.
Defaults to ``None``.
grid_line_style : str
The style of the grid lines.
Defaults to ``None``.
grid_line_width : float
The width of the grid lines.
Defaults to ``None``.
grid_color : str
The color of the grid lines.
Defaults to ``None``.
grid_alpha : float
The alpha of the grid lines.
Defaults to ``None``.
"""
if color_cycle is not None:
color_cycle = plt.cycler(color=color_cycle)
rc_params_dict = {
"figure.facecolor": figure_face_color,
"axes.facecolor": axes_face_color,
"axes.edgecolor": axes_edge_color,
"axes.labelcolor": axes_label_color,
"axes.linewidth": axes_line_width,
"axes.prop_cycle": color_cycle,
"xtick.color": x_tick_color,
"ytick.color": y_tick_color,
"legend.facecolor": legend_face_color,
"legend.edgecolor": legend_edge_color,
"font.family": font_family,
"font.size": font_size,
"font.weight": font_weight,
"text.color": text_color,
"text.usetex": use_latex,
"grid.linestyle": grid_line_style,
"grid.linewidth": grid_line_width,
"grid.color": grid_color,
"grid.alpha": grid_alpha,
}
rc_params_dict = {
key: value for key, value in rc_params_dict.items() if value is not None
}
self.set_rc_params(rc_params_dict, reset=reset)