Source code for figanos.matplotlib.plot

# noqa: D100
from __future__ import annotations
import copy
import logging
import math
import string
import warnings
from collections.abc import Iterable
from inspect import signature
from pathlib import Path
from typing import Any

import cartopy.mpl.geoaxes
import geopandas as gpd
import matplotlib
import matplotlib.axes
import matplotlib.colors
import matplotlib.pyplot as plt
import mpl_toolkits.axisartist.grid_finder as gf
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from cartopy import crs as ccrs
from matplotlib.cm import ScalarMappable
from matplotlib.lines import Line2D
from matplotlib.projections import PolarAxes
from matplotlib.tri import Triangulation
from mpl_toolkits.axisartist.floating_axes import FloatingSubplot, GridHelperCurveLinear

from figanos.matplotlib.utils import (  # masknan_sizes_key,
    add_cartopy_features,
    add_features_map,
    check_timeindex,
    convert_scen_name,
    create_cmap,
    custom_cmap_norm,
    empty_dict,
    fill_between_label,
    get_array_categ,
    get_attributes,
    get_localized_term,
    get_rotpole,
    get_scen_color,
    get_var_group,
    gpd_to_ccrs,
    norm2range,
    plot_coords,
    process_keys,
    set_plot_attrs,
    size_legend_elements,
    sort_lines,
    split_legend,
    wrap_text,
)


logger = logging.getLogger(__name__)


def _plot_realizations(
    ax: matplotlib.axes.Axes,
    da: xr.DataArray,
    name: str,
    plot_kw: dict[str, Any],
    non_dict_data: dict[str, Any],
) -> matplotlib.axes.Axes:
    """
    Plot realizations from a DataArray, inside or outside a Dataset.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        The Matplotlib axis object.
    da : DataArray
        The DataArray containing the realizations.
    name : str
        The label to be used in the first part of a composite label.
        Can be the name of the parent Dataset or that of the DataArray.
    plot_kw : dict
        Dictionary of kwargs coming from the timeseries() input.
    non_dict_data : dict
        TBD.

    Returns
    -------
    matplotlib.axes.Axes
    """
    ignore_label = False

    for r in da.realization.values:
        if plot_kw[name]:  # if kwargs (all lines identical)
            if not ignore_label:  # if label not already in legend
                label = "" if non_dict_data is True else name
                ignore_label = True
            else:
                label = ""
        else:
            label = str(r) if non_dict_data is True else (name + "_" + str(r))

        ax.plot(
            da.sel(realization=r)["time"],
            da.sel(realization=r).values,
            label=label,
            **plot_kw[name],
        )

    return ax


def _plot_timeseries(
    ax: matplotlib.axes.Axes,
    name: str,
    arr: xr.DataArray | xr.Dataset,
    plot_kw: dict[str, Any],
    non_dict_data: bool,
    array_categ: dict[str, Any],
    legend: str,
) -> matplotlib.axes.Axes:
    """
    Plot figanos timeseries.

    Parameters
    ----------
    ax: matplotlib.axes.Axes
        Axe to be used for plotting.
    name : str
        Dictionary key of the plotted data.
    arr : Dataset/DataArray
        Data to be plotted.
    plot_kw : dict
        Dictionary of kwargs coming from the timeseries() input.
    non_dic_data : bool
        If True, plot_kw is not a dictionary.
    array_categ: dict
        Categories of data.
    legend: str
        Legend type.

    Returns
    -------
    matplotlib.axes.Axes
    """
    lines_dict = {}  # created to facilitate accessing line properties later
    # look for SSP, RCP, CMIP model color
    cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
    if get_scen_color(name, cat_colors):
        plot_kw[name].setdefault("color", get_scen_color(name, cat_colors))

    #  remove 'label' to avoid error due to double 'label' args
    if "label" in plot_kw[name]:
        del plot_kw[name]["label"]
        warnings.warn(f'"label" entry in plot_kw[{name}] will be ignored.', stacklevel=2)

    if array_categ[name] == "ENS_REALS_DA":
        _plot_realizations(ax, arr, name, plot_kw, non_dict_data)

    elif array_categ[name] == "ENS_REALS_DS":
        if len(arr.data_vars) >= 2:
            raise TypeError(
                "To plot multiple ensembles containing realizations, use DataArrays outside a Dataset"
            )
        for sub_arr in arr.data_vars.values():
            _plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data)

    elif array_categ[name] == "ENS_PCT_DIM_DS":
        for sub_arr in arr.data_vars.values():
            sub_name = (
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
            )

            # extract each percentile array from the dims
            array_data = {}
            for pct in sub_arr.percentiles.values:
                array_data[str(pct)] = sub_arr.sel(percentiles=pct)

            # create a dictionary labeling the middle, upper and lower line
            sorted_lines = sort_lines(array_data)

            # plot
            lines_dict[sub_name] = ax.plot(
                array_data[sorted_lines["middle"]]["time"],
                array_data[sorted_lines["middle"]].values,
                label=sub_name,
                **plot_kw[name],
            )

            ax.fill_between(
                array_data[sorted_lines["lower"]]["time"],
                array_data[sorted_lines["lower"]].values,
                array_data[sorted_lines["upper"]].values,
                color=lines_dict[sub_name][0].get_color(),
                linewidth=0.0,
                alpha=0.2,
                label=fill_between_label(sorted_lines, name, array_categ, legend),
            )

    # other ensembles
    elif array_categ[name] in [
        "ENS_PCT_VAR_DS",
        "ENS_STATS_VAR_DS",
        "ENS_PCT_DIM_DA",
    ]:
        # extract each array from the datasets
        array_data = {}
        if array_categ[name] == "ENS_PCT_DIM_DA":
            for pct in arr.percentiles:
                array_data[str(int(pct))] = arr.sel(percentiles=int(pct))
        else:
            for k, v in arr.data_vars.items():
                array_data[k] = v

        # create a dictionary labeling the middle, upper and lower line
        sorted_lines = sort_lines(array_data)

        # plot
        lines_dict[name] = ax.plot(
            array_data[sorted_lines["middle"]]["time"],
            array_data[sorted_lines["middle"]].values,
            label=name,
            **plot_kw[name],
        )

        ax.fill_between(
            array_data[sorted_lines["lower"]]["time"],
            array_data[sorted_lines["lower"]].values,
            array_data[sorted_lines["upper"]].values,
            color=lines_dict[name][0].get_color(),
            linewidth=0.0,
            alpha=0.2,
            label=fill_between_label(sorted_lines, name, array_categ, legend),
        )

    #  non-ensemble Datasets
    elif array_categ[name] == "DS":
        ignore_label = False
        for sub_arr in arr.data_vars.values():
            sub_name = (
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
            )

            #  if kwargs are specified by user, all lines are the same and we want one legend entry
            if plot_kw[name]:
                label = name if not ignore_label else ""
                ignore_label = True
            else:
                label = sub_name

            lines_dict[sub_name] = ax.plot(
                sub_arr["time"], sub_arr.values, label=label, **plot_kw[name]
            )

    #  non-ensemble DataArrays
    elif array_categ[name] in ["DA"]:
        lines_dict[name] = ax.plot(arr["time"], arr.values, label=name, **plot_kw[name])

    else:
        raise ValueError(
            "Data structure not supported"
        )  # can probably be removed along with elif logic above,
        # given that get_array_categ() also does this check
    return ax


[docs] def timeseries( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, legend: str = "lines", show_lat_lon: bool | str | int | tuple[float, float] = True, enumerate_subplots: bool = False, ) -> matplotlib.axes.Axes: """ Plot time series from 1D Xarray Datasets or DataArrays as line plots. Parameters ---------- data : dict or Dataset/DataArray Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets. ax : matplotlib.axes.Axes, optional Matplotlib axis on which to plot. use_attrs : dict, optional A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. plot_kw : dict, optional Arguments to pass to the `plot()` function. Changes how the line looks. If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'. legend : str (default 'lines') or dict 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), 'edge' (out of plot), 'facetgrid' under figure, 'none' (no legend). If dict, arguments to pass to ax.legend(). show_lat_lon : bool, tuple, str or int If True, show latitude and longitude at the bottom right of the figure. Can be a tuple of axis coordinates (from 0 to 1, as a fraction of the axis length) representing the location of the text. If a string or an int, the same values as those of the 'loc' parameter of matplotlib's legends are accepted. ================== ============= Location String Location Code ================== ============= 'upper right' 1 'upper left' 2 'lower left' 3 'lower right' 4 'right' 5 'center left' 6 'center right' 7 'lower center' 8 'upper center' 9 'center' 10 ================== ============= enumerate_subplots: bool If True, enumerate subplots with letters. Only works with facetgrids (pass `col` or `row` in plot_kw). Returns ------- matplotlib.axes.Axes """ # convert SSP, RCP, CMIP formats in keys if isinstance(data, dict): data = process_keys(data, convert_scen_name) if isinstance(plot_kw, dict): plot_kw = process_keys(plot_kw, convert_scen_name) # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) # if only one data input, insert in dict. non_dict_data = False if not isinstance(data, dict): non_dict_data = True data = {"_no_label": data} # mpl excludes labels starting with "_" from legend plot_kw = {"_no_label": empty_dict(plot_kw)} # assign keys to plot_kw if not there if non_dict_data is False: for name in data: if name not in plot_kw: plot_kw[name] = {} for key in plot_kw: if key not in data: raise KeyError( 'plot_kw must be a nested dictionary with keys corresponding to the keys in "data"' ) # check: type for arr in data.values(): if not isinstance(arr, xr.Dataset | xr.DataArray): raise TypeError( '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.' ) # check: 'time' dimension and calendar format data = check_timeindex(data) # set fig, ax if not provided if ax is None and ( "row" not in list(plot_kw.values())[0].keys() and "col" not in list(plot_kw.values())[0].keys() ): fig, ax = plt.subplots(**fig_kw) elif ax is not None and ( "col" in list(plot_kw.values())[0].keys() or "row" in list(plot_kw.values())[0].keys() ): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: cfig_kw = fig_kw.copy() if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid list(plot_kw.values())[0].setdefault("figsize", fig_kw["figsize"]) cfig_kw.pop("figsize") if cfig_kw: for v in plot_kw.values(): {"subplots_kws": cfig_kw} | v warnings.warn( "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2 ) # set default use_attrs values if ax: use_attrs.setdefault("title", "description") else: use_attrs.setdefault("suptitle", "description") use_attrs.setdefault("ylabel", "long_name") use_attrs.setdefault("yunits", "units") # dict of array 'categories' array_categ = {name: get_array_categ(array) for name, array in data.items()} cp_plot_kw = copy.deepcopy(plot_kw) # get data and plot for name, arr in data.items(): if ax: _plot_timeseries(ax, name, arr, plot_kw, non_dict_data, array_categ, legend) else: if name == list(data.keys())[0]: # create empty DataArray with same dimensions as data first entry to create an empty xr.plot.FacetGrid if isinstance(arr, xr.Dataset): da = arr[list(arr.keys())[0]] else: da = arr da = da.where(da == np.nan) im = da.plot(**plot_kw[name], color="white") [ cp_plot_kw[name].pop(key) for key in ["row", "col", "figsize"] if key in cp_plot_kw[name].keys() ] # plot data in every axis of the facetgrid for i in range(0, im.axs.shape[0]): for j in range(0, im.axs.shape[1]): sel_arr = {} if "row" in plot_kw[name]: sel_arr[plot_kw[name]["row"]] = i if "col" in plot_kw[name]: sel_arr[plot_kw[name]["col"]] = j _plot_timeseries( im.axs[i, j], name, arr.isel(**sel_arr).squeeze(), cp_plot_kw, non_dict_data, array_categ, legend, ) # add/modify plot elements according to the first entry. if ax: set_plot_attrs( use_attrs, list(data.values())[0], ax, title_loc="left", wrap_kw={"min_line_len": 35, "max_line_len": 48}, ) ax.set_xlabel( get_localized_term("time").capitalize() ) # check_timeindex() already checks for 'time' # other plot elements if show_lat_lon: if show_lat_lon is True: plot_coords( ax, list(data.values())[0], param="location", loc="lower right", backgroundalpha=1, ) elif isinstance(show_lat_lon, str | tuple | int): plot_coords( ax, list(data.values())[0], param="location", loc=show_lat_lon, backgroundalpha=1, ) else: raise TypeError(" show_lat_lon must be a bool, string, int, or tuple") if legend is not None: if not ax.get_legend_handles_labels()[0]: # check if legend is empty pass elif legend == "in_plot": split_legend(ax, in_plot=True) elif legend == "edge": split_legend(ax, in_plot=False) elif isinstance(legend, dict): ax.legend(**legend) else: ax.legend() return ax else: if legend is not None: if not im.axs[-1, -1].get_legend_handles_labels()[ 0 ]: # check if legend is empty pass elif legend == "in_plot": split_legend(im.axs[-1, -1], in_plot=True) elif legend == "edge": split_legend(im.axs[-1, -1], in_plot=False) elif isinstance(legend, dict): handles, labels = im.axs[-1, -1].get_legend_handles_labels() legend = {"handles": handles, "labels": labels} | legend im.fig.legend(**legend) elif legend == "facetgrid": handles, labels = im.axs[-1, -1].get_legend_handles_labels() im.fig.legend( handles, labels, loc="lower center", ncol=len(im.axs[-1, -1].lines), bbox_to_anchor=(0.5, -0.05), ) if show_lat_lon: if show_lat_lon is True: plot_coords( None, list(data.values())[0].isel(lat=0, lon=0), param="location", loc="lower right", backgroundalpha=1, ) elif isinstance(show_lat_lon, str | tuple | int): plot_coords( None, list(data.values())[0].isel(lat=0, lon=0), param="location", loc=show_lat_lon, backgroundalpha=1, ) if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid): for idx, ax in enumerate(im.axs.flat): ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}") return im
[docs] def gridmap( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, projection: ccrs.Projection = ccrs.LambertConformal(), transform: ccrs.Projection | None = None, features: list[str] | dict[str, dict[str, Any]] | None = None, geometries_kw: dict[str, Any] | None = None, contourf: bool = False, cmap: str | matplotlib.colors.Colormap | None = None, levels: int | list | np.ndarray | None = None, divergent: bool | int | float = False, show_time: bool | str | int | tuple[float, float] = False, frame: bool = False, enumerate_subplots: bool = False, ) -> matplotlib.axes.Axes: """ Create map from 2D data. Parameters ---------- data : dict, DataArray or Dataset Input data do plot. If dictionary, must have only one entry. ax : matplotlib axis, optional Matplotlib axis on which to plot, with the same projection as the one specified. use_attrs : dict, optional Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw: dict, optional Arguments to pass to the `xarray.plot.pcolormesh()` or 'xarray.plot.contourf()' function. projection : ccrs.Projection The projection to use, taken from the cartopy.crs options. Ignored if ax is not None. transform : ccrs.Projection, optional Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching ccrs.PlateCarree() or ccrs.RotatedPole(). features : list or dict, optional Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states']. geometries_kw : dict, optional Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis. contourf : bool By default False, use plt.pcolormesh(). If True, use plt.contourf(). cmap : matplotlib.colors.Colormap or str, optional Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of the DataArray or its 'history' attribute and use corresponding colormap, aligned with the IPCC visual style guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). levels : int, list, np.ndarray, optional Number of levels to divide the colormap into or list of level boundaries (in data units). divergent : bool or int or float If int or float, becomes center of cmap. Default center is 0. show_time : bool, tuple, string or int. If True, show time (as date) at the bottom right of the figure. Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location of the text. If a string or an int, the same values as those of the 'loc' parameter of matplotlib's legends are accepted. ================== ============= Location String Location Code ================== ============= 'upper right' 1 'upper left' 2 'lower left' 3 'lower right' 4 'right' 5 'center left' 6 'center right' 7 'lower center' 8 'upper center' 9 'center' 10 ================== ============= frame : bool Show or hide frame. Default False. enumerate_subplots: bool If True, enumerate subplots with letters. Only works with facetgrids (pass `col` or `row` in plot_kw). Returns ------- matplotlib.axes.Axes """ # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) # set default use_attrs values use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs if "row" not in plot_kw and "col" not in plot_kw: use_attrs.setdefault("title", "description") # extract plot_kw from dict if needed if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys(): plot_kw = plot_kw[list(data.keys())[0]] # if data is dict, extract if isinstance(data, dict): if len(data) == 1: data = list(data.values())[0] else: raise ValueError("If `data` is a dict, it must be of length 1.") # select data to plot if isinstance(data, xr.DataArray): plot_data = data.squeeze() elif isinstance(data, xr.Dataset): if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) plot_data = data[list(data.keys())[0]].squeeze() else: raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup transform if transform is None: if "lat" in data.dims and "lon" in data.dims: transform = ccrs.PlateCarree() if "rlat" in data.dims and "rlon" in data.dims: transform = get_rotpole(data) # setup fig, ax if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()): fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw) elif ax is not None and ("col" in plot_kw or "row" in plot_kw): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw cfig_kw = fig_kw.copy() if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid plot_kw.setdefault("figsize", fig_kw["figsize"]) cfig_kw.pop("figsize") if len(cfig_kw) >= 1: plot_kw = {"subplot_kws": {"projection": cfig_kw}} | plot_kw warnings.warn( "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2 ) # create cbar label if ( "cbar_units" in use_attrs and len(get_attributes(use_attrs["cbar_units"], data)) >= 1 ): # avoids '[]' as label cbar_label = ( get_attributes(use_attrs["cbar_label"], data) + " (" + get_attributes(use_attrs["cbar_units"], data) + ")" ) else: cbar_label = get_attributes(use_attrs["cbar_label"], data) # colormap if isinstance(cmap, str): if cmap not in plt.colormaps(): try: cmap = create_cmap(filename=cmap) except FileNotFoundError as e: logger.error(e) pass elif cmap is None: cmap = create_cmap( get_var_group(da=plot_data), divergent=divergent, ) plot_kw.setdefault("cmap", cmap) if levels is not None: if isinstance(levels, Iterable): lin = levels else: lin = custom_cmap_norm( cmap, np.nanmin(plot_data.values), np.nanmax(plot_data.values), levels=levels, divergent=divergent, linspace_out=True, ) plot_kw.setdefault("levels", lin) elif (divergent is not False) and ("levels" not in plot_kw): vmin = plot_kw.pop("vmin", np.nanmin(plot_data.values)) vmax = plot_kw.pop("vmax", np.nanmax(plot_data.values)) norm = custom_cmap_norm( cmap, vmin, vmax, levels=levels, divergent=divergent, ) plot_kw.setdefault("norm", norm) # set defaults if divergent is not False: if isinstance(divergent, int | float): plot_kw.setdefault("center", divergent) else: plot_kw.setdefault("center", 0) if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False: plot_kw.setdefault("cbar_kwargs", {}) plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label)) # bug xlim / ylim + transform in facetgrids # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766) if transform and ("xlim" in plot_kw and "ylim" in plot_kw): extent = [ plot_kw["xlim"][0], plot_kw["xlim"][1], plot_kw["ylim"][0], plot_kw["ylim"][1], ] plot_kw.pop("xlim") plot_kw.pop("ylim") elif transform and ("xlim" in plot_kw or "ylim" in plot_kw): extent = None warnings.warn( "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped", stacklevel=2 ) if "xlim" in plot_kw.keys(): plot_kw.pop("xlim") if "ylim" in plot_kw.keys(): plot_kw.pop("ylim") else: extent = None # plot if ax: plot_kw.setdefault("ax", ax) if transform: plot_kw.setdefault("transform", transform) if contourf is False: im = plot_data.plot.pcolormesh(**plot_kw) else: im = plot_data.plot.contourf(**plot_kw) if ax: if extent: ax.set_extent(extent) ax = add_features_map( data, ax, use_attrs, projection, features, geometries_kw, frame, ) if show_time: if isinstance(show_time, bool): plot_coords( ax, plot_data, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( ax, plot_data, param="time", loc=show_time, backgroundalpha=1, ) # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute. if (frame is False) and ( (getattr(im, "colorbar", None) is not None) or (getattr(im, "cbar", None) is not None) ): im.colorbar.outline.set_visible(False) return ax else: for _i, fax in enumerate(im.axs.flat): add_features_map( data, fax, use_attrs, projection, features, geometries_kw, frame, ) if extent: fax.set_extent(extent) # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute. if (frame is False) and ( (getattr(im, "colorbar", None) is not None) or (getattr(im, "cbar", None) is not None) ): im.cbar.outline.set_visible(False) if show_time: if isinstance(show_time, bool): plot_coords( None, plot_data, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( None, plot_data, param="time", loc=show_time, backgroundalpha=1, ) use_attrs.setdefault("suptitle", "long_name") im = set_plot_attrs(use_attrs, data, facetgrid=im) if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid): for idx, ax in enumerate(im.axs.flat): ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}") return im
[docs] def gdfmap( df: gpd.GeoDataFrame, df_col: str, ax: cartopy.mpl.geoaxes.GeoAxes | cartopy.mpl.geoaxes.GeoAxesSubplot | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, projection: ccrs.Projection = ccrs.LambertConformal(), features: list[str] | dict[str, dict[str, Any]] | None = None, cmap: str | matplotlib.colors.Colormap | None = None, levels: int | list[int | float] | None = None, divergent: bool | int | float = False, cbar: bool = True, frame: bool = False, ) -> matplotlib.axes.Axes: """ Create a map plot from geometries. Parameters ---------- df : geopandas.GeoDataFrame Dataframe containing the geometries and the data to plot. Must have a column named 'geometry'. df_col : str Name of the column of 'df' containing the data to plot using the colorscale. If `boundary`, only the boundary of the geometries is plotted, without colorscale. ax : cartopy.mpl.geoaxes.GeoAxes or cartopy.mpl.geoaxes.GeoaxesSubplot, optional Matplotlib axis built with a projection, on which to plot. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw : dict, optional Arguments to pass to the GeoDataFrame.plot() method. projection : ccrs.Projection The projection to use, taken from the cartopy.crs options. Ignored if ax is not None. features : list or dict, optional Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states']. cmap : matplotlib.colors.Colormap or str Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of df_col and use corresponding colormap, aligned with the IPCC visual style guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). levels : int or list, optional Number of levels or list of level boundaries (in data units) to use to divide the colormap. divergent : bool or int or float If int or float, becomes center of cmap. Default center is 0. cbar : bool Show colorbar. Default 'True'. frame : bool Show or hide frame. Default False. Returns ------- matplotlib.axes.Axes """ # create empty dicts if None fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) features = empty_dict(features) # checks if not isinstance(df, gpd.GeoDataFrame): raise TypeError("df myst be an instance of class geopandas.GeoDataFrame") if "geometry" not in df.columns: raise ValueError("column 'geometry' not found in GeoDataFrame") # convert to projection if ax is None: df = gpd_to_ccrs(df=df, proj=projection) else: df = gpd_to_ccrs(df=df, proj=ax.projection) # setup fig, ax if ax is None: fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw) ax.set_aspect("equal") # recommended by geopandas # add features if features: add_cartopy_features(ax, features) if df_col == "boundary": plot = df.boundary.plot(ax=ax, **plot_kw) if cmap is not None or levels is not None or divergent is not False: warnings.warn("Colomap arguments are ignored when plotting 'boundary'.", stacklevel=2) else: # colormap if isinstance(cmap, str): if cmap in plt.colormaps(): cmap = matplotlib.colormaps[cmap] else: try: cmap = create_cmap(filename=cmap) except FileNotFoundError: warnings.warn("invalid cmap, using default", stacklevel=2) cmap = create_cmap(filename="slev_seq") elif cmap is None: cmap = create_cmap( get_var_group(unique_str=df_col), divergent=divergent, ) # create normalization for colormap plot_kw.setdefault("vmin", df[df_col].min()) plot_kw.setdefault("vmax", df[df_col].max()) if (levels is not None) or (divergent is not False): norm = custom_cmap_norm( cmap, plot_kw["vmin"], plot_kw["vmax"], levels=levels, divergent=divergent, ) plot_kw.setdefault("norm", norm) # colorbar if cbar: plot_kw.setdefault("legend", True) plot_kw.setdefault("legend_kwds", {}) plot_kw["legend_kwds"].setdefault("label", df_col) plot_kw["legend_kwds"].setdefault("orientation", "horizontal") plot_kw["legend_kwds"].setdefault("pad", 0.02) # plot plot = df.plot(column=df_col, ax=ax, cmap=cmap, **plot_kw) if frame is False: # cbar if len(plot.figure.axes) > 1: # only if it exists plot.figure.axes[1].spines["outline"].set_visible(False) plot.figure.axes[1].tick_params(size=0) # main axes ax.spines["geo"].set_visible(False) return ax
[docs] def violin( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, color: str | int | list[str | int] | None = None, ) -> matplotlib.axes.Axes: """ Make violin plot using seaborn. Parameters ---------- data : dict or Dataset/DataArray Input data to plot. If a dict, must contain DataArrays and/or Datasets. ax : matplotlib.axes.Axes, optional Matplotlib axis on which to plot. use_attrs : dict, optional A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. plot_kw : dict, optional Arguments to pass to the `seaborn.violinplot()` function. color : str, int or list, optional Unique color or list of colors to use. Integers point to the applied stylesheet's colors, in zero-indexed order. Passing 'color' or 'palette' in plot_kw overrides this argument. Returns ------- matplotlib.axes.Axes """ # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) # if data is dict, assemble into one DataFrame non_dict_data = True if isinstance(data, dict): non_dict_data = False df = pd.DataFrame() for key, xr_obj in data.items(): if isinstance(xr_obj, xr.Dataset): # if one data var, use key if len(list(xr_obj.data_vars)) == 1: df[key] = xr_obj[list(xr_obj.data_vars)[0]].values # if more than one data var, use key + name of var else: for data_var in list(xr_obj.data_vars): df[key + "_" + data_var] = xr_obj[data_var].values elif isinstance(xr_obj, xr.DataArray): df[key] = xr_obj.values else: raise TypeError( '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.' ) elif isinstance(data, xr.Dataset): # create dataframe df = data.to_dataframe() df = df[data.data_vars] elif isinstance(data, xr.DataArray): # create dataframe df = data.to_dataframe() for coord in list(data.coords): if coord in df.columns: df = df.drop(columns=coord) else: raise TypeError( '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.' ) # set fig, ax if not provided if ax is None: fig, ax = plt.subplots(**fig_kw) # set default use_attrs values if "orient" in plot_kw and plot_kw["orient"] == "h": use_attrs = {"xlabel": "long_name", "xunits": "units"} | use_attrs else: use_attrs = {"ylabel": "long_name", "yunits": "units"} | use_attrs # add/modify plot elements according to the first entry. if non_dict_data: set_plot_obj = data else: set_plot_obj = list(data.values())[0] set_plot_attrs( use_attrs, xr_obj=set_plot_obj, ax=ax, title_loc="left", wrap_kw={"min_line_len": 35, "max_line_len": 48}, ) # color if color: style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"] if isinstance(color, str): plot_kw.setdefault("color", color) elif isinstance(color, int): try: plot_kw.setdefault("color", style_colors[color]) except IndexError as err: raise IndexError("Index out of range of stylesheet colors") from err elif isinstance(color, list): for c, i in zip(color, np.arange(len(color)), strict=False): if isinstance(c, int): try: color[i] = style_colors[c] except IndexError as err: raise IndexError("Index out of range of stylesheet colors") from err plot_kw.setdefault("palette", color) # plot sns.violinplot(df, ax=ax, **plot_kw) # grid if "orient" in plot_kw and plot_kw["orient"] == "h": ax.grid(visible=True, axis="x") return ax
[docs] def stripes( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, fig_kw: dict[str, Any] | None = None, divide: int | None = None, cmap: str | matplotlib.colors.Colormap | None = None, cmap_center: int | float = 0, cbar: bool = True, cbar_kw: dict[str, Any] | None = None, ) -> matplotlib.axes.Axes: """ Create stripes plot with or without multiple scenarios. Parameters ---------- data : dict or DataArray or Dataset Data to plot. If a dictionary of xarray objects, each will correspond to a scenario. ax : matplotlib.axes.Axes, optional Matplotlib axis on which to plot. fig_kw : : dict, optional Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided. divide : int, optional Year at which the plot is divided into scenarios. If not provided, the horizontal separators will extend over the full time axis. cmap : matplotlib.colors.Colormap or str, optional Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray or its 'history' attribute and use corresponding diverging colormap, aligned with the IPCC Visual Style Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). cmap_center : int or float Center of the colormap in data coordinates. Default is 0. cbar : bool Show colorbar. cbar_kw : dict, optional Arguments to pass to plt.colorbar. Returns ------- matplotlib.axes.Axes """ # create empty dicts if None fig_kw = empty_dict(fig_kw) cbar_kw = empty_dict(cbar_kw) # init main (figure) axis if ax is None: fig_kw.setdefault("figsize", (10, 5)) fig, ax = plt.subplots(**fig_kw) ax.set_yticks([]) ax.set_xticks([]) ax.spines[["top", "bottom", "left", "right"]].set_visible(False) # init plot axis ax_0 = ax.inset_axes([0, 0.15, 1, 0.75]) # handle non-dict data if not isinstance(data, dict): data = {"_no_label": data} # convert SSP, RCP, CMIP formats in keys data = process_keys(data, convert_scen_name) n = len(data) # extract DataArrays from datasets for key, obj in data.items(): if isinstance(obj, xr.DataArray): pass elif isinstance(obj, xr.Dataset): data[key] = obj[list(obj.data_vars)[0]] else: raise TypeError("data must contain xarray DataArrays or Datasets") # get time interval time_index = list(data.values())[0].time.dt.year.values delta_time = [ time_index[i] - time_index[i - 1] for i in np.arange(1, len(time_index), 1) ] if all(i == delta_time[0] for i in delta_time): dtime = delta_time[0] else: raise ValueError("Time delta between each array element must be constant") # modify axes ax.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime) ax_0.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime) ax_0.set_ylim(0, 1) ax_0.set_yticks([]) ax_0.xaxis.set_ticks_position("top") ax_0.tick_params(axis="x", direction="out", zorder=10) ax_0.spines[["top", "left", "right", "bottom"]].set_visible(False) # width of bars, to fill x axis limits width = (max(time_index) + 0.5 - min(time_index) - 0.5) / len(time_index) # create historical/projection divide if divide is not None: # convert divide year to transAxes divide_disp = ax_0.transData.transform( (divide - width * 0.5, 1) ) # left limit of stripe, 1 is placeholder divide_ax = ax_0.transAxes.inverted().transform(divide_disp) divide_ax = divide_ax[0] else: divide_ax = 0 # create an inset ax for each da in data subaxes = {} for i in np.arange(n): name = "subax_" + str(i) y = (1 / n) * i subaxes[name] = ax_0.inset_axes([0, y, 1, 1 / n], transform=ax_0.transAxes) subaxes[name].set(xlim=ax_0.get_xlim(), ylim=(0, 1), xticks=[], yticks=[]) subaxes[name].spines[["top", "bottom", "left", "right"]].set_visible(False) # lines separating axes if i > 0: subaxes[name].spines["bottom"].set_visible(True) subaxes[name].spines["bottom"].set( lw=2, color="w", bounds=(divide_ax, 1), transform=subaxes[name].transAxes, ) # circles if divide: circle = matplotlib.patches.Ellipse( xy=(divide_ax, y), width=0.01, height=0.03, color="w", transform=ax_0.transAxes, zorder=10, ) ax_0.add_patch(circle) # get max and min of all data data_min = 1e6 data_max = -1e6 for da in data.values(): if min(da.values) < data_min: data_min = min(da.values) if max(da.values) > data_max: data_max = max(da.values) # colormap if isinstance(cmap, str): if cmap in plt.colormaps(): cmap = matplotlib.colormaps[cmap] else: try: cmap = create_cmap(filename=cmap) except FileNotFoundError as e: logger.error(e) pass elif cmap is None: cmap = create_cmap( get_var_group(da=list(data.values())[0]), divergent=True, ) # create cmap norm if cmap_center is not None: norm = matplotlib.colors.TwoSlopeNorm(cmap_center, vmin=data_min, vmax=data_max) else: norm = matplotlib.colors.Normalize(data_min, data_max) # plot for (_name, subax), (key, da) in zip(subaxes.items(), data.items(), strict=False): subax.bar(da.time.dt.year, height=1, width=dtime, color=cmap(norm(da.values))) if divide: if key != "_no_label": subax.text( 0.99, 0.5, key, transform=subax.transAxes, fontsize=14, ha="right", va="center", c="w", weight="bold", ) # colorbar if cbar is True: sm = ScalarMappable(cmap=cmap, norm=norm) cax = ax.inset_axes([0.01, 0.05, 0.35, 0.06]) cbar_tcks = np.arange(math.floor(data_min), math.ceil(data_max), 2) # label da = list(data.values())[0] label = get_attributes("long_name", da) if label != "": if "units" in da.attrs: u = da.units label += f" ({u})" label = wrap_text(label, max_line_len=40) cbar_kw = { "cax": cax, "orientation": "horizontal", "ticks": cbar_tcks, "label": label, } | cbar_kw plt.colorbar(sm, **cbar_kw) cax.spines["outline"].set_visible(False) cax.set_xscale("linear") return ax
[docs] def heatmap( data: xr.DataArray | xr.Dataset | dict[str, Any], ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, transpose: bool = False, cmap: str | matplotlib.colors.Colormap | None = "RdBu", divergent: bool | int | float = False, ) -> matplotlib.axes.Axes: """ Create heatmap from a DataArray. Parameters ---------- data : dict or DataArray or Dataset Input data do plot. If dictionary, must have only one entry. ax : matplotlib axis, optional Matplotlib axis on which to plot, with the same projection as the one specified. use_attrs : dict, optional Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'cbar_label': 'long_name'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw : dict, optional Arguments to pass to the 'seaborn.heatmap()' function. If 'data' is a dictionary, can be a nested dictionary with the same key as 'data'. transpose : bool If true, the 2D data will be transposed, so that the original x-axis becomes the y-axis and vice versa. cmap : matplotlib.colors.Colormap or str, optional Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). divergent : bool or int or float If int or float, becomes center of cmap. Default center is 0. Returns ------- matplotlib.axes.Axes """ # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) # set default use_attrs values use_attrs.setdefault("cbar_label", "long_name") # if data is dict, extract if isinstance(data, dict): if plot_kw and list(data.keys())[0] in plot_kw.keys(): plot_kw = plot_kw[list(data.keys())[0]] if len(data) == 1: data = list(data.values())[0] else: raise ValueError("If `data` is a dict, it must be of length 1.") # select data to plot if isinstance(data, xr.DataArray): da = data elif isinstance(data, xr.Dataset): if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) da = list(data.values())[0] else: raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup fig, axis if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()): fig, ax = plt.subplots(**fig_kw) elif ax is not None and ("col" in plot_kw or "row" in plot_kw): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: if any([k != "figsize" for k in fig_kw.keys()]): warnings.warn( "Only figsize arguments can be passed to fig_kw when using facetgrid.", stacklevel=2 ) plot_kw.setdefault("col", None) plot_kw.setdefault("row", None) plot_kw.setdefault("margin_titles", True) heatmap_dims = list( set(da.dims) - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None} ) if da.name is None: da = da.to_dataset(name="data").data da_name = da.name # create cbar label if ( "cbar_units" in use_attrs and len(get_attributes(use_attrs["cbar_units"], data)) >= 1 ): # avoids '()' as label cbar_label = ( get_attributes(use_attrs["cbar_label"], data) + " (" + get_attributes(use_attrs["cbar_units"], data) + ")" ) else: cbar_label = get_attributes(use_attrs["cbar_label"], data) # colormap if isinstance(cmap, str): if cmap not in plt.colormaps(): try: cmap = create_cmap(filename=cmap) except FileNotFoundError as e: logger.error(e) pass elif cmap is None: cmap = create_cmap( get_var_group(da=da), divergent=divergent, ) # convert data to DataFrame if transpose: da = da.transpose() if "col" not in plot_kw and "row" not in plot_kw: if len(da.dims) != 2: raise ValueError("DataArray must have exactly two dimensions") df = da.to_pandas() else: if len(heatmap_dims) != 2: raise ValueError("DataArray must have exactly two dimensions") df = da.to_dataframe().reset_index() # set defaults if divergent is not False: if isinstance(divergent, int | float): plot_kw.setdefault("center", divergent) else: plot_kw.setdefault("center", 0) if "cbar" not in plot_kw or plot_kw["cbar"] is not False: plot_kw.setdefault("cbar_kws", {}) plot_kw["cbar_kws"].setdefault("label", wrap_text(cbar_label)) plot_kw.setdefault("cmap", cmap) # plot def draw_heatmap(*args, **kwargs): data = kwargs.pop("data") d = ( data if len(args) == 0 # Any sorting should be performed before sending a DataArray in `fg.heatmap` else data.pivot_table( index=args[1], columns=args[0], values=args[2], sort=False ) ) ax = sns.heatmap(d, **kwargs) ax.set_xticklabels( ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor", ) ax.tick_params(axis="both", direction="out") set_plot_attrs( use_attrs, da, ax, title_loc="center", wrap_kw={"min_line_len": 35, "max_line_len": 44}, ) return ax if ax is not None: ax = draw_heatmap(data=df, ax=ax, **plot_kw) return ax elif "col" in plot_kw or "row" in plot_kw: # When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function # With Seaborn, we need to be more careful and separate keywords. plot_kw_hm = { k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters } plot_kw_fg = { k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters } unused_keys = ( set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys()) ) if unused_keys != set(): raise ValueError( f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords " "allowed in `sns.heatmap` or `sns.FacetGrid`. " ) g = sns.FacetGrid(df, **plot_kw_fg) cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9]) g.map_dataframe( draw_heatmap, *heatmap_dims, da_name, **plot_kw_hm, cbar=True, cbar_ax=cax, ) g.fig.subplots_adjust(right=0.9) if "figsize" in fig_kw.keys(): g.fig.set_size_inches(*fig_kw["figsize"]) return g
[docs] def scattermap( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, projection: ccrs.Projection = ccrs.LambertConformal(), transform: ccrs.Projection | None = None, features: list[str] | dict[str, dict[str, Any]] | None = None, geometries_kw: dict[str, Any] | None = None, sizes: str | bool | None = None, size_range: tuple = (10, 60), cmap: str | matplotlib.colors.Colormap | None = None, levels: int | None = None, divergent: bool | int | float = False, legend_kw: dict[str, Any] | None = None, show_time: bool | str | int | tuple[float, float] = False, frame: bool = False, enumerate_subplots: bool = False, ) -> matplotlib.axes.Axes: """ Make a scatter plot of georeferenced data on a map. Parameters ---------- data : dict, DataArray or Dataset Input data do plot. If dictionary, must have only one entry. ax : matplotlib axis, optional Matplotlib axis on which to plot, with the same projection as the one specified. use_attrs : dict, optional Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw : dict, optional Arguments to pass to `plt.scatter()`. If 'data' is a dictionary, can be a dictionary with the same key as 'data'. projection : ccrs.Projection The projection to use, taken from the cartopy.crs options. Ignored if ax is not None. transform : ccrs.Projection, optional Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching ccrs.PlateCarree() or ccrs.RotatedPole(). features : list or dict, optional Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states']. geometries_kw : dict, optional Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis. sizes : bool or str, optional String name of the coordinate to use for determining point size. If True, use the same data as in the colorbar. size_range : tuple Tuple of the minimum and maximum size of the points. cmap : matplotlib.colors.Colormap or str, optional Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). levels : int, optional Number of levels to divide the colormap into. divergent : bool or int or float If int or float, becomes center of cmap. Default center is 0. legend_kw : dict, optional Arguments to pass to plt.legend(). Some defaults {"loc": "lower left", "facecolor": "w", "framealpha": 1, "edgecolor": "w", "bbox_to_anchor": (-0.05, 0)} show_time : bool, tuple, string or int. If True, show time (as date) at the bottom right of the figure. Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location of the text. If a string or an int, the same values as those of the 'loc' parameter of matplotlib's legends are accepted. ================== ============= Location String Location Code ================== ============= 'upper right' 1 'upper left' 2 'lower left' 3 'lower right' 4 'right' 5 'center left' 6 'center right' 7 'lower center' 8 'upper center' 9 'center' 10 ================== ============= frame : bool Show or hide frame. Default False. enumerate_subplots: bool If True, enumerate subplots with letters. Only works with facetgrids (pass `col` or `row` in plot_kw). Returns ------- matplotlib.axes.Axes """ # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) legend_kw = empty_dict(legend_kw) # set default use_attrs values use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs if "row" not in plot_kw and "col" not in plot_kw: use_attrs.setdefault("title", "description") # extract plot_kw from dict if needed if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys(): plot_kw = plot_kw[list(data.keys())[0]] # figanos does not use xr.plot.scatter default markersize if "markersize" in plot_kw.keys(): if not sizes: sizes = plot_kw["markersize"] plot_kw.pop("markersize") # if data is dict, extract if isinstance(data, dict): if len(data) == 1: data = list(data.values())[0].squeeze() if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) else: raise ValueError("If `data` is a dict, it must be of length 1.") # select data to plot and its xr.Dataset if isinstance(data, xr.DataArray): plot_data = data data = xr.Dataset({plot_data.name: plot_data}) elif isinstance(data, xr.Dataset): if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) plot_data = data[list(data.keys())[0]] else: raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup transform if transform is None: if "rlat" in data.dims and "rlon" in data.dims: transform = get_rotpole(data) elif ( "lat" in data.coords and "lon" in data.coords ): # need to work with station dims transform = ccrs.PlateCarree() # setup fig, ax if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()): fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw) elif ax is not None and ("col" in plot_kw or "row" in plot_kw): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw cfig_kw = fig_kw.copy() if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid plot_kw.setdefault("figsize", fig_kw["figsize"]) cfig_kw.pop("figsize") if len(cfig_kw) >= 1: plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw warnings.warn( "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2 ) # create cbar label if ( "cbar_units" in use_attrs and len(get_attributes(use_attrs["cbar_units"], data)) >= 1 ): # avoids '[]' as label cbar_label = ( get_attributes(use_attrs["cbar_label"], data) + " (" + get_attributes(use_attrs["cbar_units"], data) + ")" ) else: cbar_label = get_attributes(use_attrs["cbar_label"], data) if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False: plot_kw.setdefault("cbar_kwargs", {}) plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label)) plot_kw["cbar_kwargs"].setdefault("pad", 0.015) # colormap if isinstance(cmap, str): if cmap not in plt.colormaps(): try: cmap = create_cmap(filename=cmap) except FileNotFoundError as e: logger.error(e) pass elif cmap is None: cmap = create_cmap( get_var_group(da=plot_data), divergent=divergent, ) # nans (not required for plotting since xarray.plot handles np.nan, but needs to be found for sizes legend and to # inform user on how many stations were dropped) mask = ~np.isnan(plot_data.values) if np.sum(mask) < len(mask): warnings.warn( f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the color values", stacklevel=2 ) # point sizes if sizes: if sizes is True: sdata = plot_data elif isinstance(sizes, str): if hasattr(data, "name") and data.name == sizes: sdata = plot_data elif sizes in list(data.coords.keys()): sdata = plot_data[sizes] else: raise ValueError(f"{sizes} not found") else: raise TypeError("sizes must be a string or a bool") # nans sizes smask = ~np.isnan(sdata.values) & mask if np.sum(smask) < np.sum(mask): warnings.warn( f"{np.sum(mask) - np.sum(smask)} nan values were dropped when setting the point size", stacklevel=2 ) mask = smask pt_sizes = norm2range( data=sdata.where(mask).values, target_range=size_range, data_range=None, ) plot_kw.setdefault("add_legend", False) if ax: plot_kw.setdefault("s", pt_sizes) else: plot_kw.setdefault("s", pt_sizes[0]) # norm plot_kw.setdefault("vmin", np.nanmin(plot_data.values[mask])) plot_kw.setdefault("vmax", np.nanmax(plot_data.values[mask])) if levels is not None: if isinstance(levels, Iterable): lin = levels else: lin = custom_cmap_norm( cmap, np.nanmin(plot_data.values[mask]), np.nanmax(plot_data.values[mask]), levels=levels, divergent=divergent, linspace_out=True, ) plot_kw.setdefault("levels", lin) elif (divergent is not False) and ("levels" not in plot_kw): vmin = plot_kw.pop("vmin", np.nanmin(plot_data.values[mask])) vmax = plot_kw.pop("vmax", np.nanmax(plot_data.values[mask])) norm = custom_cmap_norm( cmap, vmin, vmax, levels=levels, divergent=divergent, ) plot_kw.setdefault("norm", norm) # matplotlib.pyplot.scatter treats "edgecolor" and "edgecolors" as aliases so we accept "edgecolor" and convert it if "edgecolor" in plot_kw and "edgecolors" not in plot_kw: plot_kw["edgecolors"] = plot_kw["edgecolor"] plot_kw.pop("edgecolor") # set defaults and create copy without vmin, vmax (conflicts with norm) plot_kw = { "cmap": cmap, "transform": transform, "zorder": 8, "marker": "o", } | plot_kw # check if edgecolors in plot_kw and match len of plot_data if "edgecolors" in plot_kw: if matplotlib.colors.is_color_like(plot_kw["edgecolors"]): plot_kw["edgecolors"] = np.repeat( plot_kw["edgecolors"], len(plot_data.where(mask).values) ) elif len(plot_kw["edgecolors"]) != len(plot_data.values): plot_kw["edgecolors"] = np.repeat( plot_kw["edgecolors"][0], len(plot_data.where(mask).values) ) warnings.warn( "Length of edgecolors does not match length of data. Only first edgecolor is used for plotting.", stacklevel=2 ) else: if isinstance(plot_kw["edgecolors"], list): plot_kw["edgecolors"] = np.array(plot_kw["edgecolors"]) plot_kw["edgecolors"] = plot_kw["edgecolors"][mask] else: plot_kw.setdefault("edgecolors", "none") for key in ["vmin", "vmax"]: plot_kw.pop(key, None) # plot plot_kw = {"x": "lon", "y": "lat", "hue": plot_data.name} | plot_kw if ax: plot_kw.setdefault("ax", ax) plot_data_masked = plot_data.where(mask).to_dataset() im = plot_data_masked.plot.scatter(**plot_kw) # add features if ax: ax = add_features_map( data, ax, use_attrs, projection, features, geometries_kw, frame, ) if show_time: if isinstance(show_time, bool): plot_coords( ax, plot_data, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( ax, plot_data, param="time", loc=show_time, backgroundalpha=1, ) if (frame is False) and (im.colorbar is not None): im.colorbar.outline.set_visible(False) else: for i, fax in enumerate(im.axs.flat): fax = add_features_map( data, fax, use_attrs, projection, features, geometries_kw, frame, ) if sizes: # correct markersize for facetgrid scat = fax.collections[0] scat.set_sizes(pt_sizes[i]) if (frame is False) and (im.cbar is not None): im.cbar.outline.set_visible(False) if show_time: if isinstance(show_time, bool): plot_coords( None, plot_data, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( None, plot_data, param="time", loc=show_time, backgroundalpha=1, ) # size legend if sizes: legend_elements = size_legend_elements( np.resize(sdata.values[mask], (sdata.values[mask].size, 1)), np.resize(pt_sizes[mask], (pt_sizes[mask].size, 1)), max_entries=6, marker=plot_kw["marker"], ) # legend spacing if size_range[1] > 200: ls = 0.5 + size_range[1] / 100 * 0.125 else: ls = 0.5 legend_kw = { "loc": "lower left", "facecolor": "w", "framealpha": 1, "edgecolor": "w", "labelspacing": ls, "handles": legend_elements, "bbox_to_anchor": (-0.05, -0.1), } | legend_kw if "title" not in legend_kw: if hasattr(sdata, "long_name"): lgd_title = wrap_text( sdata.long_name, min_line_len=1, max_line_len=15 ) if hasattr(sdata, "units"): lgd_title += f" ({sdata.units})" else: lgd_title = sizes legend_kw.setdefault("title", lgd_title) if ax: lgd = ax.legend(**legend_kw) lgd.set_zorder(11) else: im.figlegend = im.fig.legend(**legend_kw) # im._adjust_fig_for_guide(im.figlegend) if ax: return ax else: im.fig.suptitle(get_attributes("long_name", data)) im.set_titles(template="{value}") if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid): for idx, ax in enumerate(im.axs.flat): ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}") return im
[docs] def taylordiagram( data: xr.DataArray | dict[str, xr.DataArray], plot_kw: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, std_range: tuple = (0, 1.5), contours: int | None = 4, contours_kw: dict[str, Any] | None = None, ref_std_line: bool = False, legend_kw: dict[str, Any] | None = None, std_label: str | None = None, corr_label: str | None = None, colors_key: str | None = None, markers_key: str | None = None, ): """ Build a Taylor diagram. Based on the following code: https://gist.github.com/ycopin/3342888. Parameters ---------- data : xr.DataArray or dict DataArray or dictionary of DataArrays created by xsdba.measures.taylordiagram, each corresponding to a point on the diagram. The dictionary keys will become their labels. plot_kw : dict, optional Arguments to pass to the `plot()` function. Changes how the markers look. If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'. fig_kw : dict, optional Arguments to pass to `plt.figure()`. std_range : tuple Range of the x and y axes, in units of the highest standard deviation in the data. contours : int, optional Number of rsme contours to plot. contours_kw : dict, optional Arguments to pass to `plt.contour()` for the rmse contours. ref_std_line : bool, optional If True, draws a circular line on radius `std = ref_std`. Default: False legend_kw : dict, optional Arguments to pass to `plt.legend()`. std_label : str, optional Label for the standard deviation (x and y) axes. corr_label : str, optional Label for the correlation axis. colors_key : str, optional Attribute or dimension of DataArrays used to separate DataArrays into groups with different colors. If present, it overrides the "color" key in `plot_kw`. markers_key : str, optional Attribute or dimension of DataArrays used to separate DataArrays into groups with different markers. If present, it overrides the "marker" key in `plot_kw`. Returns ------- (plt.figure, mpl_toolkits.axisartist.floating_axes.FloatingSubplot, plt.legend) """ plot_kw = empty_dict(plot_kw) fig_kw = empty_dict(fig_kw) contours_kw = empty_dict(contours_kw) legend_kw = empty_dict(legend_kw) # preserve order of dimensions if used for marker/color ordered_markers_type = None ordered_colors_type = None # convert SSP, RCP, CMIP formats in keys if isinstance(data, dict): data = process_keys(data, convert_scen_name) if isinstance(plot_kw, dict): plot_kw = process_keys(plot_kw, convert_scen_name) # if only one data input, insert in dict. if not isinstance(data, dict): data = {"_no_label": data} # mpl excludes labels starting with "_" from legend plot_kw = {"_no_label": empty_dict(plot_kw)} elif not plot_kw: plot_kw = {k: {} for k in data.keys()} # check type for key, v in data.items(): if not isinstance(v, xr.DataArray): raise TypeError("All objects in 'data' must be xarray DataArrays.") if "taylor_param" not in v.dims: raise ValueError("All DataArrays must contain a 'taylor_param' dimension.") if key == "reference": raise ValueError("'reference' is not allowed as a key in data.") # If there are other dimensions than 'taylor_param', create a bigger dict with them data_keys = list(data.keys()) for data_key in data_keys: da = data[data_key] dims = list(set(da.dims) - {"taylor_param"}) if dims != []: if markers_key in dims: ordered_markers_type = da[markers_key].values if colors_key in dims: ordered_colors_type = da[colors_key].values da = da.stack(pl_dims=dims) for i, dim_key in enumerate(da.pl_dims.values): if isinstance(dim_key, list) or isinstance(dim_key, tuple): dim_key = "-".join([str(k) for k in dim_key]) da0 = da.isel(pl_dims=i) # if colors_key/markers_key is a dimension, add it as an attribute for later use if markers_key in dims: da0.attrs[markers_key] = da0[markers_key].values.item() if colors_key in dims: da0.attrs[colors_key] = da0[colors_key].values.item() new_data_key = ( f"{data_key}-{dim_key}" if data_key != "_no_label" else dim_key ) data[new_data_key] = da0 plot_kw[new_data_key] = empty_dict(plot_kw[f"{data_key}"]) data.pop(data_key) plot_kw.pop(data_key) # remove negative correlations initial_len = len(data) removed = [ key for key, da in data.items() if da.sel(taylor_param="corr").values < 0 ] data = { key: da for key, da in data.items() if da.sel(taylor_param="corr").values >= 0 } if len(data) != initial_len: warnings.warn( f"{initial_len - len(data)} points with negative correlations will not be plotted: {', '.join(removed)}", stacklevel=2 ) # add missing keys to plot_kw for key in data.keys(): if key not in plot_kw: plot_kw[key] = {} # extract ref to be used in plot ref_std = list(data.values())[0].sel(taylor_param="ref_std").values # check if ref is the same in all DataArrays and get the highest std (for ax limits) if len(data) > 1: for da in data.values(): if da.sel(taylor_param="ref_std").values != ref_std: raise ValueError( "All reference standard deviation values must be identical" ) # get highest std for axis limits max_std = [ref_std] for da in data.values(): max_std.extend( [ max( da.sel(taylor_param="ref_std").values, da.sel(taylor_param="sim_std").values, ).astype(float) ] ) # make labels if not std_label: try: units = list(data.values())[0].units std_label = get_localized_term("standard deviation") std_label = std_label if units == "" else f"{std_label} ({units})" except AttributeError: std_label = get_localized_term("standard deviation").capitalize() if not corr_label: try: if "Pearson" in list(data.values())[0].correlation_type: corr_label = get_localized_term("pearson correlation").capitalize() else: corr_label = get_localized_term("correlation").capitalize() except AttributeError: corr_label = get_localized_term("correlation").capitalize() # build diagram transform = PolarAxes.PolarTransform() # Setup the axis, here we map angles in degrees to angles in radius # Correlation labels rlocs = np.array([0, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1]) tlocs = np.arccos(rlocs) # Conversion to polar angles gl1 = gf.FixedLocator(tlocs) # Positions tf1 = gf.DictFormatter(dict(zip(tlocs, map(str, rlocs), strict=False))) # Standard deviation axis extent radius_min = std_range[0] * max(max_std) radius_max = std_range[1] * max(max_std) # Set up the axes range in the parameter "extremes" ghelper = GridHelperCurveLinear( transform, extremes=(0, np.pi / 2, radius_min, radius_max), grid_locator1=gl1, tick_formatter1=tf1, ) fig = plt.figure(**fig_kw) floating_ax = FloatingSubplot(fig, 111, grid_helper=ghelper) fig.add_subplot(floating_ax) # Adjust axes floating_ax.axis["top"].set_axis_direction("bottom") # "Angle axis" floating_ax.axis["top"].toggle(ticklabels=True, label=True) floating_ax.axis["top"].major_ticklabels.set_axis_direction("top") floating_ax.axis["top"].label.set_axis_direction("top") floating_ax.axis["top"].label.set_text(corr_label) floating_ax.axis["left"].set_axis_direction("bottom") # "X axis" floating_ax.axis["left"].label.set_text(std_label) floating_ax.axis["right"].set_axis_direction("top") # "Y axis" floating_ax.axis["right"].toggle(ticklabels=True, label=True) floating_ax.axis["right"].major_ticklabels.set_axis_direction("left") floating_ax.axis["right"].label.set_text(std_label) floating_ax.axis["bottom"].set_visible(False) # Useless # Contours along standard deviations floating_ax.grid(visible=True, alpha=0.4) floating_ax.set_title("") ax = floating_ax.get_aux_axes(transform) # return the axes that can be plotted on # plot reference if "reference" in plot_kw: ref_kw = plot_kw.pop("reference") else: ref_kw = {} ref_kw = { "color": "#154504", "marker": "s", "label": get_localized_term("reference"), } | ref_kw ref_pt = ax.scatter(0, ref_std, **ref_kw) points = [ref_pt] # set up for later # plot a circular line along `ref_std` if ref_std_line: angles_for_line = np.linspace(0, np.pi / 2, 100) radii_for_line = np.full_like(angles_for_line, ref_std) ax.plot( angles_for_line, radii_for_line, color=ref_kw["color"], linewidth=0.5, linestyle="-", ) # rmse contours from reference standard deviation if contours: radii, angles = np.meshgrid( np.linspace(radius_min, radius_max), np.linspace(0, np.pi / 2), ) # Compute centered RMS difference rms = np.sqrt(ref_std**2 + radii**2 - 2 * ref_std * radii * np.cos(angles)) contours_kw = {"linestyles": "--", "linewidths": 0.5} | contours_kw ct = ax.contour(angles, radii, rms, levels=contours, **contours_kw) ax.clabel(ct, ct.levels, fontsize=8) # points.append(ct_line) ct_line = ax.plot( [0], [0], ls=contours_kw["linestyles"], lw=1, c="k" if "colors" not in contours_kw else contours_kw["colors"], label="rmse", ) points.append(ct_line[0]) # get color options style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"] if len(data) > len(style_colors): style_colors = style_colors * math.ceil(len(data) / len(style_colors)) cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json" # get marker options (only used if `markers_key` is set) style_markers = "oDv^<>p*hH+x|_" if len(data) > len(style_markers): style_markers = style_markers * math.ceil(len(data) / len(style_markers)) # set colors and markers styles based on discrimnating attributes (if specified) if colors_key or markers_key: if colors_key: # get_scen_color : look for SSP, RCP, CMIP model color colors_type = ( ordered_colors_type if ordered_colors_type is not None else {da.attrs[colors_key] for da in data.values()} ) colorsd = { k: get_scen_color(k, cat_colors) or style_colors[i] for i, k in enumerate(colors_type) } if markers_key: markers_type = ( ordered_markers_type if ordered_markers_type is not None else {da.attrs[markers_key] for da in data.values()} ) markersd = {k: style_markers[i] for i, k in enumerate(markers_type)} for key, da in data.items(): if colors_key: plot_kw[key]["color"] = colorsd[da.attrs[colors_key]] if markers_key: plot_kw[key]["marker"] = markersd[da.attrs[markers_key]] # plot scatter for (key, da), i in zip(data.items(), range(len(data)), strict=False): # look for SSP, RCP, CMIP model color if colors_key is None: plot_kw[key].setdefault( "color", get_scen_color(key, cat_colors) or style_colors[i] ) # set defaults plot_kw[key] = {"label": key} | plot_kw[key] # legend will be handled later in this case if markers_key or colors_key: plot_kw[key]["label"] = "" # plot pt = ax.scatter( np.arccos(da.sel(taylor_param="corr").values), da.sel(taylor_param="sim_std").values, **plot_kw[key], ) points.append(pt) # legend legend_kw.setdefault("loc", "upper right") legend = fig.legend(points, [pt.get_label() for pt in points], **legend_kw) # plot new legend if markers/colors represent a certain dimension if colors_key or markers_key: handles = list(floating_ax.get_legend_handles_labels()[0]) if markers_key: for k, m in markersd.items(): handles.append(Line2D([0], [0], color="k", label=k, marker=m, ls="")) if colors_key: for k, c in colorsd.items(): handles.append(Line2D([0], [0], color=c, label=k, ls="-")) legend.remove() legend = fig.legend(handles=handles, **legend_kw) return fig, floating_ax, legend
[docs] def hatchmap( data: dict[str, Any] | xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None = None, projection: ccrs.Projection = ccrs.LambertConformal(), transform: ccrs.Projection | None = None, features: list[str] | dict[str, dict[str, Any]] | None = None, geometries_kw: dict[str, Any] | None = None, levels: int | None = None, legend_kw: dict[str, Any] | bool = True, show_time: bool | str | int | tuple[float, float] = False, frame: bool = False, enumerate_subplots: bool = False, ) -> matplotlib.axes.Axes: """ Create map of hatches from 2D data. Parameters ---------- data : dict, DataArray or Dataset Input data do plot. ax : matplotlib axis, optional Matplotlib axis on which to plot, with the same projection as the one specified. use_attrs : dict, optional Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'title': 'description'}. Only the keys found in the default dict can be used. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw: dict, optional Arguments to pass to 'xarray.plot.contourf()' function. If 'data' is a dictionary, can be a nested dictionary with the same keys as 'data'. projection : ccrs.Projection The projection to use, taken from the cartopy.ccrs options. Ignored if ax is not None. transform : ccrs.Projection, optional Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching ccrs.PlateCarree() or ccrs.RotatedPole(). features : list or dict, optional Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states']. geometries_kw : dict, optional Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis. legend_kw : dict or boolean, optional Arguments to pass to `ax.legend()`. No legend is added if legend_kw == False. show_time : bool, tuple, string or int. If True, show time (as date) at the bottom right of the figure. Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location of the text. If a string or an int, the same values as those of the 'loc' parameter of matplotlib's legends are accepted. ================== ============= Location String Location Code ================== ============= 'upper right' 1 'upper left' 2 'lower left' 3 'lower right' 4 'right' 5 'center left' 6 'center right' 7 'lower center' 8 'upper center' 9 'center' 10 ================== ============= frame : bool Show or hide frame. Default False. enumerate_subplots: bool If True, enumerate subplots with letters. Only works with facetgrids (pass `col` or `row` in plot_kw). Returns ------- matplotlib.axes.Axes """ # default hatches dfh = [ "/", "\\", "|", "-", "+", "x", "o", "O", ".", "*", "//", "\\\\", "||", "--", "++", "xx", "oo", "OO", "..", "**", ] # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) legend_kw = empty_dict(legend_kw) dattrs = None plot_data = {} # convert data to dict (if not one) if not isinstance(data, dict): if isinstance(data, xr.DataArray): plot_data = {data.name: data} if data.name not in plot_kw.keys(): plot_kw = {data.name: plot_kw} elif isinstance(data, xr.Dataset): dattrs = data plot_data = {var: data[var] for var in data.data_vars} for v in plot_data.keys(): if v not in plot_kw.keys(): plot_kw[v] = plot_kw else: for k, v in data.items(): if k not in plot_kw.keys(): plot_kw[k] = plot_kw if isinstance(v, xr.Dataset): dattrs = k plot_data[k] = v[list(v.data_vars)[0]] warnings.warn("Only first variable of Dataset is plotted.", stacklevel=2) else: plot_data[k] = v # setup transform from first data entry trdata = list(plot_data.values())[0] if transform is None: if "lat" in trdata.dims and "lon" in trdata.dims: transform = ccrs.PlateCarree() elif "rlat" in trdata.dims and "rlon" in trdata.dims: transform = get_rotpole(list(plot_data.values())[0]) # bug xlim / ylim + transform in facetgrids # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766) if transform and ( "xlim" in list(plot_kw.values())[0] and "ylim" in list(plot_kw.values())[0] ): extent = [ list(plot_kw.values())[0]["xlim"][0], list(plot_kw.values())[0]["xlim"][1], list(plot_kw.values())[0]["ylim"][0], list(plot_kw.values())[0]["ylim"][1], ] [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v] elif transform and ( "xlim" in list(plot_kw.values())[0] or "ylim" in list(plot_kw.values())[0] ): extent = None warnings.warn( "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped", stacklevel=2 ) [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v] else: extent = None # setup fig, ax if ax is None and ( "row" not in list(plot_kw.values())[0].keys() and "col" not in list(plot_kw.values())[0].keys() ): fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw) elif ax is not None and ( "col" in list(plot_kw.values())[0].keys() or "row" in list(plot_kw.values())[0].keys() ): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: [ v.setdefault("subplot_kws", {}).setdefault("projection", projection) for v in plot_kw.values() ] cfig_kw = copy.deepcopy(fig_kw) if "figsize" in fig_kw: # add figsize to plot_kw for facetgrid plot_kw[0].setdefault("figsize", fig_kw["figsize"]) cfig_kw.pop("figsize") if cfig_kw: for v in plot_kw.values(): {"subplots_kws": cfig_kw} | v warnings.warn( "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2 ) pat_leg = [] n = 0 for k, v in plot_data.items(): # if levels plot multiple hatching from one data entry if "levels" in plot_kw[k] and len(plot_data) == 1: # nans mask = ~np.isnan(v.values) if np.sum(mask) < len(mask): warnings.warn( f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the pattern values", stacklevel=2 ) if "hatches" in plot_kw[k] and plot_kw[k]["levels"] != len( plot_kw[k]["hatches"] ): warnings.warn("Hatches number is not equivalent to number of levels", stacklevel=2) hatches = dfh[0:levels] if "hatches" not in plot_kw[k]: hatches = dfh[0:levels] plot_kw[k] = { "hatches": hatches, "colors": "none", "add_colorbar": False, } | plot_kw[k] if "lat" in v.dims: v.coords["mask"] = (("lat", "lon"), mask) else: v.coords["mask"] = (("rlat", "rlon"), mask) plot_kw[k].setdefault("transform", transform) if ax: plot_kw[k].setdefault("ax", ax) im = v.where(mask is not True).plot.contourf(**plot_kw[k]) artists, labels = im.legend_elements(str_format="{:2.1f}".format) if ax and legend_kw: ax.legend(artists, labels, **legend_kw) elif legend_kw: im.figlegend = im.fig.legend(**legend_kw) elif len(plot_data) > 1 and "levels" in plot_kw[k]: raise TypeError( "To plot levels only one xr.DataArray or xr.Dataset accepted" ) else: # since pattern remove colors and colorbar from plotting (done by gridmap) plot_kw[k] = {"colors": "none", "add_colorbar": False} | plot_kw[k] if "hatches" not in plot_kw[k].keys(): plot_kw[k]["hatches"] = dfh[n] n += 1 elif isinstance( plot_kw[k]["hatches"], str ): # make sure the hatches are in a list warnings.warn( "Hatches argument must be of type 'list'. Wrapping string argument as list.", stacklevel=2 ) plot_kw[k]["hatches"] = [plot_kw[k]["hatches"]] plot_kw[k].setdefault("transform", transform) if ax: im = v.plot.contourf(ax=ax, **plot_kw[k]) if not ax: if k == list(plot_data.keys())[0]: c_pkw = plot_kw[k].copy() if "col" in plot_kw[k].keys() or "row" in plot_kw[k].keys(): if c_pkw["colors"] == "none": c_pkw.pop("colors") im = v.plot.contourf(**c_pkw) for i, fax in enumerate(im.axs.flat): if ( k == list(plot_data.keys())[0] and plot_kw[k]["colors"] == "none" ): fax.clear() if len(plot_data) > 1: # select data to plot from DataSet in loop to plot on facetgrids axis c_pkw = plot_kw[k].copy() c_pkw.pop("subplot_kws") sel = {} if "row" in c_pkw.keys(): sel[c_pkw["row"]] = i c_pkw.pop("row") elif "col" in c_pkw.keys(): sel[c_pkw["col"]] = i c_pkw.pop("col") v.isel(sel).plot.contourf(ax=fax, **c_pkw) if k == list(plot_data.keys())[-1]: add_features_map( dattrs, fax, use_attrs, projection, features, geometries_kw, frame, ) if extent: fax.set_extent(extent) pat_leg.append( matplotlib.patches.Patch( hatch=plot_kw[k]["hatches"][0], fill=False, label=k ) ) if pat_leg and legend_kw: legend_kw = { "loc": "lower right", "handleheight": 2, "handlelength": 4, } | legend_kw if ax and legend_kw: ax.legend(handles=pat_leg, **legend_kw) elif legend_kw: im.figlegend = im.fig.legend(handles=pat_leg, **legend_kw) # add features if ax: if extent: ax.set_extent(extent) if dattrs: use_attrs.setdefault("title", "description") ax = add_features_map( dattrs, ax, use_attrs, projection, features, geometries_kw, frame, ) if show_time: if isinstance(show_time, bool): plot_coords( ax, plot_data, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( ax, plot_data, param="time", loc=show_time, backgroundalpha=1, ) # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute. if (frame is False) and ( (getattr(im, "colorbar", None) is not None) or (getattr(im, "cbar", None) is not None) ): im.colorbar.outline.set_visible(False) set_plot_attrs(use_attrs, dattrs, ax, wrap_kw={"max_line_len": 60}) return ax else: # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute. if (frame is False) and ( (getattr(im, "colorbar", None) is not None) or (getattr(im, "cbar", None) is not None) ): im.cbar.outline.set_visible(False) if show_time: if show_time is True: plot_coords( None, dattrs, param="time", loc="lower right", backgroundalpha=1, ) elif isinstance(show_time, str | tuple | int): plot_coords( None, dattrs, param="time", loc=show_time, backgroundalpha=1 ) if dattrs: use_attrs.setdefault("suptitle", "long_name") set_plot_attrs(use_attrs, dattrs, facetgrid=im) if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid): for idx, ax in enumerate(im.axs.flat): ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}") return im
def _add_lead_time_coord(da, ref): """Add a lead time coordinate to the data. Modifies da in-place.""" lead_time = da.time.dt.year - int(ref) da["Lead time"] = lead_time da["Lead time"].attrs["units"] = f"years from {ref}" return lead_time
[docs] def partition( data: xr.DataArray | xr.Dataset, ax: matplotlib.axes.Axes | None = None, start_year: str | None = None, show_num: bool = True, fill_kw: dict[str, Any] | None = None, line_kw: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, legend_kw: dict[str, Any] | None = None, ) -> matplotlib.axes.Axes: """ Figure of the partition of total uncertainty by components. Uncertainty fractions can be computed with xclim (https://xclim.readthedocs.io/en/stable/api.html#uncertainty-partitioning). Make sure the use `fraction=True` in the xclim function call. Parameters ---------- data : xr.DataArray or xr.Dataset Variance over time of the different components of uncertainty. Output of a `xclim.ensembles._partitioning` function. ax : matplotlib axis, optional Matplotlib axis on which to plot. start_year : str If None, the x-axis will be the time in year. If str, the x-axis will show the number of year since start_year. show_num : bool If True, show the number of elements for each uncertainty components in parentheses in the legend. `data` should have attributes named after the components with a list of its the elements. fill_kw : dict Keyword arguments passed to `ax.fill_between`. It is possible to pass a dictionary of keywords for each component (uncertainty coordinates). line_kw : dict Keyword arguments passed to `ax.plot` for the lines in between the components. The default is {color="k", lw=2}. We recommend always using lw>=2. fig_kw : dict Keyword arguments passed to `plt.subplots`. legend_kw : dict Keyword arguments passed to `ax.legend`. Returns ------- mpl.axes.Axes """ if isinstance(data, xr.Dataset): if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) data = data[list(data.keys())[0]].squeeze() if data.attrs["units"] != "%": raise ValueError( "The units are not %. Use `fraction=True` in the xclim function call." ) fill_kw = empty_dict(fill_kw) line_kw = empty_dict(line_kw) fig_kw = empty_dict(fig_kw) legend_kw = empty_dict(legend_kw) # select data to plot if isinstance(data, xr.DataArray): data = data.squeeze() elif isinstance(data, xr.Dataset): # in case, it was saved to disk before plotting. if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) data = data[list(data.keys())[0]].squeeze() else: raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") if ax is None: fig, ax = plt.subplots(**fig_kw) # Select data from reference year onward if start_year: data = data.sel(time=slice(start_year, None)) # Lead time coordinate time = _add_lead_time_coord(data, start_year) ax.set_xlabel(f"Lead time (years from {start_year})") else: time = data.time.dt.year # fill_kw that are direct (not with uncertainty as key) fk_direct = {k: v for k, v in fill_kw.items() if (k not in data.uncertainty.values)} # Draw areas past_y = 0 black_lines = [] for u in data.uncertainty.values: if u not in ["total", "variability"]: present_y = past_y + data.sel(uncertainty=u) num = len(data.attrs.get(u, [])) # compatible with pre PR PR #1529 label = f"{u} ({num})" if show_num and num else u ax.fill_between( time, past_y, present_y, label=label, **fill_kw.get(u, fk_direct), ) black_lines.append(present_y) past_y = present_y ax.fill_between( time, past_y, 100, label="variability", **fill_kw.get("variability", fk_direct), ) # Draw black lines line_kw.setdefault("color", "k") line_kw.setdefault("lw", 2) ax.plot(time, np.array(black_lines).T, **line_kw) ax.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(20)) ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=5)) ax.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(10)) ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=2)) ax.set_ylabel(f"{data.attrs['long_name']} ({data.attrs['units']})") # ax.set_ylim(0, 100) ax.legend(**legend_kw) return ax
[docs] def triheatmap( data: xr.DataArray | xr.Dataset, z: str, ax: matplotlib.axes.Axes | None = None, use_attrs: dict[str, Any] | None = None, fig_kw: dict[str, Any] | None = None, plot_kw: dict[str, Any] | None | list = None, cmap: str | matplotlib.colors.Colormap | None = None, divergent: bool | int | float = False, cbar: bool | str = "unique", cbar_kw: dict[str, Any] | None | list = None, ) -> matplotlib.axes.Axes: """ Create a triangle heatmap from a DataArray. Note that most of the code comes from: https://stackoverflow.com/questions/66048529/how-to-create-a-heatmap-where-each-cell-is-divided-into-4-triangles Parameters ---------- data : DataArray or Dataset Input data do plot. z: str Dimension to plot on the triangles. Its length should be 2 or 4. ax : matplotlib axis, optional Matplotlib axis on which to plot, with the same projection as the one specified. use_attrs : dict, optional Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). Default value is {'cbar_label': 'long_name',"cbar_units": "units"}. Valid keys are: 'title', 'xlabel', 'ylabel', 'cbar_label', 'cbar_units'. fig_kw : dict, optional Arguments to pass to `plt.figure()`. plot_kw : dict, optional Arguments to pass to the 'plt.tripcolor()' function. It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west). cmap : matplotlib.colors.Colormap or str, optional Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors). If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf). divergent : bool or int or float If int or float, becomes center of cmap. Default center is 0. cbar : {False, True, 'unique', 'each'} If False, don't show the colorbar. If True or 'unique', show a unique colorbar for all triangle types. (The cbar of the first triangle is used). If 'each', show a colorbar for each triangle type. cbar_kw : dict or list Arguments to pass to 'fig.colorbar()'. It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west). Returns ------- matplotlib.axes.Axes """ # create empty dicts if None use_attrs = empty_dict(use_attrs) fig_kw = empty_dict(fig_kw) plot_kw = empty_dict(plot_kw) cbar_kw = empty_dict(cbar_kw) # select data to plot if isinstance(data, xr.DataArray): da = data elif isinstance(data, xr.Dataset): if len(data.data_vars) > 1: warnings.warn( "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2 ) da = list(data.values())[0] else: raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup fig, axis if ax is None: fig, ax = plt.subplots(**fig_kw) # colormap if isinstance(cmap, str): if cmap not in plt.colormaps(): try: cmap = create_cmap(filename=cmap) except FileNotFoundError: pass logging.log("Colormap not found. Using default.") elif cmap is None: cmap = create_cmap( get_var_group(da=da), divergent=divergent, ) # prep data d = [da.sel(**{z: v}).values for v in da[z].values] other_dims = [di for di in da.dims if di != z] if len(other_dims) > 2: warnings.warn( "More than 3 dimensions in data. The first two after dim will be used as the dimensions of the heatmap.", stacklevel=2 ) if len(other_dims) < 2: raise ValueError( "Data must have 3 dimensions. If you only have 2 dimensions, use fg.heatmap." ) if plot_kw == {} and cbar in ["unique", True]: warnings.warn( 'With cbar="unique" only the colorbar of the first triangle' " will be shown. No `plot_kw` was passed. vmin and vmax will be set the max" " and min of data.", stacklevel=2 ) plot_kw = {"vmax": da.max().values, "vmin": da.min().values} if isinstance(plot_kw, dict): plot_kw.setdefault("cmap", cmap) plot_kw.setdefault("ec", "white") plot_kw = [plot_kw for _ in range(len(d))] labels_x = da[other_dims[0]].values labels_y = da[other_dims[1]].values m, n = d[0].shape[0], d[0].shape[1] # plot if len(d) == 2: x = np.arange(m + 1) y = np.arange(n + 1) xss, ys = np.meshgrid(x, y) (xss * ys) % 10 triangles1 = [ (i + j * (m + 1), i + 1 + j * (m + 1), i + (j + 1) * (m + 1)) for j in range(n) for i in range(m) ] triangles2 = [ ( i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), i + (j + 1) * (m + 1), ) for j in range(n) for i in range(m) ] triang1 = Triangulation(xss.ravel(), ys.ravel(), triangles1) triang2 = Triangulation(xss.ravel(), ys.ravel(), triangles2) triangul = [triang1, triang2] imgs = [ ax.tripcolor(t, np.ravel(val), **plotkw) for t, val, plotkw in zip(triangul, d, plot_kw, strict=False) ] ax.set_xticks(np.array(range(m)) + 0.5, labels=labels_x, rotation=45) ax.set_yticks(np.array(range(n)) + 0.5, labels=labels_y, rotation=90) elif len(d) == 4: xv, yv = np.meshgrid( np.arange(-0.5, m), np.arange(-0.5, n) ) # vertices of the little squares xc, yc = np.meshgrid( np.arange(0, m), np.arange(0, n) ) # centers of the little squares x = np.concatenate([xv.ravel(), xc.ravel()]) y = np.concatenate([yv.ravel(), yc.ravel()]) cstart = (m + 1) * (n + 1) # indices of the centers triangles_n = [ (i + j * (m + 1), i + 1 + j * (m + 1), cstart + i + j * m) for j in range(n) for i in range(m) ] triangles_e = [ (i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), cstart + i + j * m) for j in range(n) for i in range(m) ] triangles_s = [ ( i + 1 + (j + 1) * (m + 1), i + (j + 1) * (m + 1), cstart + i + j * m, ) for j in range(n) for i in range(m) ] triangles_w = [ (i + (j + 1) * (m + 1), i + j * (m + 1), cstart + i + j * m) for j in range(n) for i in range(m) ] triangul = [ Triangulation(x, y, triangles) for triangles in [ triangles_n, triangles_e, triangles_s, triangles_w, ] ] imgs = [ ax.tripcolor(t, np.ravel(val), **plotkw) for t, val, plotkw in zip(triangul, d, plot_kw, strict=False) ] ax.set_xticks(np.array(range(m)), labels=labels_x, rotation=45) ax.set_yticks(np.array(range(n)), labels=labels_y, rotation=90) else: raise ValueError( f"The length of the dimensiondim ({z},{len(d)}) should be either 2 or 4. It represents the number of triangles." ) ax.set_title(get_attributes(use_attrs.get("title", None), data)) ax.set_xlabel(other_dims[0]) ax.set_ylabel(other_dims[1]) if "xlabel" in use_attrs: ax.set_xlabel(get_attributes(use_attrs["xlabel"], data)) if "ylabel" in use_attrs: ax.set_ylabel(get_attributes(use_attrs["ylabel"], data)) ax.set_aspect("equal", "box") ax.invert_yaxis() ax.tick_params(left=False, bottom=False) ax.spines["bottom"].set_visible(False) ax.spines["left"].set_visible(False) # create cbar label # set default use_attrs values use_attrs.setdefault("cbar_label", "long_name") use_attrs.setdefault("cbar_units", "units") if ( "cbar_units" in use_attrs and len(get_attributes(use_attrs["cbar_units"], data)) >= 1 ): # avoids '()' as label cbar_label = ( get_attributes(use_attrs["cbar_label"], data) + " (" + get_attributes(use_attrs["cbar_units"], data) + ")" ) else: cbar_label = get_attributes(use_attrs["cbar_label"], data) if isinstance(cbar_kw, dict): cbar_kw.setdefault("label", cbar_label) cbar_kw = [cbar_kw for _ in range(len(d))] if cbar == "unique": plt.colorbar(imgs[0], ax=ax, **cbar_kw[0]) elif (cbar == "each") or (cbar is True): for i in reversed(range(len(d))): # switch order of colour bars plt.colorbar(imgs[i], ax=ax, **cbar_kw[i]) return ax