diff --git a/extension/pybindings/cpp_extension.py b/extension/pybindings/cpp_extension.py index 60498ce8136..ea6c4eb515f 100644 --- a/extension/pybindings/cpp_extension.py +++ b/extension/pybindings/cpp_extension.py @@ -1,24 +1,32 @@ -from torch.utils import cpp_extension import os +from torch.utils import cpp_extension + _HERE = os.path.abspath(__file__) _EXECUTORCH_PATH = os.path.dirname(os.path.dirname(_HERE)) -def load_inline(name, - cpp_sources, - functions=None, - extra_cflags=None, - extra_ldflags=None, - extra_include_paths=None, - build_directory=None, - verbose=False, - is_python_module=True, - with_pytorch_error_handling=True, - keep_intermediates=True, - use_pch=False): + +def load_inline( + name, + cpp_sources, + functions=None, + extra_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + verbose=False, + is_python_module=True, + with_pytorch_error_handling=True, + keep_intermediates=True, + use_pch=False, +): # Register the code into PyTorch aten_extra_cflags = ["-DUSE_ATEN_LIB"] + (extra_cflags if extra_cflags else []) - extra_ldflags = [f"-L{_EXECUTORCH_PATH}", f"-Wl,-rpath,{_EXECUTORCH_PATH}", "-lexecutorch"] + (extra_ldflags if extra_ldflags else []) + extra_ldflags = [ + f"-L{_EXECUTORCH_PATH}", + f"-Wl,-rpath,{_EXECUTORCH_PATH}", + "-lexecutorch", + ] + (extra_ldflags if extra_ldflags else []) module = cpp_extension.load_inline( name, cpp_sources, @@ -37,13 +45,13 @@ def load_inline(name, cpp_extension.load_inline( name, cpp_sources, - functions=None, # leave this out since we are not passing out any python module + functions=None, # leave this out since we are not passing out any python module extra_cflags=extra_cflags, extra_ldflags=extra_ldflags, extra_include_paths=extra_include_paths, build_directory=build_directory, verbose=verbose, - is_python_module=False, # don't register as a python module. Load shared library as a side effect. + is_python_module=False, # don't register as a python module. Load shared library as a side effect. with_pytorch_error_handling=with_pytorch_error_handling, keep_intermediates=keep_intermediates, use_pch=use_pch, diff --git a/runtime/kernel/make_boxed_from_unboxed_functor.h b/runtime/kernel/make_boxed_from_unboxed_functor.h index a024baad31f..7f261d6cfa7 100644 --- a/runtime/kernel/make_boxed_from_unboxed_functor.h +++ b/runtime/kernel/make_boxed_from_unboxed_functor.h @@ -19,8 +19,9 @@ /// return out; /// } /// -/// Kernel my_kernel = Kernel.make_boxed_kernel("my_ns::my_op", -/// EXECUTORCH_FN(my_op)); register_kernels({my_kernel}); +/// Kernel my_kernel = Kernel::make_boxed_kernel("my_ns::my_op", +/// EXECUTORCH_FN(my_op)); +/// static auto res = register_kernels({my_kernel}); /// ``` /// /// The trick here is to convert each EValue to inferred argument type. This @@ -34,11 +35,7 @@ #include #include -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { @@ -46,97 +43,6 @@ namespace executor { class KernelRuntimeContext; // Forward declaration using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove -// Check if a given type is a function -template -struct is_function_type : std::false_type {}; -template -struct is_function_type : std::true_type {}; -template -using is_function_type_t = typename is_function_type::type; - -// A compile-time wrapper around a function pointer -template -struct CompileTimeFunctionPointer final { - static_assert( - is_function_type::value, - "EXECUTORCH_FN can only wrap function types."); - using FuncType = FuncType_; - - static constexpr FuncType* func_ptr() { - return func_ptr_; - } -}; - -// Check if a given type is a compile-time function pointer -template -struct is_compile_time_function_pointer : std::false_type {}; -template -struct is_compile_time_function_pointer< - CompileTimeFunctionPointer> : std::true_type {}; - -#define EXECUTORCH_FN_TYPE(func) \ - CompileTimeFunctionPointer< \ - std::remove_pointer_t>, \ - func> -#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() - -/** - * strip_class: helper to remove the class type from pointers to `operator()`. - */ -template -struct strip_class {}; -template -struct strip_class { - using type = Result(Args...); -}; -template -struct strip_class { - using type = Result(Args...); -}; -template -using strip_class_t = typename strip_class::type; - -/** - * Access information about result type or arguments from a function type. - * Example: - * using A = function_traits::return_type // A == int - * using A = function_traits::parameter_types::tuple_type - * // A == tuple - */ -template -struct function_traits { - static_assert( - !std::is_same::value, - "In function_traits, Func must be a plain function type."); -}; -template -struct function_traits { - using func_type = Result(Args...); - using return_type = Result; - using parameter_types = typelist; - static constexpr auto number_of_parameters = sizeof...(Args); -}; - -/** - * infer_function_traits: creates a `function_traits` type for a simple - * function (pointer) or functor (lambda/struct). Currently does not support - * class methods. - */ -template -struct infer_function_traits { - using type = function_traits>; -}; -template -struct infer_function_traits { - using type = function_traits; -}; -template -struct infer_function_traits { - using type = function_traits; -}; -template -using infer_function_traits_t = typename infer_function_traits::type; - // evalue_to_arg template struct decay_if_not_tensor final { diff --git a/runtime/kernel/meta_programming.h b/runtime/kernel/meta_programming.h new file mode 100644 index 00000000000..b651176d9c8 --- /dev/null +++ b/runtime/kernel/meta_programming.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#if __cplusplus < 201703L +#error "This header requires C++17" +#endif + +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +// Check if a given type is a function +template +struct is_function_type : std::false_type {}; +template +struct is_function_type : std::true_type {}; +template +using is_function_type_t = typename is_function_type::type; + +// A compile-time wrapper around a function pointer +template +struct CompileTimeFunctionPointer final { + static_assert( + is_function_type::value, + "EXECUTORCH_FN can only wrap function types."); + using FuncType = FuncType_; + + static constexpr FuncType* func_ptr() { + return func_ptr_; + } +}; + +// Check if a given type is a compile-time function pointer +template +struct is_compile_time_function_pointer : std::false_type {}; +template +struct is_compile_time_function_pointer< + CompileTimeFunctionPointer> : std::true_type {}; + +#define EXECUTORCH_FN_TYPE(func) \ + CompileTimeFunctionPointer< \ + std::remove_pointer_t>, \ + func> +#define EXECUTORCH_FN(func) EXECUTORCH_FN_TYPE(func)() + +/** + * strip_class: helper to remove the class type from pointers to `operator()`. + */ +template +struct strip_class {}; +template +struct strip_class { + using type = Result(Args...); +}; +template +struct strip_class { + using type = Result(Args...); +}; +template +using strip_class_t = typename strip_class::type; + +/** + * Access information about result type or arguments from a function type. + * Example: + * using A = function_traits::return_type // A == int + * using A = function_traits::parameter_types::tuple_type + * // A == tuple + */ +template +struct function_traits { + static_assert( + !std::is_same::value, + "In function_traits, Func must be a plain function type."); +}; +template +struct function_traits { + using func_type = Result(Args...); + using return_type = Result; + using parameter_types = typelist; + static constexpr auto number_of_parameters = sizeof...(Args); +}; + +/** + * infer_function_traits: creates a `function_traits` type for a simple + * function (pointer) or functor (lambda/struct). Currently does not support + * class methods. + */ +template +struct infer_function_traits { + using type = function_traits>; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +struct infer_function_traits { + using type = function_traits; +}; +template +using infer_function_traits_t = typename infer_function_traits::type; + +} // namespace executor +} // namespace torch