#pragma once #include <c10/core/SafePyObject.h> #include <c10/macros/Macros.h> namespace at::impl { enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; struct TORCH_API PythonTorchFunctionTLS { static void set_disabled_state(TorchFunctionDisabledState disabled_state_); static TorchFunctionDisabledState get_disabled_state(); static void push_onto_stack(std::shared_ptr<SafePyObject> mode); static const std::shared_ptr<SafePyObject> pop_stack(); static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx); static int64_t stack_len(); static const PythonTorchFunctionTLS& get_state(); static void set_state(const PythonTorchFunctionTLS& state); private: // The mode TLS is split into // - disabled_state, which says which part of torch function are disabled // - stack_, which is a vector of modes representing the stack of user // defined modes TorchFunctionDisabledState disabled_state_ = TorchFunctionDisabledState::ENABLED; std::vector<std::shared_ptr<c10::SafePyObject>> stack_; }; TORCH_API bool torch_function_mode_enabled(); } // namespace at::impl
Memory