Source code for graphinglib.figure

from .inherit import INHERIT, Inherit, is_inherit

from copy import deepcopy
from shutil import which
from typing import Literal, Optional
from warnings import warn

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.collections import LineCollection
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Polygon

from .file_manager import FileLoader, FileUpdater, get_default_style
from .graph_elements import GraphingException, Plottable
from .legend_artists import (
    HandlerMultipleLines,
    HandlerMultipleVerticalLines,
    VerticalLineCollection,
    histogram_legend_artist,
)
from .tools import _copy_with_overrides

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self


[docs] class Figure: """ This class implements a general figure object. Parameters ---------- x_label, y_label : str The indentification for the x-axis and y-axis. Defaults to ``"x axis"`` and ``"y axis"``. x_lim, y_lim : tuple[float, float], optional The limits for the x-axis and y-axis. size : tuple[float, float] Overall size of the figure. Figure size is in inches; typical width is ``4`` to ``12`` and typical height is ``3`` to ``8``. Default depends on the ``figure_style`` configuration. title: str, optional The title of the figure. log_scale_x, log_scale_y : bool Whether or not to set the scale of the x- or y-axis to logaritmic scale. Default depends on the ``figure_style`` configuration. show_grid : bool Whether or not to show the grid. Default depends on the ``figure_style`` configuration. remove_axes : bool Whether or not to show the axes. Useful for adding tables or text to the subfigure. Defaults to ``False``. aspect_ratio : float, str The aspect ratio of the axis scaling. Values are ``"equal"``, ``"auto"``, or a positive float. Defaults to ``"auto"``. figure_style : str The figure style to use for the figure. """
[docs] def __init__( self, x_label: Optional[str] = None, y_label: Optional[str] = None, size: tuple[float, float] | Inherit = INHERIT, title: Optional[str] = None, x_lim: Optional[tuple[float, float]] = None, y_lim: Optional[tuple[float, float]] = None, log_scale_x: bool | Inherit = INHERIT, log_scale_y: bool | Inherit = INHERIT, remove_axes: bool = False, aspect_ratio: float | str = "auto", figure_style: str | Inherit = INHERIT, ) -> None: """ This class implements a general figure object. Parameters ---------- x_label, y_label : str, optional The indentification for the x-axis and y-axis. Defaults to ``"x axis"`` and ``"y axis"``. x_lim, y_lim : tuple[float, float], optional The limits for the x-axis and y-axis. size : tuple[float, float] Overall size of the figure. Figure size is in inches; typical width is ``4`` to ``12`` and typical height is ``3`` to ``8``. Default depends on the ``figure_style`` configuration. title: str, optional The title of the figure. log_scale_x, log_scale_y : bool Whether or not to set the scale of the x- or y-axis to logaritmic scale. Default depends on the ``figure_style`` configuration. remove_axes : bool Whether or not to show the axes. Useful for adding tables or text to the subfigure. Defaults to ``False``. aspect_ratio : float, str The aspect ratio of the axis scaling. Values are ``"equal"``, ``"auto"``, or a positive float. Defaults to ``"auto"``. figure_style : str The figure style to use for the figure. Default can be set using ``gl.set_default_style()``. """ self._figure_style = figure_style self._size = size self._title = title self._log_scale_x = log_scale_x self._log_scale_y = log_scale_y self._show_grid = False self._elements: list[Plottable] = [] self._labels: list[str | None] = [] self._handles = [] self._x_axis_name = x_label self._y_axis_name = y_label self._x_lim = x_lim self._y_lim = y_lim self._rc_dict = {} self._user_rc_dict = {} self._custom_ticks = False self._remove_axes = remove_axes self._twin_x_axis = None self._twin_y_axis = None self.aspect_ratio = aspect_ratio
@property def figure_style(self) -> str | Inherit: return self._figure_style @figure_style.setter def figure_style(self, value: str | Inherit): self._figure_style = value @property def size(self) -> tuple[float, float]: return self._size @size.setter def size(self, value: tuple[float, float]): self._size = value @property def title(self) -> str: return self._title @title.setter def title(self, value: str): self._title = value @property def log_scale_x(self) -> bool: return self._log_scale_x @log_scale_x.setter def log_scale_x(self, value: bool): self._log_scale_x = value @property def log_scale_y(self) -> bool: return self._log_scale_y @log_scale_y.setter def log_scale_y(self, value: bool): self._log_scale_y = value @property def show_grid(self) -> bool: return self._show_grid @show_grid.setter def show_grid(self, value: bool): self._show_grid = value @property def x_axis_name(self) -> str: return self._x_axis_name @x_axis_name.setter def x_axis_name(self, value: str): self._x_axis_name = value @property def y_axis_name(self) -> str: return self._y_axis_name @y_axis_name.setter def y_axis_name(self, value: str): self._y_axis_name = value @property def x_lim(self) -> tuple[float, float]: return self._x_lim @x_lim.setter def x_lim(self, value: tuple[float, float]): self._x_lim = value @property def y_lim(self) -> tuple[float, float]: return self._y_lim @y_lim.setter def y_lim(self, value: tuple[float, float]): self._y_lim = value @property def remove_axes(self) -> bool: return self._remove_axes @remove_axes.setter def remove_axes(self, value: bool): self._remove_axes = value @property def aspect_ratio(self) -> float | str: return self._aspect_ratio @aspect_ratio.setter def aspect_ratio(self, value: float | str): if isinstance(value, str): if value not in ["equal", "auto"]: raise GraphingException( "Aspect ratio must be either 'equal', 'auto' or a float." ) elif isinstance(value, (int, float)) and not isinstance(value, bool): if value <= 0: raise GraphingException("Aspect ratio must be a positive float.") else: raise GraphingException( "Aspect ratio must be either 'equal', 'auto' or a float." ) self._aspect_ratio = value
[docs] def add_elements(self, *elements: Plottable) -> None: """ Adds one or more :class:`~graphinglib.graph_elements.Plottable` elements to the :class:`~graphinglib.figure.Figure`. Parameters ---------- elements : :class:`~graphinglib.graph_elements.Plottable` Elements to plot in the :class:`~graphinglib.figure.Figure`. """ for element in elements: self._elements.append(element)
[docs] def copy(self) -> Self: """ Returns a deep copy of the :class:`~graphinglib.figure.Figure` object. """ return deepcopy(self)
[docs] def copy_with(self, **kwargs) -> Self: """ Returns a deep copy of the Figure with specified attributes overridden. Parameters ---------- **kwargs Public writable properties to override in the copied Figure. The keys should be property names to modify and the values are the new values for those properties. Returns ------- Figure A new Figure instance with the specified attributes overridden. """ return _copy_with_overrides(self, **kwargs)
def _prepare_figure( self, legend: bool = True, legend_loc: str = None, legend_cols: int = 1, axes: plt.Axes = None, default_params: dict = None, is_matplotlib_style: bool = False, ): """ Prepares the :class:`~graphinglib.figure.Figure` to be displayed. """ if default_params is not None: self._default_params = default_params is_a_subfigure = default_params.get("is_a_subfigure", False) if not is_a_subfigure: self._fill_in_rc_params() figure_params_to_reset = self._fill_in_missing_params(self) else: if self._figure_style == INHERIT: self._figure_style = get_default_style() try: file_loader = FileLoader(self._figure_style) self._default_params = file_loader.load() self._fill_in_rc_params() except FileNotFoundError: # set the style use matplotlib style try: is_matplotlib_style = True if self._figure_style == "matplotlib": # set the style to default plt.style.use("default") else: plt.style.use(self._figure_style) file_loader = FileLoader("plain") self._default_params = file_loader.load() except OSError: raise GraphingException( f"The figure style {self._figure_style} was not found. Please choose a different style." ) figure_params_to_reset = self._fill_in_missing_params(self) if axes is not None: self._axes = axes if self._title is not None: self._axes.set_title(self._title, fontdict={"fontsize": "medium"}) else: self._figure, self._axes = plt.subplots( figsize=self._size, layout="constrained" ) if self._title is not None: self._axes.set_title(self._title) if self._show_grid: self._axes.grid(self._grid_vis_x, self._grid_which_x, "x") self._axes.grid(self._grid_vis_y, self._grid_which_y, "y") self._axes.set_xlabel(self._x_axis_name) self._axes.set_ylabel(self._y_axis_name) self._axes.set_aspect(self._aspect_ratio) if self._custom_ticks: if self._xticks: self._axes.set_xticks(self._xticks, self._xticklabels) if self._xticklabels_rotation: self._axes.tick_params("x", labelrotation=self._xticklabels_rotation) if self._xtick_spacing: self._axes.xaxis.set_major_locator( ticker.MultipleLocator(self._xtick_spacing) ) if self._yticks: self._axes.set_yticks(self._yticks, self._yticklabels) if self._yticklabels_rotation: self._axes.tick_params("y", labelrotation=self._yticklabels_rotation) if self._ytick_spacing: self._axes.yaxis.set_major_locator( ticker.MultipleLocator(self._ytick_spacing) ) if self._x_lim: self._axes.set_xlim(*self._x_lim) if self._y_lim: self._axes.set_ylim(*self._y_lim) if self._log_scale_x: self._axes.set_xscale("log") if self._log_scale_y: self._axes.set_yscale("log") if self._remove_axes: self._axes.axis("off") warn("Axes on this figure have been removed.") if self._twin_x_axis: labels, handles = self._twin_x_axis._prepare_twin_axis( self._axes, is_matplotlib_style, self._default_params, self._figure_style, ) self._handles += handles self._labels += labels if self._twin_y_axis: labels, handles = self._twin_y_axis._prepare_twin_axis( self._axes, is_matplotlib_style, self._default_params, self._figure_style, ) self._handles += handles self._labels += labels cycle_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] num_cycle_colors = len(cycle_colors) if self._elements: z_order = 2 for index, element in enumerate(self._elements): params_to_reset = [] if not is_matplotlib_style: params_to_reset = self._fill_in_missing_params(element) element._plot_element( self._axes, z_order, cycle_color=cycle_colors[index % num_cycle_colors], ) if not is_matplotlib_style: self._reset_params_to_default(element, params_to_reset) try: if element.label is not None: self._handles.append(element.handle) self._labels.append(element.label) except AttributeError: continue z_order += 5 if not self._labels: legend = False if legend: if legend_loc is not None and "outside" in legend_loc: outside_coords = { "outside upper center": (0.5, 1), "outside center right": (1, 0.5), "outside lower center": (0.5, -0.1), } outside_keyword = { "outside upper center": "lower center", "outside center right": "center left", "outside lower center": "upper center", } legend_params = { "loc": outside_keyword[legend_loc], "bbox_to_anchor": outside_coords[legend_loc], } else: legend_params = {"loc": legend_loc} try: _legend = self._axes.legend( handles=self._handles, labels=self._labels, handleheight=1.3, handler_map={ Polygon: HandlerPatch(patch_func=histogram_legend_artist), LineCollection: HandlerMultipleLines(), VerticalLineCollection: HandlerMultipleVerticalLines(), }, draggable=True, **legend_params, ncols=legend_cols, ) _legend.set_zorder(10000) except TypeError: _legend = self._axes.legend( handles=self._handles, labels=self._labels, handleheight=1.3, handler_map={ Polygon: HandlerPatch(patch_func=histogram_legend_artist), LineCollection: HandlerMultipleLines(), VerticalLineCollection: HandlerMultipleVerticalLines(), }, **legend_params, ncols=legend_cols, ) _legend.set_zorder(10000) else: raise GraphingException("No curves to be plotted!") self._reset_params_to_default(self, figure_params_to_reset) temp_handles = self._handles temp_labels = self._labels self._handles = [] self._labels = [] self._rc_dict = {} return temp_labels, temp_handles
[docs] def show( self, legend: bool = True, legend_loc: str | tuple = "best", legend_cols: int = 1, ) -> None: """ Displays the :class:`~graphinglib.figure.Figure`. Parameters ---------- legend : bool Whether or not to display the legend. The legend is always set to be draggable. Defaults to ``True``. legend_loc : str or tuple Legend location keyword or tuple. Any value in {"best", "upper right", "upper left", "lower left", "lower right", "right", "center left", "center right", "lower center", "upper center", "center"} or {"outside upper center", "outside center right", "outside lower center"}. Tuple contains floats relative to the plots size; ``(1, 1)`` is equivalent to the upper right corner. Defaults to "best". legend_cols : int Number of columns in the legend. Defaults to 1. """ self._prepare_figure( legend=legend, legend_loc=legend_loc, legend_cols=legend_cols ) plt.show() plt.rcParams.update(plt.rcParamsDefault)
[docs] def save( self, file_name: str, legend: bool = True, legend_loc: str | tuple = "best", legend_cols: int = 1, dpi: Optional[int] = None, ) -> None: """ Saves the :class:`~graphinglib.figure.Figure`. Parameters ---------- file_name : str The name of the file to save the figure to (including the file extension). legend : bool Wheter or not to display the legend. Defaults to ``True``. legend_loc : str or tuple Legend location keyword or tuple. Any value in {"best", "upper right", "upper left", "lower left", "lower right", "right", "center left", "center right", "lower center", "upper center", "center"} or {"outside upper center", "outside center right", "outside lower center"}. Tuple contains floats relative to the plots size; ``(1, 1)`` is equivalent to the upper right corner. Defaults to "best". legend_cols : int Number of columns in the legend. Defaults to 1. dpi : int The resolution of the saved figure. Only used for raster formats (e.g. PNG, JPG, etc.). Default depends on the ``figure_style`` configuration. """ self._prepare_figure( legend=legend, legend_loc=legend_loc, legend_cols=legend_cols ) if dpi is not None: plt.savefig(file_name, bbox_inches="tight", dpi=dpi) else: plt.savefig(file_name, bbox_inches="tight") plt.close() plt.rcParams.update(plt.rcParamsDefault)
def _fill_in_missing_params(self, element: Plottable) -> list[str]: """ Fills in the missing parameters from the specified ``figure_style``. """ params_to_reset = [] object_type = type(element).__name__ tries = 0 while tries < 2: try: for property, value in vars(element).items(): if is_inherit(value): params_to_reset.append(property) default_value = self._default_params[object_type][property] setattr(element, property, default_value) break except KeyError as e: tries += 1 if tries >= 2: raise GraphingException( f"There was an error auto updating your {self._figure_style} style file following the recent GraphingLib update. Please notify the developers by creating an issue on GraphingLib's GitHub page. In the meantime, you can manually add the following parameter to your {self._figure_style} style file:\n {e.args[0]}" ) file_updater = FileUpdater(self._figure_style) file_updater.update() file_loader = FileLoader(self._figure_style) self._default_params = file_loader.load() return params_to_reset def _reset_params_to_default( self, element: Plottable, params_to_reset: list[str] ) -> None: """ Resets the parameters that were set to default in the _fill_in_missing_params method. """ for param in params_to_reset: setattr( element, param, INHERIT, )
[docs] def set_rc_params( self, rc_params_dict: dict[str, str | float] = {}, reset: bool = False, ): """ Customize the visual style of the :class:`~graphinglib.figure.Figure`. Any rc parameter that is not specified in the dictionary will be set to the default value for the specified ``figure_style``. Parameters ---------- rc_params_dict : dict[str, str | float] Dictionary of rc parameters to update. Defaults to empty dictionary. reset : bool Whether or not to reset the rc parameters to the default values for the specified ``figure_style``. Defaults to ``False``. """ if reset: self._user_rc_dict = {} for property, value in rc_params_dict.items(): self._user_rc_dict[property] = value
[docs] def set_visual_params( self, reset: bool = False, figure_face_color: str | None = None, axes_face_color: str | None = None, axes_edge_color: str | None = None, axes_label_color: str | None = None, axes_line_width: float | None = None, color_cycle: list[str] | None = None, x_tick_color: str | None = None, y_tick_color: str | None = None, legend_face_color: str | None = None, legend_edge_color: str | None = None, font_family: str | None = None, font_size: float | None = None, font_weight: str | None = None, text_color: str | None = None, use_latex: bool | None = None, ): """ Customize the visual style of the :class:`~graphinglib.figure.Figure`. Any parameter that is not specified (None) will be set to the default value for the specified ``figure_style``. Parameters ---------- reset : bool Whether or not to reset the rc parameters to the default values for the specified ``figure_style``. Defaults to ``False``. figure_face_color : str The color of the figure face. Defaults to ``None``. axes_face_color : str The color of the axes face. Defaults to ``None``. axes_edge_color : str The color of the axes edge. Defaults to ``None``. axes_label_color : str The color of the axes labels. Defaults to ``None``. axes_line_width : float The width of the axes lines. Typical range is ``0.5`` to ``3``. Defaults to ``None``. color_cycle : list[str] A list of colors to use for the color cycle. Defaults to ``None``. x_tick_color : str The color of the x-axis ticks. Defaults to ``None``. y_tick_color : str The color of the y-axis ticks. Defaults to ``None``. legend_face_color : str The color of the legend face. Defaults to ``None``. legend_edge_color : str The color of the legend edge. Defaults to ``None``. font_family : str The font family to use. Defaults to ``None``. font_size : float The font size to use. Typical range is ``8`` to ``20``. Defaults to ``None``. font_weight : str The font weight to use. Values include ``"normal"``, ``"bold"``, ``"light"``, ``"ultralight"``, ``"heavy"``, and ``"black"``. Defaults to ``None``. text_color : str The color of the text. Defaults to ``None``. use_latex : bool Whether or not to use latex. Defaults to ``None``. Notes ----- Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings (``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``). """ if color_cycle is not None: color_cycle = plt.cycler(color=color_cycle) rc_params_dict = { "figure.facecolor": figure_face_color, "axes.facecolor": axes_face_color, "axes.edgecolor": axes_edge_color, "axes.labelcolor": axes_label_color, "axes.linewidth": axes_line_width, "axes.prop_cycle": color_cycle, "xtick.color": x_tick_color, "ytick.color": y_tick_color, "legend.facecolor": legend_face_color, "legend.edgecolor": legend_edge_color, "font.family": font_family, "font.size": font_size, "font.weight": font_weight, "text.color": text_color, "text.usetex": use_latex, } rc_params_dict = { key: value for key, value in rc_params_dict.items() if value is not None } self.set_rc_params(rc_params_dict, reset=reset)
def _fill_in_rc_params(self): """ Fills in the missing rc parameters from the specified ``figure_style``. """ params = self._default_params["rc_params"] for property, value in params.items(): # add to rc_dict if not already in there if (property not in self._rc_dict) and (property not in self._user_rc_dict): self._rc_dict[property] = value all_rc_params = {**self._rc_dict, **self._user_rc_dict} try: if all_rc_params["text.usetex"] and which("latex") is None: all_rc_params["text.usetex"] = False except KeyError: pass plt.rcParams.update(all_rc_params)
[docs] def set_ticks( self, xticks: Optional[list[float]] = None, xticklabels: Optional[list[str]] = None, xticklabels_rotation: Optional[float] = None, xtick_spacing: Optional[float] = None, yticks: Optional[list[float]] = None, yticklabels: Optional[list[str]] = None, yticklabels_rotation: Optional[float] = None, ytick_spacing: Optional[float] = None, ): """ Sets custom ticks and ticks labels. Parameters ---------- xticks : list[float], optional Tick positions for the x axis. xticklabels : list[str], optional Tick labels for the x axis. xticklabels_rotation : float, optional Rotation value for xtick labels. xtick_spacing : float, optional Spacing between ticks on the x axis. yticks : list[float], optional Tick positions for the y axis. yticklabels : list[str], optional Tick labels for the y axis. yticklabels_rotation : float, optional Rotation value for ytick labels. ytick_spacing : float, optional Spacing between ticks on the y axis. """ self._custom_ticks = True self._xticks = xticks self._xticklabels = xticklabels self._xticklabels_rotation = xticklabels_rotation self._xtick_spacing = xtick_spacing self._yticks = yticks self._yticklabels = yticklabels self._yticklabels_rotation = yticklabels_rotation self._ytick_spacing = ytick_spacing if self._xticklabels is not None or self._yticklabels is not None: if self._yticklabels and not self._yticks: raise GraphingException( "Ticks position must be specified when ticks labels are specified" ) if self._xticklabels and not self._xticks: raise GraphingException( "Ticks position must be specified when ticks labels are specified" ) if self._xticks is not None and self._xtick_spacing is not None: raise GraphingException( "Tick spacing and tick positions cannot be set simultaneously" ) if self._yticks is not None and self._ytick_spacing is not None: raise GraphingException( "Tick spacing and tick positions cannot be set simultaneously" )
[docs] def set_grid( self, visible_x: bool = True, visible_y: bool = True, which_x: Literal["both", "major", "minor"] = "both", which_y: Literal["both", "major", "minor"] = "both", color: str | Inherit = INHERIT, alpha: float | Inherit = INHERIT, line_style: str | Inherit = INHERIT, line_width: float | Inherit = INHERIT, ) -> None: """ Sets the grid parameters for the figure. Parameters ---------- visible_x : bool, optional If ``True``, sets the x-axis grid visible. Defaults to ``True``. visible_y : bool, optional If ``True``, sets the y-axis grid visible. Defaults to ``True``. which_x : {"both", "major", "minor"}, optional Sets whether both, only major or only minor grid lines are shown for the x-axis. Defaults to ``"both"``. which_y : {"both", "major", "minor"}, optional Sets whether both, only major or only minor grid lines are shown for the y-axis. Defaults to ``"both"``. color : str, optional sets the color of the grid lines. Default depends on the ``figure_style`` configuration. alpha : float, optional Sets the alpha value for the grid lines. Range is ``0`` (transparent) to ``1`` (opaque). Default depends on the ``figure_style`` configuration. line_style : str, optional Sets the line style of the grid lines. Values include ``"-"``, ``"--"``, ``"-."``, ``":"``, ``"solid"``, ``"dashed"``, ``"dashdot"``, and ``"dotted"``. Default depends on the ``figure_style`` configuration. line_width : float, optional Sets the line width of the grid lines. Typical range is ``0.5`` to ``3``. Default depends on the ``figure_style`` configuration. Notes ----- Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings (``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``). """ self._show_grid = True self._grid_vis_x = visible_x self._grid_vis_y = visible_y self._grid_which_x = which_x self._grid_which_y = which_y rc_params_dict = { "grid.color": color, "grid.alpha": alpha, "grid.linestyle": line_style, "grid.linewidth": line_width, } rc_params_dict = {k: v for k, v in rc_params_dict.items() if v != INHERIT} self.set_rc_params(rc_params_dict)
[docs] def create_twin_axis( self, is_y: bool = True, label: str = None, log_scale: bool = False, axis_lim: Optional[tuple[float, float]] = None, ) -> "TwinAxis": """ Creates a twin axis for the :class:`~graphinglib.figure.Figure` object. Parameters ---------- is_y : bool If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis. label : str The identification label for the twin axis. log_scale : bool Whether or not to set the scale of the twin axis to logaritmic scale. Defaults to ``False``. axis_lim : tuple[float, float], optional The limits for the axis. Returns ------- :class:`~graphinglib.figure.TwinAxis` The created twin axis. """ if self._remove_axes: raise GraphingException( "Axis in this figure were removed, therefore twin-axis can't be added." ) twin = TwinAxis(is_y, label, log_scale, axis_lim) if is_y: self._twin_y_axis = twin else: self._twin_x_axis = twin return twin
class TwinAxis: """ This class implements a twin axis for the :class:`~graphinglib.figure.Figure` class. Behaves like a :class:`~graphinglib.figure.Figure` object, but is not meant to be used on its own. Elements can be added to the twin axis using the :meth:`~graphinglib.figure.TwinAxis.add_element` method, the visual style can be customized using the :meth:`~graphinglib.figure.TwinAxis.customize_visual_style` method, and tick labels can be customized using the :meth:`~graphinglib.figure.TwinAxis.set_ticks` method. Parameters ---------- is_y : bool If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis. label : str The identification for the twin axis. log_scale : bool Whether or not to set the scale of the twin axis to logaritmic scale. Defaults to ``False``. axis_lim : tuple[float, float], optional The limits for the axis. """ def __init__( self, is_y: bool = True, label: Optional[str] = None, log_scale: bool = False, axis_lim: Optional[tuple[float, float]] = None, ): """ This class implements a twin axis for the :class:`~graphinglib.figure.Figure` class. Behaves like a :class:`~graphinglib.figure.Figure` object, but is not meant to be used on its own. Elements can be added to the twin axis using the :meth:`~graphinglib.figure.TwinAxis.add_element` method, the visual style can be customized using the :meth:`~graphinglib.figure.TwinAxis.customize_visual_style` method, and tick labels can be customized using the :meth:`~graphinglib.figure.TwinAxis.set_ticks` method. Parameters ---------- is_y : bool If ``True``, the twin axis will be a y-axis, otherwise it will be an x-axis. label : str, optional The identification for the twin axis. log_scale : bool Whether or not to set the scale of the twin axis to logaritmic scale. Defaults to ``False``. axis_lim : tuple[float, float], optional The limits for the axis. """ self._is_y = is_y self._label = label self._log_scale = log_scale self._elements: list[Plottable] = [] self._custom_ticks = False self._labels: list[str | None] = [] self._handles = [] self._figure_style = None self._default_params = None self._tick_color = None self._axes_label_color = None self._axes_edge_color = None self._axis_lim = axis_lim @property def label(self) -> str: return self._label @label.setter def label(self, value: str): self._label = value @property def log_scale(self) -> bool: return self._log_scale @log_scale.setter def log_scale(self, value: bool): self._log_scale = value @property def axis_lim(self) -> tuple[float, float]: return self._axis_lim def _prepare_twin_axis( self, fig_axes: plt.Axes, is_matplotlib_style: bool = False, default_params: dict = None, figure_style: str | Inherit = INHERIT, ): """ Prepares the :class:`~graphinglib.figure.TwinAxis` to be displayed. """ self._default_params = default_params self._figure_style = ( figure_style if figure_style != INHERIT else get_default_style() ) if self._is_y: self._axes = fig_axes.twinx() self._axes.set_ylabel(self._label) if self._axis_lim: self._axes.set_ylim(*self._axis_lim) else: self._axes = fig_axes.twiny() self._axes.set_xlabel(self._label) if self._axis_lim: self._axes.set_xlim(*self._axis_lim) if self._is_y: if self._tick_color: self._axes.tick_params(axis="y", colors=self._tick_color) if self._axes_label_color: self._axes.yaxis.label.set_color(self._axes_label_color) if self._axes_edge_color: self._axes.spines["right"].set_color(self._axes_edge_color) else: if self._tick_color: self._axes.tick_params(axis="x", colors=self._tick_color) if self._axes_label_color: self._axes.xaxis.label.set_color(self._axes_label_color) if self._axes_edge_color: self._axes.spines["top"].set_color(self._axes_edge_color) if self._custom_ticks: if self._ticks: if self._is_y: self._axes.set_yticks(self._ticks, self._ticklabels) if self._ticklabels_rotation: self._axes.tick_params( "y", labelrotation=self._ticklabels_rotation ) else: self._axes.set_xticks(self._ticks, self._ticklabels) if self._ticklabels_rotation: self._axes.tick_params( "x", labelrotation=self._ticklabels_rotation ) if self._log_scale: if self._is_y: self._axes.set_yscale("log") else: self._axes.set_xscale("log") z_order = 1 for element in self._elements: params_to_reset = [] if not is_matplotlib_style: params_to_reset = self._fill_in_missing_params(element) element._plot_element(self._axes, z_order) if not is_matplotlib_style: self._reset_params_to_default(element, params_to_reset) try: if element.label is not None: self._handles.append(element.handle) self._labels.append(element.label) except AttributeError: continue z_order += 2 temp_handles = self._handles temp_labels = self._labels self._handles = [] self._labels = [] self._rc_dict = {} return temp_labels, temp_handles def set_ticks( self, ticks: list[float], ticklabels: list[str], ticklabels_rotation: Optional[float] = None, ): """ Sets custom ticks and labels for the twin axis. Parameters ---------- ticks : list[float], optional Tick positions for the axis. ticklabels : list[str], optional Tick labels for the axis. ticklabels_rotation : float, optional Rotation value for the tick labels. """ if not ticks or not ticklabels: raise GraphingException( "Ticks position and corresponding labels must both be specified for the twin axis." ) if len(ticks) != len(ticklabels): raise GraphingException( f"Number of ticks ({len(ticks)}) and number of tick labels ({len(ticklabels)}) must be the same." ) self._custom_ticks = True self._ticks = ticks self._ticklabels = ticklabels self._ticklabels_rotation = ticklabels_rotation def add_elements(self, *elements: Plottable) -> None: """ Adds one or more :class:`~graphinglib.graph_elements.Plottable` elements to the :class:`~graphinglib.figure.Figure`. Parameters ---------- elements : :class:`~graphinglib.graph_elements.Plottable` Elements to plot in the :class:`~graphinglib.figure.Figure`. """ for element in elements: self._elements.append(element) def copy(self) -> Self: """ Returns a deep copy of the :class:`~graphinglib.figure.TwinAxis` object. """ return deepcopy(self) def copy_with(self, **kwargs) -> Self: """ Returns a deep copy of the TwinAxis with specified attributes overridden. Parameters ---------- **kwargs Public writable properties to override in the copied TwinAxis. The keys should be property names to modify and the values are the new values for those properties. Returns ------- TwinAxis A new TwinAxis instance with the specified attributes overridden. """ return _copy_with_overrides(self, **kwargs) def set_visual_params( self, axes_label_color: str | None = None, tick_color: str | None = None, axes_edge_color: str | None = None, ): """ Customize the visual style of the :class:`~graphinglib.figure.Figure`. Any parameter that is not specified (None) will be set to the default value for the specified ``figure_style``. Parameters ---------- axes_edge_color : str The color of the axes edge. Defaults to ``None``. axes_label_color : str The color of the axes labels. Defaults to ``None``. tick_color : str The color of the axis ticks. Defaults to ``None``. Notes ----- Color parameters accept Matplotlib color formats: named colors (``"blue"``), short color strings (``"b"``), hex strings (``"#0000ff"``), grayscale strings (``"0.5"``), and RGB/RGBA tuples with values between ``0`` and ``1`` (``(0, 0, 1)`` or ``(0, 0, 1, 0.5)``). """ self._axes_label_color = axes_label_color self._tick_color = tick_color self._axes_edge_color = axes_edge_color def _fill_in_missing_params(self, element: Plottable) -> list[str]: """ Fills in the missing parameters from the specified ``figure_style``. """ params_to_reset = [] object_type = type(element).__name__ tries = 0 curve_defaults = { "_errorbars_color": "_color", "_errorbars_line_width": "_line_width", "_cap_thickness": "_line_width", "_fill_under_color": "_color", } while tries < 2: try: for property, value in vars(element).items(): if is_inherit(value): params_to_reset.append(property) default_value = self._default_params[object_type][property] if default_value == "same as curve": setattr( element, property, getattr(element, curve_defaults[property]), ) elif default_value == "same as scatter": element.errorbars_color = getattr(element, "_face_color") else: setattr(element, property, default_value) break except KeyError as e: tries += 1 if tries >= 2: raise GraphingException( f"There was an error auto updating your {self._figure_style} style file following the recent GraphingLib update. Please notify the developers by creating an issue on GraphingLib's GitHub page. In the meantime, you can manually add the following parameter to your {self._figure_style} style file:\n {e.args[0]}" ) file_updater = FileUpdater(self._figure_style) file_updater.update() file_loader = FileLoader(self._figure_style) self._default_params = file_loader.load() return params_to_reset def _reset_params_to_default( self, element: Plottable, params_to_reset: list[str] ) -> None: """ Resets the parameters that were set to default in the _fill_in_missing_params method. """ for param in params_to_reset: setattr( element, param, INHERIT, )