#define TORCH_ASSERT_ONLY_METHOD_OPERATORS // ${generated_comment} #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else // needed for the meta tensor calls to get stride info in functionalization #include // needed for special handling of copy_(). // See Note [functionalizating copy_() and not preserving strides] #include #include $ops_headers #endif namespace at { namespace functionalization { // This keyset is used by functionalization when it calls into meta kernels // to accurately propagate stride metadata. // Exclude any modes: the purpose of calling into meta kernels is only as an implementation // detail to perform shape inference, and we don't want any modal keys to run. // Specifically, we want to prevent functionalization and Python modes from running. constexpr auto exclude_keys_for_meta_dispatch = c10::functorch_transforms_ks | c10::DispatchKeySet({ c10::DispatchKey::FuncTorchDynamicLayerBackMode, c10::DispatchKey::FuncTorchDynamicLayerFrontMode, c10::DispatchKey::Python }); inline Tensor to_meta(const Tensor& t) { if (!t.defined()) return t; return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(), /*dtype=*/c10::make_optional(t.scalar_type()), /*layout=*/c10::make_optional(t.layout()), /*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt); } inline c10::optional to_meta(const c10::optional& t) { if (t.has_value()) { return c10::make_optional(to_meta(*t)); } return c10::nullopt; } inline std::vector to_meta(at::ITensorListRef t_list) { std::vector outputs; outputs.reserve(t_list.size()); for (const auto& tensor : t_list) { outputs.push_back(to_meta(tensor)); } return outputs; } inline c10::List to_meta(const c10::List& t_list) { c10::List outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(to_meta(t_list[i])); } return outputs; } inline c10::List> to_meta(const c10::List>& t_list) { c10::List> outputs; outputs.reserve(t_list.size()); for (const auto i : c10::irange(t_list.size())) { outputs.push_back(to_meta(t_list[i])); } return outputs; } ${func_definitions} } // namespace functionalization namespace { TORCH_LIBRARY_IMPL(aten, Functionalize, m) { ${func_registrations}; } } // namespace } // namespace at