Source code for graphinglib.figure
from .inherit import INHERIT, Inherit, is_inherit
from copy import deepcopy
from shutil import which
from typing import Literal, Optional
from warnings import warn
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.collections import LineCollection
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Polygon
from .file_manager import FileLoader, FileUpdater, get_default_style
from .graph_elements import GraphingException, Plottable
from .legend_artists import (
HandlerMultipleLines,
HandlerMultipleVerticalLines,
VerticalLineCollection,
histogram_legend_artist,
)
from .tools import _copy_with_overrides
try:
from typing import Self
except ImportError:
from typing_extensions import Self
[docs]
class Figure:
"""
This class implements a general figure object.
Parameters
----------
x_label, y_label : str
The indentification for the x-axis and y-axis.
Defaults to ``"x axis"`` and ``"y axis"``.
x_lim, y_lim : tuple[float, float], optional
The limits for the x-axis and y-axis.
size : tuple[float, float]
Overall size of the figure.
Figure size is in inches; typical width is ``4`` to ``12`` and typical height is ``3`` to ``8``.
Default depends on the ``figure_style`` configuration.
title: str, optional
The title of the figure.
log_scale_x, log_scale_y : bool
Whether or not to set the scale of the x- or y-axis to logaritmic scale.
Default depends on the ``figure_style`` configuration.
show_grid : bool
Whether or not to show the grid.
Default depends on the ``figure_style`` configuration.
remove_axes : bool
Whether or not to show the axes. Useful for adding tables or text to
the subfigure. Defaults to ``False``.
aspect_ratio : float, str
The aspect ratio of the axis scaling. Values are ``"equal"``, ``"auto"``, or a positive float.
Defaults to ``"auto"``.
figure_style : str
The figure style to use for the figure.
"""
[docs]
def __init__(
self,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
size: tuple[float, float] | Inherit = INHERIT,
title: Optional[str] = None,
x_lim: Optional[tuple[float, float]] = None,
y_lim: Optional[tuple[float, float]] = None,
log_scale_x: bool | Inherit = INHERIT,
log_scale_y: bool | Inherit = INHERIT,
remove_axes: bool = False,
aspect_ratio: float | str = "auto",
figure_style: str | Inherit = INHERIT,
) -> None:
"""
This class implements a general figure object.
Parameters
----------
x_label, y_label : str, optional
The indentification for the x-axis and y-axis.
Defaults to ``"x axis"`` and ``"y axis"``.
x_lim, y_lim : tuple[float, float], optional
The limits for the x-axis and y-axis.
size : tuple[float, float]
Overall size of the figure.
Figure size is in inches; typical width is ``4`` to ``12`` and typical height is ``3`` to ``8``.
Default depends on the ``figure_style`` configuration.
title: str, optional
The title of the figure.
log_scale_x, log_scale_y : bool
Whether or not to set the scale of the x- or y-axis to logaritmic scale.
Default depends on the ``figure_style`` configuration.
remove_axes : bool
Whether or not to show the axes. Useful for adding tables or text to
the subfigure. Defaults to ``False``.
aspect_ratio : float, str
The aspect ratio of the axis scaling. Values are ``"equal"``, ``"auto"``, or a positive float.
Defaults to ``"auto"``.
figure_style : str
The figure style to use for the figure.
Default can be set using ``gl.set_default_style()``.
"""
self._figure_style = figure_style
self._size = size
self._title = title
self._log_scale_x = log_scale_x
self._log_scale_y = log_scale_y
self._show_grid = False
self._elements: list[Plottable] = []
self._labels: list[str | None] = []
self._handles = []
self._x_axis_name = x_label
self._y_axis_name = y_label
self._x_lim = x_lim
self._y_lim = y_lim
self._rc_dict = {}
self._user_rc_dict = {}
self._custom_ticks = False
self._remove_axes = remove_axes
self._twin_x_axis = None
self._twin_y_axis = None
self.aspect_ratio = aspect_ratio
@property
def figure_style(self) -> str | Inherit:
return self._figure_style
@figure_style.setter
def figure_style(self, value: str | Inherit):
self._figure_style = value
@property
def size(self) -> tuple[float, float]:
return self._size
@size.setter
def size(self, value: tuple[float, float]):
self._size = value
@property
def title(self) -> str:
return self._title
@title.setter
def title(self, value: str):
self._title = value
@property
def log_scale_x(self) -> bool:
return self._log_scale_x
@log_scale_x.setter
def log_scale_x(self, value: bool):
self._log_scale_x = value
@property
def log_scale_y(self) -> bool:
return self._log_scale_y
@log_scale_y.setter
def log_scale_y(self, value: bool):
self._log_scale_y = value
@property
def show_grid(self) -> bool:
return self._show_grid
@show_grid.setter
def show_grid(self, value: bool):
self._show_grid = value
@property
def x_axis_name(self) -> str:
return self._x_axis_name
@x_axis_name.setter
def x_axis_name(self, value: str):
self._x_axis_name = value
@property
def y_axis_name(self) -> str:
return self._y_axis_name
@y_axis_name.setter
def y_axis_name(self, value: str):
self._y_axis_name = value
@property
def x_lim(self) -> tuple[float, float]:
return self._x_lim
@x_lim.setter
def x_lim(self, value: tuple[float, float]):
self._x_lim = value
@property
def y_lim(self) -> tuple[float, float]:
return self._y_lim
@y_lim.setter
def y_lim(self, value: tuple[float, float]):
self._y_lim = value
@property
def remove_axes(self) -> bool:
return self._remove_axes
@remove_axes.setter
def remove_axes(self, value: bool):
self._remove_axes = value
@property
def aspect_ratio(self) -> float | str:
return self._aspect_ratio
@aspect_ratio.setter
def aspect_ratio(self, value: float | str):
if isinstance(value, str):
if value not in ["equal", "auto"]:
raise GraphingException(
"Aspect ratio must be either 'equal', 'auto' or a float."
)
elif isinstance(value, (int, float)) and not isinstance(value, bool):
if value <= 0:
raise GraphingException("Aspect ratio must be a positive float.")
else:
raise GraphingException(
"Aspect ratio must be either 'equal', 'auto' or a float."
)
self._aspect_ratio = value
[docs]
def add_elements(self, *elements: Plottable) -> None:
"""
Adds one or more :class:`~graphinglib.graph_elements.Plottable` elements to the :class:`~graphinglib.figure.Figure`.
Parameters
----------
elements : :class:`~graphinglib.graph_elements.Plottable`
Elements to plot in the :class:`~graphinglib.figure.Figure`.
"""
for element in elements:
self._elements.append(element)
[docs]
def copy(self) -> Self:
"""
Returns a deep copy of the :class:`~graphinglib.figure.Figure` object.
"""
return deepcopy(self)
[docs]
def copy_with(self, **kwargs) -> Self:
"""
Returns a deep copy of the Figure with specified attributes overridden.
Parameters
----------
**kwargs
Public writable properties to override in the copied Figure. The keys should be property names to modify
and the values are the new values for those properties.
Returns
-------
Figure
A new Figure instance with the specified attributes overridden.
"""
return _copy_with_overrides(self, **kwargs)
def _prepare_figure(
self,
legend: bool = True,
legend_loc: str = None,
legend_cols: int = 1,
axes: plt.Axes = None,
default_params: dict = None,
is_matplotlib_style: bool = False,
):
"""
Prepares the :class:`~graphinglib.figure.Figure` to be displayed.
"""
if default_params is not None:
self._default_params = default_params
is_a_subfigure = default_params.get("is_a_subfigure", False)
if not is_a_subfigure:
self._fill_in_rc_params()
figure_params_to_reset = self._fill_in_missing_params(self)
else:
if self._figure_style == INHERIT:
self._figure_style = get_default_style()
try:
file_loader = FileLoader(self._figure_style)
self._default_params = file_loader.load()
self._fill_in_rc_params()
except FileNotFoundError:
# set the style use matplotlib style
try:
is_matplotlib_style = True
if self._figure_style == "matplotlib":
# set the style to default
plt.style.use("default")
else:
plt.style.use(self._figure_style)
file_loader = FileLoader("plain")
self._default_params = file_loader.load()
except OSError:
raise GraphingException(
f"The figure style {self._figure_style} was not found. Please choose a different style."
)
figure_params_to_reset = self._fill_in_missing_params(self)
if axes is not None:
self._axes = axes
if self._title is not None:
self._axes.set_title(self._title, fontdict={"fontsize": "medium"})
else:
self._figure, self._axes = plt.subplots(
figsize=self._size, layout="constrained"
)
if self._title is not None:
self._axes.set_title(self._title)
if self._show_grid:
self._axes.grid(self._grid_vis_x, self._grid_which_x, "x")
self._axes.grid(self._grid_vis_y, self._grid_which_y, "y")
self._axes.set_xlabel(self._x_axis_name)
self._axes.set_ylabel(self._y_axis_name)
self._axes.set_aspect(self._aspect_ratio)
if self._custom_ticks:
if self._xticks:
self._axes.set_xticks(self._xticks, self._xticklabels)
if self._xticklabels_rotation:
self._axes.tick_params("x", labelrotation=self._xticklabels_rotation)
if self._xtick_spacing:
self._axes.xaxis.set_major_locator(
ticker.MultipleLocator(self._xtick_spacing)
)
if self._yticks:
self._axes.set_yticks(self._yticks, self._yticklabels)
if self._yticklabels_rotation:
self._axes.tick_params("y", labelrotation=self._yticklabels_rotation)
if self._ytick_spacing:
self._axes.yaxis.set_major_locator(
ticker.MultipleLocator(self._ytick_spacing)
)
if self._x_lim:
self._axes.set_xlim(*self._x_lim)
if self._y_lim:
self._axes.set_ylim(*self._y_lim)
if self._log_scale_x:
self._axes.set_xscale("log")
if self._log_scale_y:
self._axes.set_yscale("log")
if self._remove_axes:
self._axes.axis("off")
warn("Axes on this figure have been removed.")
if self._twin_x_axis:
labels, handles = self._twin_x_axis._prepare_twin_axis(
self._axes,
is_matplotlib_style,
self._default_params,
self._figure_style,
)
self._handles += handles
self._labels += labels
if self._twin_y_axis:
labels, handles = self._twin_y_axis._prepare_twin_axis(
self._axes,
is_matplotlib_style,
self._default_params,
self._figure_style,
)
self._handles += handles
self._labels += labels
cycle_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
num_cycle_colors = len(cycle_colors)
if self._elements:
z_order = 2
for index, element in enumerate(self._elements):
params_to_reset = []
if not is_matplotlib_style:
params_to_reset = self._fill_in_missing_params(element)
element._plot_element(
self._axes,
z_order,
cycle_color=cycle_colors[index % num_cycle_colors],
)
if not is_matplotlib_style:
self._reset_params_to_default(element, params_to_reset)
try:
if element.label is not None:
self._handles.append(element.handle)
self._labels.append(element.label)
except AttributeError:
continue
z_order += 5
if not self._labels:
legend = False
if legend:
if legend_loc is not None and "outside" in legend_loc:
outside_coords = {
"outside upper center": (0.5, 1),
"outside center right": (1, 0.5),
"outside lower center": (0.5, -0.1),
}
outside_keyword = {
"outside upper center": "lower center",
"outside center right": "center left",
"outside lower center": "upper center",
}
legend_params = {
"loc": outside_keyword[legend_loc],
"bbox_to_anchor": outside_coords[legend_loc],
}
else:
legend_params = {"loc": legend_loc}
try:
_legend = self._axes.legend(
handles=self._handles,
labels=self._labels,
handleheight=1.3,
handler_map={
Polygon: HandlerPatch(patch_func=histogram_legend_artist),
LineCollection: HandlerMultipleLines(),
VerticalLineCollection: HandlerMultipleVerticalLines(),
},
draggable=True,
**legend_params,
ncols=legend_cols,
)
_legend.set_zorder(10000)
except TypeError:
_legend = self._axes.legend(
handles=self._handles,
labels=self._labels,
handleheight=1.3,
handler_map={
Polygon: HandlerPatch(patch_func=histogram_legend_artist),
LineCollection: HandlerMultipleLines(),
VerticalLineCollection: HandlerMultipleVerticalLines(),
},
**legend_params,
ncols=legend_cols,
)
_legend.set_zorder(10000)
else:
raise GraphingException("No curves to be plotted!")
self._reset_params_to_default(self, figure_params_to_reset)
temp_handles = self._handles
temp_labels = self._labels
self._handles = []
self._labels = []
self._rc_dict = {}
return temp_labels, temp_handles
[docs]
def show(
self,
legend: bool = True,
legend_loc: str | tuple = "best",
legend_cols: int = 1,
) -> None:
"""
Displays the :class:`~graphinglib.figure.Figure`.
Parameters
----------
legend : bool
Whether or not to display the legend. The legend is always set to be
draggable.
Defaults to ``True``.
legend_loc : str or tuple
Legend location keyword or tuple. Any value in {"best", "upper right", "upper left",
"lower left", "lower right", "right", "center left", "center right", "lower center",
"upper center", "center"} or {"outside upper center", "outside center right",
"outside lower center"}. Tuple contains floats relative to the plots size; ``(1, 1)``
is equivalent to the upper right corner. Defaults to "best".
legend_cols : int
Number of columns in the legend.
Defaults to 1.
"""
self._prepare_figure(
legend=legend, legend_loc=legend_loc, legend_cols=legend_cols
)
plt.show()
plt.rcParams.update(plt.rcParamsDefault)
[docs]
def save(
self,
file_name: str,
legend: bool = True,
legend_loc: str | tuple = "best",
legend_cols: int = 1,
dpi: Optional[int] = None,
) -> None:
"""
Saves the :class:`~graphinglib.figure.Figure`.
Parameters
----------
file_name : str
The name of the file to save the figure to (including the file extension).
legend : bool
Wheter or not to display the legend.
Defaults to ``True``.
legend_loc : str or tuple
Legend location keyword or tuple. Any value in {"best", "upper right", "upper left",
"lower left", "lower right", "right", "center left", "center right", "lower center",
"upper center", "center"} or {"outside upper center", "outside center right",
"outside lower center"}. Tuple contains floats relative to the plots size; ``(1, 1)``
is equivalent to the upper right corner. Defaults to "best".
legend_cols : int
Number of columns in the legend.
Defaults to 1.
dpi : int
The resolution of the saved figure. Only used for raster formats (e.g. PNG, JPG, etc.).
Default depends on the ``figure_style`` configuration.
"""
self._prepare_figure(
legend=legend, legend_loc=legend_loc, legend_cols=legend_cols
)
if dpi is not None:
plt.savefig(file_name, bbox_inches="tight", dpi=dpi)
else:
plt.savefig(file_name, bbox_inches="tight")
plt.close()
plt.rcParams.update(plt.rcParamsDefault)
def _fill_in_missing_params(self, element: Plottable) -> list[str]:
"""
Fills in the missing parameters from the specified ``figure_style``.
"""
params_to_reset = []
object_type = type(element).__name__
tries = 0
while tries < 2:
try:
for property, value in vars(element).items():
if is_inherit(value):
params_to_reset.append(property)
default_value = self._default_params[object_type][property]
setattr(element, property, default_value)
break
except KeyError as e:
tries += 1
if tries >= 2:
raise GraphingException(
f"There was an error auto updating your {self._figure_style} style file following the recent GraphingLib update. Please notify the developers by creating an issue on GraphingLib's GitHub page. In the meantime, you can manually add the following parameter to your {self._figure_style} style file:\n {e.args[0]}"
)
file_updater = FileUpdater(self._figure_style)
file_updater.update()
file_loader = FileLoader(self._figure_style)
self._default_params = file_loader.load()
return params_to_reset
def _reset_params_to_default(
self, element: Plottable, params_to_reset: list[str]
) -> None:
"""
Resets the parameters that were set to default in the _fill_in_missing_params method.
"""
for param in params_to_reset:
setattr(
element,
param,
INHERIT,
)
[docs]
def set_rc_params(
self,
rc_params_dict: dict[str, str | float] = {},
reset: bool = False,
):
"""
Customize the visual style of the :class:`~graphinglib.figure.Figure`.
Any rc parameter that is not specified in the dictionary will be set to the default value for the specified ``figure_style``.
Parameters
----------
rc_params_dict : dict[str, str | float]
Dictionary of rc parameters to update.
Defaults to empty dictionary.
reset : bool
Whether or not to reset the rc parameters to the default values for the specified ``figure_style``.
Defaults to ``False``.
"""
if reset:
self._user_rc_dict = {}
for property, value in rc_params_dict.items():
self._user_rc_dict[property] = value
[docs]
def set_visual_params(
self,
reset: bool = False,
figure_face_color: str | None = None,
axes_face_color: str | None = None,
axes_edge_color: str | None = None,
axes_label_color: str | None = None,
axes_line_width: float | None = None,
color_cycle: list[str] | None = None,
x_tick_color: str | None = None,
y_tick_color: str | None = None,
legend_face_color: str | None = None,
legend_edge_color: str | None = None,
font_family: str | None = None,
font_size: float | None = None,
font_weight: str | None = None,
text_color: str | None = None,
use_latex: bool | None = None,
):
"""
Customize the visual style of the :class:`~graphinglib.figure.Figure`.
Any parameter that is not specified (None) will be set to the default value for the specified ``figure_style``.
Parameters
----------
reset : bool
Whether or not to reset the rc parameters to the default values for the specified ``figure_style``.
Defaults to ``False``.
figure_face_color : str
The color of the figure face.
Defaults to ``None``.
axes_face_color : str
The color of the axes face.
Defaults to ``None``.
axes_edge_color : str
The color of the axes edge.
Defaults to ``None``.
axes_label_color : str
The color of the axes labels.
Defaults to ``None``.
axes_line_width : float
The width of the axes lines.
Typical range is ``0.5`` to ``3``.
Defaults to ``None``.
color_cycle : list[str]
A list of colors to use for the color cycle.
Defaults to ``None``.
x_tick_color : str
The color of the x-axis ticks.
Defaults to ``None``.
y_tick_color : str
The color of the y-axis ticks.
Defaults to ``None``.
legend_face_color : str
The color of the legend face.
Defaults to ``None``.
legend_edge_color : str
The color of the legend edge.
Defaults to ``None``.
font_family : str
The font family to use.
Defaults to ``None``.
font_size : float
The font size to use.
Typical range is ``8`` to ``20``.
Defaults to ``None``.
font_weight : str
The font weight to use.
Values include ``"normal"``, ``"bold"``, ``"light"``, ``"ultralight"``, ``"heavy"``, and
``"black"``.
Defaults to ``None``.
text_color : str
The color of the text.
Defaults to ``None``.
use_latex : bool
Whether or not to use latex.
Defaults to ``None``.
Notes
-----
Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings
(``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with
values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``).
"""
if color_cycle is not None:
color_cycle = plt.cycler(color=color_cycle)
rc_params_dict = {
"figure.facecolor": figure_face_color,
"axes.facecolor": axes_face_color,
"axes.edgecolor": axes_edge_color,
"axes.labelcolor": axes_label_color,
"axes.linewidth": axes_line_width,
"axes.prop_cycle": color_cycle,
"xtick.color": x_tick_color,
"ytick.color": y_tick_color,
"legend.facecolor": legend_face_color,
"legend.edgecolor": legend_edge_color,
"font.family": font_family,
"font.size": font_size,
"font.weight": font_weight,
"text.color": text_color,
"text.usetex": use_latex,
}
rc_params_dict = {
key: value for key, value in rc_params_dict.items() if value is not None
}
self.set_rc_params(rc_params_dict, reset=reset)
def _fill_in_rc_params(self):
"""
Fills in the missing rc parameters from the specified ``figure_style``.
"""
params = self._default_params["rc_params"]
for property, value in params.items():
# add to rc_dict if not already in there
if (property not in self._rc_dict) and (property not in self._user_rc_dict):
self._rc_dict[property] = value
all_rc_params = {**self._rc_dict, **self._user_rc_dict}
try:
if all_rc_params["text.usetex"] and which("latex") is None:
all_rc_params["text.usetex"] = False
except KeyError:
pass
plt.rcParams.update(all_rc_params)
[docs]
def set_ticks(
self,
xticks: Optional[list[float]] = None,
xticklabels: Optional[list[str]] = None,
xticklabels_rotation: Optional[float] = None,
xtick_spacing: Optional[float] = None,
yticks: Optional[list[float]] = None,
yticklabels: Optional[list[str]] = None,
yticklabels_rotation: Optional[float] = None,
ytick_spacing: Optional[float] = None,
):
"""
Sets custom ticks and ticks labels.
Parameters
----------
xticks : list[float], optional
Tick positions for the x axis.
xticklabels : list[str], optional
Tick labels for the x axis.
xticklabels_rotation : float, optional
Rotation value for xtick labels.
xtick_spacing : float, optional
Spacing between ticks on the x axis.
yticks : list[float], optional
Tick positions for the y axis.
yticklabels : list[str], optional
Tick labels for the y axis.
yticklabels_rotation : float, optional
Rotation value for ytick labels.
ytick_spacing : float, optional
Spacing between ticks on the y axis.
"""
self._custom_ticks = True
self._xticks = xticks
self._xticklabels = xticklabels
self._xticklabels_rotation = xticklabels_rotation
self._xtick_spacing = xtick_spacing
self._yticks = yticks
self._yticklabels = yticklabels
self._yticklabels_rotation = yticklabels_rotation
self._ytick_spacing = ytick_spacing
if self._xticklabels is not None or self._yticklabels is not None:
if self._yticklabels and not self._yticks:
raise GraphingException(
"Ticks position must be specified when ticks labels are specified"
)
if self._xticklabels and not self._xticks:
raise GraphingException(
"Ticks position must be specified when ticks labels are specified"
)
if self._xticks is not None and self._xtick_spacing is not None:
raise GraphingException(
"Tick spacing and tick positions cannot be set simultaneously"
)
if self._yticks is not None and self._ytick_spacing is not None:
raise GraphingException(
"Tick spacing and tick positions cannot be set simultaneously"
)
[docs]
def set_grid(
self,
visible_x: bool = True,
visible_y: bool = True,
which_x: Literal["both", "major", "minor"] = "both",
which_y: Literal["both", "major", "minor"] = "both",
color: str | Inherit = INHERIT,
alpha: float | Inherit = INHERIT,
line_style: str | Inherit = INHERIT,
line_width: float | Inherit = INHERIT,
) -> None:
"""
Sets the grid parameters for the figure.
Parameters
----------
visible_x : bool, optional
If ``True``, sets the x-axis grid visible. Defaults to ``True``.
visible_y : bool, optional
If ``True``, sets the y-axis grid visible. Defaults to ``True``.
which_x : {"both", "major", "minor"}, optional
Sets whether both, only major or only minor grid lines are shown for the
x-axis. Defaults to ``"both"``.
which_y : {"both", "major", "minor"}, optional
Sets whether both, only major or only minor grid lines are shown for the
y-axis. Defaults to ``"both"``.
color : str, optional
sets the color of the grid lines.
Default depends on the ``figure_style`` configuration.
alpha : float, optional
Sets the alpha value for the grid lines.
Range is ``0`` (transparent) to ``1`` (opaque).
Default depends on the ``figure_style`` configuration.
line_style : str, optional
Sets the line style of the grid lines.
Values include ``"-"``, ``"--"``, ``"-."``, ``":"``, ``"solid"``, ``"dashed"``, ``"dashdot"``, and
``"dotted"``.
Default depends on the ``figure_style`` configuration.
line_width : float, optional
Sets the line width of the grid lines.
Typical range is ``0.5`` to ``3``.
Default depends on the ``figure_style`` configuration.
Notes
-----
Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings
(``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with
values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``).
"""
self._show_grid = True
self._grid_vis_x = visible_x
self._grid_vis_y = visible_y
self._grid_which_x = which_x
self._grid_which_y = which_y
rc_params_dict = {
"grid.color": color,
"grid.alpha": alpha,
"grid.linestyle": line_style,
"grid.linewidth": line_width,
}
rc_params_dict = {k: v for k, v in rc_params_dict.items() if v != INHERIT}
self.set_rc_params(rc_params_dict)
[docs]
def create_twin_axis(
self,
is_y: bool = True,
label: str = None,
log_scale: bool = False,
axis_lim: Optional[tuple[float, float]] = None,
) -> "TwinAxis":
"""
Creates a twin axis for the :class:`~graphinglib.figure.Figure` object.
Parameters
----------
is_y : bool
If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis.
label : str
The identification label for the twin axis.
log_scale : bool
Whether or not to set the scale of the twin axis to logaritmic scale.
Defaults to ``False``.
axis_lim : tuple[float, float], optional
The limits for the axis.
Returns
-------
:class:`~graphinglib.figure.TwinAxis`
The created twin axis.
"""
if self._remove_axes:
raise GraphingException(
"Axis in this figure were removed, therefore twin-axis can't be added."
)
twin = TwinAxis(is_y, label, log_scale, axis_lim)
if is_y:
self._twin_y_axis = twin
else:
self._twin_x_axis = twin
return twin
class TwinAxis:
"""
This class implements a twin axis for the :class:`~graphinglib.figure.Figure` class.
Behaves like a :class:`~graphinglib.figure.Figure` object, but is not meant to be used on its own.
Elements can be added to the twin axis using the :meth:`~graphinglib.figure.TwinAxis.add_element` method,
the visual style can be customized using the :meth:`~graphinglib.figure.TwinAxis.customize_visual_style` method,
and tick labels can be customized using the :meth:`~graphinglib.figure.TwinAxis.set_ticks` method.
Parameters
----------
is_y : bool
If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis.
label : str
The identification for the twin axis.
log_scale : bool
Whether or not to set the scale of the twin axis to logaritmic scale.
Defaults to ``False``.
axis_lim : tuple[float, float], optional
The limits for the axis.
"""
def __init__(
self,
is_y: bool = True,
label: Optional[str] = None,
log_scale: bool = False,
axis_lim: Optional[tuple[float, float]] = None,
):
"""
This class implements a twin axis for the :class:`~graphinglib.figure.Figure` class.
Behaves like a :class:`~graphinglib.figure.Figure` object, but is not meant to be used on its own.
Elements can be added to the twin axis using the :meth:`~graphinglib.figure.TwinAxis.add_element` method,
the visual style can be customized using the :meth:`~graphinglib.figure.TwinAxis.customize_visual_style` method,
and tick labels can be customized using the :meth:`~graphinglib.figure.TwinAxis.set_ticks` method.
Parameters
----------
is_y : bool
If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis.
label : str, optional
The identification for the twin axis.
log_scale : bool
Whether or not to set the scale of the twin axis to logaritmic scale.
Defaults to ``False``.
axis_lim : tuple[float, float], optional
The limits for the axis.
"""
self._is_y = is_y
self._label = label
self._log_scale = log_scale
self._elements: list[Plottable] = []
self._custom_ticks = False
self._labels: list[str | None] = []
self._handles = []
self._figure_style = None
self._default_params = None
self._tick_color = None
self._axes_label_color = None
self._axes_edge_color = None
self._axis_lim = axis_lim
@property
def label(self) -> str:
return self._label
@label.setter
def label(self, value: str):
self._label = value
@property
def log_scale(self) -> bool:
return self._log_scale
@log_scale.setter
def log_scale(self, value: bool):
self._log_scale = value
@property
def axis_lim(self) -> tuple[float, float]:
return self._axis_lim
def _prepare_twin_axis(
self,
fig_axes: plt.Axes,
is_matplotlib_style: bool = False,
default_params: dict = None,
figure_style: str | Inherit = INHERIT,
):
"""
Prepares the :class:`~graphinglib.figure.TwinAxis` to be displayed.
"""
self._default_params = default_params
self._figure_style = (
figure_style if figure_style != INHERIT else get_default_style()
)
if self._is_y:
self._axes = fig_axes.twinx()
self._axes.set_ylabel(self._label)
if self._axis_lim:
self._axes.set_ylim(*self._axis_lim)
else:
self._axes = fig_axes.twiny()
self._axes.set_xlabel(self._label)
if self._axis_lim:
self._axes.set_xlim(*self._axis_lim)
if self._is_y:
if self._tick_color:
self._axes.tick_params(axis="y", colors=self._tick_color)
if self._axes_label_color:
self._axes.yaxis.label.set_color(self._axes_label_color)
if self._axes_edge_color:
self._axes.spines["right"].set_color(self._axes_edge_color)
else:
if self._tick_color:
self._axes.tick_params(axis="x", colors=self._tick_color)
if self._axes_label_color:
self._axes.xaxis.label.set_color(self._axes_label_color)
if self._axes_edge_color:
self._axes.spines["top"].set_color(self._axes_edge_color)
if self._custom_ticks:
if self._ticks:
if self._is_y:
self._axes.set_yticks(self._ticks, self._ticklabels)
if self._ticklabels_rotation:
self._axes.tick_params(
"y", labelrotation=self._ticklabels_rotation
)
else:
self._axes.set_xticks(self._ticks, self._ticklabels)
if self._ticklabels_rotation:
self._axes.tick_params(
"x", labelrotation=self._ticklabels_rotation
)
if self._log_scale:
if self._is_y:
self._axes.set_yscale("log")
else:
self._axes.set_xscale("log")
z_order = 1
for element in self._elements:
params_to_reset = []
if not is_matplotlib_style:
params_to_reset = self._fill_in_missing_params(element)
element._plot_element(self._axes, z_order)
if not is_matplotlib_style:
self._reset_params_to_default(element, params_to_reset)
try:
if element.label is not None:
self._handles.append(element.handle)
self._labels.append(element.label)
except AttributeError:
continue
z_order += 2
temp_handles = self._handles
temp_labels = self._labels
self._handles = []
self._labels = []
self._rc_dict = {}
return temp_labels, temp_handles
def set_ticks(
self,
ticks: list[float],
ticklabels: list[str],
ticklabels_rotation: Optional[float] = None,
):
"""
Sets custom ticks and labels for the twin axis.
Parameters
----------
ticks : list[float], optional
Tick positions for the axis.
ticklabels : list[str], optional
Tick labels for the axis.
ticklabels_rotation : float, optional
Rotation value for the tick labels.
"""
if not ticks or not ticklabels:
raise GraphingException(
"Ticks position and corresponding labels must both be specified for the twin axis."
)
if len(ticks) != len(ticklabels):
raise GraphingException(
f"Number of ticks ({len(ticks)}) and number of tick labels ({len(ticklabels)}) must be the same."
)
self._custom_ticks = True
self._ticks = ticks
self._ticklabels = ticklabels
self._ticklabels_rotation = ticklabels_rotation
def add_elements(self, *elements: Plottable) -> None:
"""
Adds one or more :class:`~graphinglib.graph_elements.Plottable` elements to the :class:`~graphinglib.figure.Figure`.
Parameters
----------
elements : :class:`~graphinglib.graph_elements.Plottable`
Elements to plot in the :class:`~graphinglib.figure.Figure`.
"""
for element in elements:
self._elements.append(element)
def copy(self) -> Self:
"""
Returns a deep copy of the :class:`~graphinglib.figure.TwinAxis` object.
"""
return deepcopy(self)
def copy_with(self, **kwargs) -> Self:
"""
Returns a deep copy of the TwinAxis with specified attributes overridden.
Parameters
----------
**kwargs
Public writable properties to override in the copied TwinAxis. The keys should be property names to modify
and the values are the new values for those properties.
Returns
-------
TwinAxis
A new TwinAxis instance with the specified attributes overridden.
"""
return _copy_with_overrides(self, **kwargs)
def set_visual_params(
self,
axes_label_color: str | None = None,
tick_color: str | None = None,
axes_edge_color: str | None = None,
):
"""
Customize the visual style of the :class:`~graphinglib.figure.Figure`.
Any parameter that is not specified (None) will be set to the default value for the specified ``figure_style``.
Parameters
----------
axes_edge_color : str
The color of the axes edge.
Defaults to ``None``.
axes_label_color : str
The color of the axes labels.
Defaults to ``None``.
tick_color : str
The color of the axis ticks.
Defaults to ``None``.
Notes
-----
Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings
(``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with
values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``).
"""
self._axes_label_color = axes_label_color
self._tick_color = tick_color
self._axes_edge_color = axes_edge_color
def _fill_in_missing_params(self, element: Plottable) -> list[str]:
"""
Fills in the missing parameters from the specified ``figure_style``.
"""
params_to_reset = []
object_type = type(element).__name__
tries = 0
curve_defaults = {
"_errorbars_color": "_color",
"_errorbars_line_width": "_line_width",
"_cap_thickness": "_line_width",
"_fill_under_color": "_color",
}
while tries < 2:
try:
for property, value in vars(element).items():
if is_inherit(value):
params_to_reset.append(property)
default_value = self._default_params[object_type][property]
if default_value == "same as curve":
setattr(
element,
property,
getattr(element, curve_defaults[property]),
)
elif default_value == "same as scatter":
element.errorbars_color = getattr(element, "_face_color")
else:
setattr(element, property, default_value)
break
except KeyError as e:
tries += 1
if tries >= 2:
raise GraphingException(
f"There was an error auto updating your {self._figure_style} style file following the recent GraphingLib update. Please notify the developers by creating an issue on GraphingLib's GitHub page. In the meantime, you can manually add the following parameter to your {self._figure_style} style file:\n {e.args[0]}"
)
file_updater = FileUpdater(self._figure_style)
file_updater.update()
file_loader = FileLoader(self._figure_style)
self._default_params = file_loader.load()
return params_to_reset
def _reset_params_to_default(
self, element: Plottable, params_to_reset: list[str]
) -> None:
"""
Resets the parameters that were set to default in the _fill_in_missing_params method.
"""
for param in params_to_reset:
setattr(
element,
param,
INHERIT,
)