compared. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride. msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during the comparison. Can also passed as callable in which case it will be called with the generated message and should return the new message. Raises: ValueError: If no :class:`torch.Tensor` can be constructed from an input. ValueError: If only ``rtol`` or ``atol`` is specified. AssertionError: If corresponding inputs are not Python scalars and are not directly related. AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have different types. AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match. AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match. AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`. AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same :attr:`~torch.Tensor.layout`. AssertionError: If only one of corresponding tensors is quantized. AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s. AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same :attr:`~torch.Tensor.device`. AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``. AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride. AssertionError: If the values of corresponding tensors are not close according to the definition above. The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching ``dtype``'s, the maximum of both tolerances is used. +---------------------------+------------+----------+ | ``dtype`` | ``rtol`` | ``atol`` | +===========================+============+==========+ | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` | +---------------------------+------------+----------+ | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` | +---------------------------+------------+----------+ | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` | +---------------------------+------------+----------+ | other | ``0.0`` | ``0.0`` | +---------------------------+------------+----------+ .. note:: :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default: >>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0 Examples: >>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected) >>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected) >>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected) >>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected) >>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected) >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type and . >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type and . >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False) >>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer T) rî