TM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` use_reentrant(bool, optional): Use checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This allows ``checkpoint`` to support additional functionality, such as working as expected with ``torch.autograd.grad`` and support for keyword arguments input into the checkpointed function. Note that future versions of PyTorch will default to ``use_reentrant=False``. Default: ``True`` context_fn(Callable, optional): A callable returning a tuple of two context managers. The function and its recomputation will be run under the first and second context managers respectively. This argument is only supported if ``use_reentrant=False``. determinism_check(str, optional): A string specifying the determinism check to perform. By default it is set to ``"default"`` which compares the shapes, dtypes, and devices of the recomputed tensors against those the saved tensors. To turn off this check, specify ``"none"``. Currently these are the only two supported values. Please open an issue if you would like to see more determinism checks. This argument is only supported if ``use_reentrant=False``, if ``use_reentrant=True``, the determinism check is always disabled. debug(bool, optional): If ``True``, error messages will also include a trace of the operators ran during the original forward computation as well as the recomputation. This argument is only supported if ``use_reentrant=False``. args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr:`function` on :attr:`*args` Naa