#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
// ${generated_comment}
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/FunctionalInverses.h>
#include <ATen/MemoryOverlap.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Operators.h>
#include <ATen/NativeFunctions.h>
#else
// needed for the meta tensor calls to get stride info in functionalization
#include <ATen/ops/empty_strided_native.h>
// needed for special handling of copy_().
// See Note [functionalizating copy_() and not preserving strides]
#include <ATen/ops/to_ops.h>
#include <ATen/ops/expand_copy_ops.h>
$ops_headers
#endif
namespace at {
namespace functionalization {
// This keyset is used by functionalization when it calls into meta kernels
// to accurately propagate stride metadata.
// Exclude any modes: the purpose of calling into meta kernels is only as an implementation
// detail to perform shape inference, and we don't want any modal keys to run.
// Specifically, we want to prevent functionalization and Python modes from running.
constexpr auto exclude_keys_for_meta_dispatch =
c10::functorch_transforms_ks |
c10::DispatchKeySet({
c10::DispatchKey::FuncTorchDynamicLayerBackMode,
c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
c10::DispatchKey::Python,
c10::DispatchKey::PreDispatch,
});
// Helper around at::has_internal_overlap.
// The ATen util is used in hot-path eager mode: it's always fast,
// but might return TOO_HARD sometimes.
// During functionalization, we're ok taking a bit longer
// to detect memory overlap.
inline bool has_internal_overlap_helper(const at::Tensor t) {
auto has_overlap = at::has_internal_overlap(t);
if (has_overlap == at::MemOverlap::Yes) return true;
if (has_overlap == at::MemOverlap::No) return false;
return false;
}
inline Tensor to_meta(const Tensor& t) {
if (!t.defined()) return t;
return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
/*dtype=*/c10::make_optional(t.scalar_type()), /*layout=*/c10::make_optional(t.layout()),
/*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt);
}
inline std::optional<Tensor> to_meta(const std::optional<Tensor>& t) {
if (t.has_value()) {
return c10::make_optional<Tensor>(to_meta(*t));
}
return c10::nullopt;
}
inline std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
std::vector<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto& tensor : t_list) {
outputs.push_back(to_meta(tensor));
}
return outputs;
}
inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
c10::List<Tensor> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_meta(t_list[i]));
}
return outputs;
}
inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
c10::List<::std::optional<Tensor>> outputs;
outputs.reserve(t_list.size());
for (const auto i : c10::irange(t_list.size())) {
outputs.push_back(to_meta(t_list[i]));
}
return outputs;
}
${func_definitions}
} // namespace functionalization
namespace {
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
${func_registrations};
}
} // namespace
} // namespace at