From 6ba7cbe075bb3df96706c40efb3760509a1d8a04 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Thu, 23 Nov 2023 11:13:06 -0800 Subject: [PATCH 1/3] [kernel] Add template based unboxing Adding a new feature to allow users to bypass codegen and register their kernels directly. This is very useful for custom kernels for custom ops. Example usage: ``` Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { // ... return out; } Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op",EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); ``` [ghstack-poisoned] --- .../kernel/make_boxed_from_unboxed_functor.h | 221 ++++++++++++++++++ runtime/kernel/operator_registry.h | 10 + runtime/kernel/targets.bzl | 6 +- .../make_boxed_from_unboxed_functor_test.cpp | 77 ++++++ runtime/kernel/test/targets.bzl | 12 + runtime/kernel/type_list.h | 142 +++++++++++ 6 files changed, 467 insertions(+), 1 deletion(-) create mode 100644 runtime/kernel/make_boxed_from_unboxed_functor.h create mode 100644 runtime/kernel/test/make_boxed_from_unboxed_functor_test.cpp create mode 100644 runtime/kernel/type_list.h diff --git a/runtime/kernel/make_boxed_from_unboxed_functor.h b/runtime/kernel/make_boxed_from_unboxed_functor.h new file mode 100644 index 00000000000..a024baad31f --- /dev/null +++ b/runtime/kernel/make_boxed_from_unboxed_functor.h @@ -0,0 +1,221 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===----------------------------------------------------------------------===// +/// \file runtime/kernel/make_boxed_from_unboxed_functor.h +/// Defines a template that can be used to create a boxed version of an unboxed +/// functor. +/// Example usage: +/// ``` +/// Tensor& +/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& +/// out) { +/// // ... +/// return out; +/// } +/// +/// Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op", +/// EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); +/// ``` +/// +/// The trick here is to convert each EValue to inferred argument type. This +/// uses a lot of C++17 features. +//===----------------------------------------------------------------------===// + +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +class KernelRuntimeContext; // Forward declaration +using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove + +// Check if a given type is a function +template +struct is_function_type : std::false_type {}; +template +struct is_function_type : std::true_type {}; +template +using is_function_type_t = typename is_function_type::type; + +// A compile-time wrapper around a function pointer +template +struct CompileTimeFunctionPointer final { + static_assert( + is_function_type::value, + "EXECUTORCH_FN can only wrap function types."); + using FuncType = FuncType_; + + static constexpr FuncType* func_ptr() { + return func_ptr_; + } +}; + +// Check if a given type is a compile-time function pointer +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; + +#define EXECUTORCH_FN_TYPE(func) \ + CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() + +/** + * strip_class: helper to remove the class type from pointers to `operator()`. + */ +template +struct strip_class {}; +template +struct strip_class { + using type = Result(Args...); +}; +template +struct strip_class { + using type = Result(Args...); +}; +template +using strip_class_t = typename strip_class::type; + +/** + * Access information about result type or arguments from a function type. + * Example: + * using A = function_traits::return_type // A == int + * using A = function_traits::parameter_types::tuple_type + * // A == tuple + */ +template +struct function_traits { + static_assert( + !std::is_same::value, + "In function_traits, Func must be a plain function type."); +}; +template +struct function_traits { + using func_type = Result(Args...); + using return_type = Result; + using parameter_types = typelist; + static constexpr auto number_of_parameters = sizeof...(Args); +}; + +/** + * infer_function_traits: creates a `function_traits` type for a simple + * function (pointer) or functor (lambda/struct). Currently does not support + * class methods. + */ +template +struct infer_function_traits { + using type = function_traits>; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +using infer_function_traits_t = typename infer_function_traits::type; + +// evalue_to_arg +template +struct decay_if_not_tensor final { + using type = std::decay_t; +}; +template <> +struct decay_if_not_tensor final { + using type = exec_aten::Tensor&; +}; +template <> +struct decay_if_not_tensor final { + using type = const exec_aten::Tensor&; +}; + +template +struct evalue_to_arg final { + static T call(EValue& v) { + return std::move(v).to(); + } +}; + +template <> +struct evalue_to_arg final { + static exec_aten::Tensor& call(EValue& v) { + return v.toTensor(); + } +}; + +template <> +struct evalue_to_arg final { + static const exec_aten::Tensor& call(EValue& v) { + return v.toTensor(); + } +}; +// Call functor with args from stack + +template +void call_functor_with_args_from_stack_( + RuntimeContext& ctx, + EValue** stack, + std::index_sequence, + typelist*) { + (*Functor::func_ptr())( + ctx, + evalue_to_arg::type>::call( + *stack[evalue_arg_indices])...); +} + +/** + * WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that + * takes EValues as input and returns void. The wrapped functor will unbox all + * inputs and forward them to unboxed kernel. + */ +template +struct WrapUnboxedIntoFunctor { + static_assert( + is_compile_time_function_pointer::value, + "Can't handle function other than EXECUTORCH_FN"); + using TrueType = typename FuncType::FuncType; + using ReturnType = typename infer_function_traits_t::return_type; + using ArgsType = typename infer_function_traits_t::parameter_types; + // check if the first argument is RuntimeContext, if so, remove it + static constexpr bool first_arg_is_context = std::is_same< + RuntimeContext, + std::remove_reference_t>>::value; + using ContextRemovedArgsType = std::conditional_t< + first_arg_is_context, + drop_if_nonempty_t, + ArgsType>; + + static void call(RuntimeContext& ctx, EValue** stack) { + constexpr size_t num_inputs = size::value; + return call_functor_with_args_from_stack_( + ctx, + stack, + std::make_index_sequence(), + static_cast(nullptr)); + } +}; + +} // namespace executor +} // namespace torch diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index 55cb4164715..855a1dda7f7 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -17,6 +17,9 @@ #include #include #include +#if __cplusplus >= 201703L +#include +#endif // Debug switch for operator registry #if defined(ET_OP_REGISTRY_DEBUG) #include @@ -200,6 +203,13 @@ struct Kernel { explicit Kernel(const char* name, KernelKey key, OpFunction func) : name_(name), kernel_key_(key), op_(func) {} +#if __cplusplus >= 201703L + template + static inline Kernel make_boxed_kernel(const char* name, FuncType) { + return Kernel(name, WrapUnboxedIntoFunctor::call); + } +#endif + Kernel() {} }; diff --git a/runtime/kernel/targets.bzl b/runtime/kernel/targets.bzl index 88a5bd61989..62f051c2ea0 100644 --- a/runtime/kernel/targets.bzl +++ b/runtime/kernel/targets.bzl @@ -11,7 +11,11 @@ def define_common_targets(): runtime.cxx_library( name = "operator_registry", srcs = ["operator_registry.cpp"], - exported_headers = ["operator_registry.h"], + exported_headers = [ + "operator_registry.h", + "make_boxed_from_unboxed_functor.h", + "type_list.h", + ], visibility = [ "//executorch/...", "@EXECUTORCH_CLIENTS", diff --git a/runtime/kernel/test/make_boxed_from_unboxed_functor_test.cpp b/runtime/kernel/test/make_boxed_from_unboxed_functor_test.cpp new file mode 100644 index 00000000000..4dcaca805d9 --- /dev/null +++ b/runtime/kernel/test/make_boxed_from_unboxed_functor_test.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using RuntimeContext = torch::executor::KernelRuntimeContext; +using namespace torch::executor; + +Tensor& my_op_out(RuntimeContext& ctx, const Tensor& a, Tensor& out) { + (void) ctx; + (void) a; + return out; +} + +Tensor& set_1_out(RuntimeContext& ctx, Tensor& out) { + (void) ctx; + out.mutable_data_ptr()[0] = 1; + return out; +} + +class MakeBoxedFromUnboxedFunctorTest : public ::testing::Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(MakeBoxedFromUnboxedFunctorTest, Basic) { + Kernel my_kernel = + Kernel::make_boxed_kernel("my_ns::my_op.out", EXECUTORCH_FN(my_op_out)); + ArrayRef kernels_array = ArrayRef(my_kernel); + // patternlint-disable-next-line clang-diagnostic-unused-variable + auto s1 = register_kernels(kernels_array); + EXPECT_TRUE(hasOpsFn("my_ns::my_op.out")); +} + +TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { + Kernel my_kernel = + Kernel::make_boxed_kernel("my_ns::set_1.out", EXECUTORCH_FN(set_1_out)); + ArrayRef kernels_array = ArrayRef(my_kernel); + // patternlint-disable-next-line clang-diagnostic-unused-variable + auto s1 = register_kernels(kernels_array); + EXPECT_TRUE(hasOpsFn("my_ns::set_1.out")); + + // prepare out tensor + TensorImpl::SizesType sizes[1] = {5}; + TensorImpl::DimOrderType dim_order[1] = {0}; + int32_t data[5] = {0, 0, 0, 0, 0}; + auto a_impl = TensorImpl(ScalarType::Int, 1, sizes, data, dim_order, nullptr); + auto a = Tensor(&a_impl); + + // get boxed callable + auto fn = getOpsFn("my_ns::set_1.out"); + + // run it + RuntimeContext context; + EValue values[1]; + values[0] = a; + EValue* stack[1]; + stack[0] = &values[0]; + + fn(context, stack); + + // check result + EXPECT_EQ(a.const_data_ptr()[0], 1); +} diff --git a/runtime/kernel/test/targets.bzl b/runtime/kernel/test/targets.bzl index e35510a72e5..ee8c136642c 100644 --- a/runtime/kernel/test/targets.bzl +++ b/runtime/kernel/test/targets.bzl @@ -30,6 +30,18 @@ def define_common_targets(): ], ) + runtime.cxx_test( + name = "make_boxed_from_unboxed_functor_test", + srcs = [ + "make_boxed_from_unboxed_functor_test.cpp", + ], + deps = [ + "//executorch/runtime/kernel:operator_registry", + "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/runtime/core/exec_aten:lib", + ], + ) + et_operator_library( name = "executorch_all_ops", include_all_operators = True, diff --git a/runtime/kernel/type_list.h b/runtime/kernel/type_list.h new file mode 100644 index 00000000000..8f94b40311d --- /dev/null +++ b/runtime/kernel/type_list.h @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/// +/// \file runtime/kernel/type_list.h +/// Forked from pytorch/c10/util/TypeList.h +/// \brief Utilities for working with type lists. +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include + +namespace torch { +namespace executor { +/** + * Type holding a list of types for compile time type computations + * constexpr size_t num = size>::value; + * static_assert(num == 2, ""); + */ +template +struct false_t : std::false_type {}; + +template +struct typelist final { + public: + typelist() = delete; // not for instantiation +}; +template +struct size final { + static_assert( + false_t::value, + "In typelist::size, T must be typelist<...>."); +}; +template +struct size> final { + static constexpr size_t value = sizeof...(Types); +}; + +/** + * is_instantiation_of is true_type iff I is a template instantiation of T + * (e.g. vector is an instantiation of vector) Example: + * is_instantiation_of_t> // true + * is_instantiation_of_t> // true + * is_instantiation_of_t> // false + */ +template