diff --git a/extension/kernel_util/README.md b/extension/kernel_util/README.md new file mode 100644 index 00000000000..a3a1e653bdb --- /dev/null +++ b/extension/kernel_util/README.md @@ -0,0 +1,23 @@ +This header file `make_boxed_from_unboxed_functor.h` defines a template that can be used to create a boxed version of an unboxed functor. It is part of the executorch extension in the torch namespace. +## Requirements +This header requires C++17 or later. +## Usage +The template takes an unboxed function pointer and wraps it into a functor that takes `RuntimeContext` and `EValues` as inputs and returns void. The wrapped functor will unbox all inputs and forward them to the unboxed kernel. +Here is an example of how to use the template: +```C++ +Tensor& my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& out) { + // ... + return out; +} +Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", EXECUTORCH_FN(my_op)); +static auto res = register_kernels({my_kernel}); +``` +Alternatively, you can use the EXECUTORCH_LIBRARY macro to simplify the process: +```C++ +EXECUTORCH_LIBRARY(my_ns, "my_op", my_op); +``` +## Details +The template uses a lot of C++17 features to convert each EValue to the inferred argument type. It checks if the first argument is `RuntimeContext`, and if so, it removes it. The call method of the `WrapUnboxedIntoFunctor` struct calls the unboxed function with the corresponding arguments. +The `EXECUTORCH_LIBRARY` macro registers the kernel for the operation and stores the result in a static variable. +## Note +The `RuntimeContext` is a placeholder for a context that will be passed to kernels. It is currently empty, but it is planned to be used for kernel temp memory allocation and error handling in the future. diff --git a/extension/kernel_util/TARGETS b/extension/kernel_util/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/kernel_util/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/kernel_util/make_boxed_from_unboxed_functor.h b/extension/kernel_util/make_boxed_from_unboxed_functor.h new file mode 100644 index 00000000000..fa69ed944a7 --- /dev/null +++ b/extension/kernel_util/make_boxed_from_unboxed_functor.h @@ -0,0 +1,145 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===----------------------------------------------------------------------===// +/// \file extension/kernel_util/make_boxed_from_unboxed_functor.h +/// Defines a template that can be used to create a boxed version of an unboxed +/// functor. +/// Example usage: +/// ``` +/// Tensor& +/// my_op(RuntimeContext& ctx, const Tensor& self, const Tensor& other, Tensor& +/// out) { +/// // ... +/// return out; +/// } +/// +/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", +/// EXECUTORCH_FN(my_op)); +/// static auto res = register_kernels({my_kernel}); +/// ``` +/// Or simply: +/// ``` +/// EXECUTORCH_LIBRARY(my_ns, "my_op", my_op); +/// ``` +/// +/// The trick here is to convert each EValue to inferred argument type. This +/// uses a lot of C++17 features. +//===----------------------------------------------------------------------===// + +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +class KernelRuntimeContext; // Forward declaration +using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove + +// evalue_to_arg +template +struct decay_if_not_tensor final { + using type = std::decay_t; +}; +template <> +struct decay_if_not_tensor final { + using type = exec_aten::Tensor&; +}; +template <> +struct decay_if_not_tensor final { + using type = const exec_aten::Tensor&; +}; + +template +struct evalue_to_arg final { + static T call(EValue& v) { + return std::move(v).to(); + } +}; + +template <> +struct evalue_to_arg final { + static exec_aten::Tensor& call(EValue& v) { + return v.toTensor(); + } +}; + +template <> +struct evalue_to_arg final { + static const exec_aten::Tensor& call(EValue& v) { + return v.toTensor(); + } +}; +// Call functor with args from stack + +template +void call_functor_with_args_from_stack_( + RuntimeContext& ctx, + EValue** stack, + std::index_sequence, + typelist*) { + (*Functor::func_ptr())( + ctx, + evalue_to_arg::type>::call( + *stack[evalue_arg_indices])...); +} + +/** + * WrapUnboxedIntoFunctor: Given a function pointer, wrap it into a functor that + * takes EValues as input and returns void. The wrapped functor will unbox all + * inputs and forward them to unboxed kernel. + */ +template +struct WrapUnboxedIntoFunctor { + static_assert( + is_compile_time_function_pointer::value, + "Can't handle function other than EXECUTORCH_FN"); + using TrueType = typename FuncType::FuncType; + using ReturnType = typename infer_function_traits_t::return_type; + using ArgsType = typename infer_function_traits_t::parameter_types; + // check if the first argument is RuntimeContext, if so, remove it + static constexpr bool first_arg_is_context = std::is_same< + RuntimeContext, + std::remove_reference_t>>::value; + using ContextRemovedArgsType = std::conditional_t< + first_arg_is_context, + drop_if_nonempty_t, + ArgsType>; + + static void call(RuntimeContext& ctx, EValue** stack) { + constexpr size_t num_inputs = size::value; + return call_functor_with_args_from_stack_( + ctx, + stack, + std::make_index_sequence(), + static_cast(nullptr)); + } +}; + +template +static Kernel make_boxed_kernel(const char* name, FuncType) { + return Kernel(name, WrapUnboxedIntoFunctor::call); +} + +#define EXECUTORCH_LIBRARY(ns, op_name, func) \ + static auto res_##ns = register_kernels( \ + make_boxed_kernel(#ns "::" op_name, EXECUTORCH_FN(func))) +} // namespace executor +} // namespace torch diff --git a/extension/kernel_util/meta_programming.h b/extension/kernel_util/meta_programming.h new file mode 100644 index 00000000000..46262b843ea --- /dev/null +++ b/extension/kernel_util/meta_programming.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +// Check if a given type is a function +template +struct is_function_type : std::false_type {}; +template +struct is_function_type : std::true_type {}; +template +using is_function_type_t = typename is_function_type::type; + +// A compile-time wrapper around a function pointer +template +struct CompileTimeFunctionPointer final { + static_assert( + is_function_type::value, + "EXECUTORCH_FN can only wrap function types."); + using FuncType = FuncType_; + + static constexpr FuncType* func_ptr() { + return func_ptr_; + } +}; + +// Check if a given type is a compile-time function pointer +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; + +#define EXECUTORCH_FN_TYPE(func) \ + CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() + +/** + * strip_class: helper to remove the class type from pointers to `operator()`. + */ +template +struct strip_class {}; +template +struct strip_class { + using type = Result(Args...); +}; +template +struct strip_class { + using type = Result(Args...); +}; +template +using strip_class_t = typename strip_class::type; + +/** + * Access information about result type or arguments from a function type. + * Example: + * using A = function_traits::return_type // A == int + * using A = function_traits::parameter_types::tuple_type + * // A == tuple + */ +template +struct function_traits { + static_assert( + !std::is_same::value, + "In function_traits, Func must be a plain function type."); +}; +template +struct function_traits { + using func_type = Result(Args...); + using return_type = Result; + using parameter_types = typelist; + static constexpr auto number_of_parameters = sizeof...(Args); +}; + +/** + * infer_function_traits: creates a `function_traits` type for a simple + * function (pointer) or functor (lambda/struct). Currently does not support + * class methods. + */ +template +struct infer_function_traits { + using type = function_traits>; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +using infer_function_traits_t = typename infer_function_traits::type; + +} // namespace executor +} // namespace torch diff --git a/extension/kernel_util/targets.bzl b/extension/kernel_util/targets.bzl new file mode 100644 index 00000000000..81d4da10d15 --- /dev/null +++ b/extension/kernel_util/targets.bzl @@ -0,0 +1,29 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_library( + name = "kernel_util", + srcs = [], + exported_headers = [ + "make_boxed_from_unboxed_functor.h", + "meta_programming.h", + "type_list.h", + ], + visibility = [ + "//executorch/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/core:evalue", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/runtime/kernel:operator_registry", + ], + ) diff --git a/extension/kernel_util/test/TARGETS b/extension/kernel_util/test/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/kernel_util/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp b/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp new file mode 100644 index 00000000000..8bc534ca329 --- /dev/null +++ b/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using RuntimeContext = torch::executor::KernelRuntimeContext; +using namespace torch::executor; + +Tensor& my_op_out(RuntimeContext& ctx, const Tensor& a, Tensor& out) { + (void)ctx; + (void)a; + return out; +} + +Tensor& set_1_out(RuntimeContext& ctx, Tensor& out) { + (void)ctx; + out.mutable_data_ptr()[0] = 1; + return out; +} + +class MakeBoxedFromUnboxedFunctorTest : public ::testing::Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(MakeBoxedFromUnboxedFunctorTest, Basic) { + EXECUTORCH_LIBRARY(my_ns, "my_op.out", my_op_out); + EXPECT_TRUE(hasOpsFn("my_ns::my_op.out")); +} + +TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { + EXECUTORCH_LIBRARY(my_ns, "set_1.out", set_1_out); + EXPECT_TRUE(hasOpsFn("my_ns::set_1.out")); + + // prepare out tensor + TensorImpl::SizesType sizes[1] = {5}; + TensorImpl::DimOrderType dim_order[1] = {0}; + int32_t data[5] = {0, 0, 0, 0, 0}; + auto a_impl = TensorImpl(ScalarType::Int, 1, sizes, data, dim_order, nullptr); + auto a = Tensor(&a_impl); + + // get boxed callable + auto fn = getOpsFn("my_ns::set_1.out"); + + // run it + RuntimeContext context; + EValue values[1]; + values[0] = a; + EValue* stack[1]; + stack[0] = &values[0]; + + fn(context, stack); + + // check result + EXPECT_EQ(a.const_data_ptr()[0], 1); +} diff --git a/extension/kernel_util/test/targets.bzl b/extension/kernel_util/test/targets.bzl new file mode 100644 index 00000000000..122d4392c68 --- /dev/null +++ b/extension/kernel_util/test/targets.bzl @@ -0,0 +1,17 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + runtime.cxx_test( + name = "make_boxed_from_unboxed_functor_test", + srcs = [ + "make_boxed_from_unboxed_functor_test.cpp", + ], + deps = [ + "//executorch/extension/kernel_util:kernel_util", + ], + ) diff --git a/extension/kernel_util/type_list.h b/extension/kernel_util/type_list.h new file mode 100644 index 00000000000..f832ab9f267 --- /dev/null +++ b/extension/kernel_util/type_list.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/// +/// \file runtime/kernel/type_list.h +/// Forked from pytorch/c10/util/TypeList.h +/// \brief Utilities for working with type lists. +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include +#include + +namespace torch { +namespace executor { +/** + * Type holding a list of types for compile time type computations + * constexpr size_t num = size>::value; + * static_assert(num == 2, ""); + */ +template +struct false_t : std::false_type {}; + +template +struct typelist final { + public: + typelist() = delete; // not for instantiation +}; +template +struct size final { + static_assert( + false_t::value, + "In typelist::size, T must be typelist<...>."); +}; +template +struct size> final { + static constexpr size_t value = sizeof...(Types); +}; + +/** + * is_instantiation_of is true_type iff I is a template instantiation of T + * (e.g. vector is an instantiation of vector) Example: + * is_instantiation_of_t> // true + * is_instantiation_of_t> // true + * is_instantiation_of_t> // false + */ +template