Source code for astropy.cosmology._src.funcs.comparison

# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Comparison functions for `astropy.cosmology.Cosmology`.

This module is **NOT** public API. To use these functions, import them from
the top-level namespace -- :mod:`astropy.cosmology`. This module will be
moved.
"""

import functools
import inspect
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, TypeAlias

import numpy as np
from numpy import False_, True_, ndarray

from astropy import table

# isort: split
from astropy.cosmology._src.core import Cosmology

__all__: list[str] = []  # Nothing is scoped here


##############################################################################
# PARAMETERS

_FormatType: TypeAlias = bool | None | str
_FormatsType: TypeAlias = _FormatType | tuple[_FormatType, ...]

_COSMO_AOK: set[Any] = {None, True_, False_, "astropy.cosmology"}
# The numpy bool also catches real bool for ops "==" and "in"


##############################################################################
# UTILITIES


_CANT_BROADCAST: tuple[type, ...] = (table.Row, table.Table)
"""Things that cannot broadcast.

Have to deal with things that do not broadcast well. e.g.
`~astropy.table.Row` cannot be used in an array, even if ``dtype=object``
and will raise a segfault when used in a `numpy.ufunc`.
"""


@dataclass(frozen=True)
class _CosmologyWrapper:
    """A private wrapper class to hide things from :mod:`numpy`.

    This should never be exposed to the user.
    """

    __slots__ = ("wrapped",)

    wrapped: Any


@functools.partial(np.frompyfunc, nin=2, nout=1)
def _parse_format(cosmo: Any, format: _FormatType, /) -> Cosmology:
    """Parse Cosmology-like input into Cosmologies, given a format hint.

    Parameters
    ----------
    cosmo : |Cosmology|-like, positional-only
        |Cosmology| to parse.
    format : bool or None or str, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting.

    Returns
    -------
    |Cosmology| or generator thereof

    Raises
    ------
    TypeError
        If ``cosmo`` is not a |Cosmology| and ``format`` equals `False`.
    TypeError
        If ``cosmo`` is a |Cosmology| and ``format`` is not `None` or equal to
        `True`.
    """
    # Deal with private wrapper
    if isinstance(cosmo, _CosmologyWrapper):
        cosmo = cosmo.wrapped

    # Shortcut if already a cosmology
    if isinstance(cosmo, Cosmology):
        if format not in _COSMO_AOK:
            allowed = "/".join(map(str, _COSMO_AOK))
            raise ValueError(
                f"for parsing a Cosmology, 'format' must be {allowed}, not {format}"
            )
        return cosmo
    # Convert, if allowed.
    elif format == False_:  # catches False and False_
        raise TypeError(
            f"if 'format' is False, arguments must be a Cosmology, not {cosmo}"
        )
    else:
        format = None if format == True_ else format  # str->str, None/True/True_->None
        out = Cosmology.from_format(cosmo, format=format)  # this can error!

    return out


def _parse_formats(*cosmos: object, format: _FormatsType) -> ndarray:
    """Parse Cosmology-like to |Cosmology|, using provided formats.

    ``format`` is broadcast to match the shape of the cosmology arguments. Note
    that the cosmology arguments are not broadcast against ``format``, so it
    cannot determine the output shape.

    Parameters
    ----------
    *cosmos : |Cosmology|-like
        The objects to compare. Must be convertible to |Cosmology|, as specified
        by the corresponding ``format``.

    format : bool or None or str or array-like thereof, positional-only
        Whether to allow, before equivalence is checked, the object to be
        converted to a |Cosmology|. This allows, e.g. a |Table| to be equivalent
        to a |Cosmology|. `False` (default) will not allow conversion. `True` or
        `None` will, and will use the auto-identification to try to infer the
        correct format. A `str` is assumed to be the correct format to use when
        converting. Note ``format`` is broadcast as an object array to match the
        shape of ``cosmos`` so ``format`` cannot determine the output shape.

    Raises
    ------
    TypeError
        If any in ``cosmos`` is not a |Cosmology| and the corresponding
        ``format`` equals `False`.
    """
    formats = np.broadcast_to(np.array(format, dtype=object), len(cosmos))
    # parse each cosmo & format

    # Have to deal with things that do not broadcast well.
    # astropy.row cannot be used in an array, even if dtype=object
    # and will raise a segfault when used in a ufunc.
    wcosmos = [
        c if not isinstance(c, _CANT_BROADCAST) else _CosmologyWrapper(c)
        for c in cosmos
    ]

    return _parse_format(wcosmos, formats)


def _comparison_decorator(pyfunc: Callable[..., Any]) -> Callable[..., Any]:
    """Decorator to make wrapper function that parses |Cosmology|-like inputs.

    Parameters
    ----------
    pyfunc : Python function object
        An arbitrary Python function.

    Returns
    -------
    callable[..., Any]
        Wrapped `pyfunc`, as described above.

    Notes
    -----
    All decorated functions should add the following to 'Parameters'.

    format : bool or None or str or array-like thereof, optional keyword-only
        Whether to allow the arguments to be converted to a |Cosmology|. This
        allows, e.g. a |Table| to be given instead a |Cosmology|. `False`
        (default) will not allow conversion. `True` or `None` will, and will use
        the auto-identification to try to infer the correct format. A `str` is
        assumed to be the correct format to use when converting. Note ``format``
        is broadcast as an object array to match the shape of ``cosmos`` so
        ``format`` cannot determine the output shape.
    """
    sig = inspect.signature(pyfunc)
    nin = sum(p.kind == 0 for p in sig.parameters.values())

    # Make wrapper function that parses cosmology-like inputs
    @functools.wraps(pyfunc)
    def wrapper(*cosmos: Any, format: _FormatsType = False, **kwargs: Any) -> bool:
        if len(cosmos) > nin:
            raise TypeError(
                f"{wrapper.__wrapped__.__name__} takes {nin} positional"
                f" arguments but {len(cosmos)} were given"
            )
        # Parse cosmologies to format. Only do specified number.
        cosmos = _parse_formats(*cosmos, format=format)
        # Evaluate pyfunc, erroring if didn't match specified number.
        # Return, casting to correct type casting is possible.
        return wrapper.__wrapped__(*cosmos, **kwargs)

    return wrapper


##############################################################################
# COMPARISON FUNCTIONS


[docs] @_comparison_decorator def cosmology_equal( cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False ) -> bool: r"""Return element-wise equality check on the cosmologies. .. note:: Cosmologies are currently scalar in their parameters. Parameters ---------- cosmo1, cosmo2 : |Cosmology|-like The objects to compare. Must be convertible to |Cosmology|, as specified by ``format``. format : bool or None or str or tuple thereof, optional keyword-only Whether to allow the arguments to be converted to a |Cosmology|. This allows, e.g. a |Table| to be given instead a |Cosmology|. `False` (default) will not allow conversion. `True` or `None` will, and will use the auto-identification to try to infer the correct format. A `str` is assumed to be the correct format to use when converting. Note ``format`` is broadcast as an object array to match the shape of ``cosmos`` so ``format`` cannot determine the output shape. allow_equivalent : bool, optional keyword-only Whether to allow cosmologies to be equal even if not of the same class. For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1` and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|. Examples -------- Assuming the following imports >>> import astropy.units as u >>> from astropy.cosmology import FlatLambdaCDM Two identical cosmologies are equal. >>> cosmo1 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3) >>> cosmo2 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3) >>> cosmology_equal(cosmo1, cosmo2) True And cosmologies with different parameters are not. >>> cosmo3 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.4) >>> cosmology_equal(cosmo1, cosmo3) False Two cosmologies may be equivalent even if not of the same class. In these examples the |LambdaCDM| has :attr:`~astropy.cosmology.LambdaCDM.Ode0` set to the same value calculated in |FlatLambdaCDM|. >>> from astropy.cosmology import LambdaCDM >>> cosmo3 = LambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, 0.7) >>> cosmology_equal(cosmo1, cosmo3) False >>> cosmology_equal(cosmo1, cosmo3, allow_equivalent=True) True While in this example, the cosmologies are not equivalent. >>> cosmo4 = FlatLambdaCDM(70 * (u.km/u.s/u.Mpc), 0.3, Tcmb0=3 * u.K) >>> cosmology_equal(cosmo3, cosmo4, allow_equivalent=True) False Also, using the keyword argument, the notion of equality is extended to any Python object that can be converted to a |Cosmology|. >>> mapping = cosmo2.to_format("mapping") >>> cosmology_equal(cosmo1, mapping, format=True) True Either (or both) arguments can be |Cosmology|-like. >>> cosmology_equal(mapping, cosmo2, format=True) True The list of valid formats, e.g. the |Table| in this example, may be checked with ``Cosmology.from_format.list_formats()``. As can be seen in the list of formats, not all formats can be auto-identified by ``Cosmology.from_format.registry``. Objects of these kinds can still be checked for equality, but the correct format string must be used. >>> yml = cosmo2.to_format("yaml") >>> cosmology_equal(cosmo1, yml, format=(None, "yaml")) True This also works with an array of ``format`` matching the number of cosmologies. >>> cosmology_equal(mapping, yml, format=[True, "yaml"]) True """ # Check parameter equality if not allow_equivalent: eq = cosmo1 == cosmo2 else: # Check parameter equivalence # The options are: 1) same class & parameters; 2) same class, different # parameters; 3) different classes, equivalent parameters; 4) different # classes, different parameters. (1) & (3) => True, (2) & (4) => False. eq = cosmo1.__equiv__(cosmo2) if eq is NotImplemented: eq = cosmo2.__equiv__(cosmo1) # that failed, try from 'other' eq = False if eq is NotImplemented else eq # TODO! include equality check of metadata return eq
@_comparison_decorator def _cosmology_not_equal( cosmo1: Any, cosmo2: Any, /, *, allow_equivalent: bool = False ) -> bool: r"""Return element-wise cosmology non-equality check. .. note:: Cosmologies are currently scalar in their parameters. Parameters ---------- cosmo1, cosmo2 : |Cosmology|-like The objects to compare. Must be convertible to |Cosmology|, as specified by ``format``. out : ndarray, None, optional A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. format : bool or None or str or tuple thereof, optional keyword-only Whether to allow the arguments to be converted to a |Cosmology|. This allows, e.g. a |Table| to be given instead a Cosmology. `False` (default) will not allow conversion. `True` or `None` will, and will use the auto-identification to try to infer the correct format. A `str` is assumed to be the correct format to use when converting. ``format`` is broadcast to match the shape of the cosmology arguments. Note that the cosmology arguments are not broadcast against ``format``, so it cannot determine the output shape. allow_equivalent : bool, optional keyword-only Whether to allow cosmologies to be equal even if not of the same class. For example, an instance of |LambdaCDM| might have :math:`\Omega_0=1` and :math:`\Omega_k=0` and therefore be flat, like |FlatLambdaCDM|. See Also -------- astropy.cosmology.cosmology_equal Element-wise equality check, with argument conversion to Cosmology. """ # TODO! it might eventually be worth the speed boost to implement some of # the internals of cosmology_equal here, but for now it's a hassle. return not cosmology_equal(cosmo1, cosmo2, allow_equivalent=allow_equivalent)