diff --git a/kernels/optimized/cpu/op_where.cpp b/kernels/optimized/cpu/op_where.cpp new file mode 100644 index 00000000000..4d897ea6281 --- /dev/null +++ b/kernels/optimized/cpu/op_where.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& opt_where_out( + KernelRuntimeContext& ctx, + const Tensor& cond, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "where.self_out"; + + if (a.scalar_type() == b.scalar_type() && + a.scalar_type() == out.scalar_type() && a.scalar_type() == compute_type && + // Using a Byte tensor for cond has been deprecated for a long time. + cond.scalar_type() == ScalarType::Bool) { + auto out_numel = out.numel(); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes()); + const bool any_is_broadcasted = + (a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted); + const CTYPE_COMPUTE* const data_a = a.const_data_ptr(); + const CTYPE_COMPUTE* const data_b = b.const_data_ptr(); + const bool* const data_cond = cond.const_data_ptr(); + CTYPE_COMPUTE* const data_out = out.data_ptr(); + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, cond_index] : + BroadcastIndexesRange<3>(out, a, b, cond)) { + data_out[out_index] = + data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; + } + } else { + for (const auto i : c10::irange(out_numel)) { + data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; + } + } + }); + } else { + // Fall back for mixed dtype to keep code size and compile time + // reasonable. + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_tritensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, + const CTYPE_COMPUTE val_b, + const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + cond, + utils::SupportedTensorDtypes::BOOL_OR_BYTE, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 83b2c320266..dc189708992 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:broadcast_util", ], ), + op_target( + name = "op_where", + deps = [ + "//executorch/kernels/portable/cpu/util:elementwise_util", + ], + ), ) diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index fd5143b1511..4f90059aa93 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -101,3 +101,8 @@ kernels: - arg_meta: null kernel_name: torch::executor::opt_sub_scalar_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_where_out diff --git a/kernels/portable/cpu/util/broadcast_util.h b/kernels/portable/cpu/util/broadcast_util.h index 10bd07baee2..f6bfae9bdaa 100644 --- a/kernels/portable/cpu/util/broadcast_util.h +++ b/kernels/portable/cpu/util/broadcast_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include @@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn( const CTYPE_B* const data_b = b.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]); } + } else { + for (const auto i : c10::irange(out.numel())) { + size_t a_linear_index = i; + size_t b_linear_index = i; - data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]); + } } } @@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn( const CTYPE_C* const data_c = c.const_data_ptr(); CTYPE_OUT* const data_out = out.mutable_data_ptr(); - for (const auto i : c10::irange(out.numel())) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + data_out[out_index] = + compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]); + } + } else { + for (const auto i : c10::irange(out.numel())) { + data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]); } - - data_out[i] = compute_fun( - data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]); } } diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 778006f1b99..09db5f7180d 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index] : + BroadcastIndexesRange<2>(out, a, b)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } @@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn( char* const data_out = reinterpret_cast(out.mutable_data_ptr()); auto out_numel = out.numel(); - for (const auto i : c10::irange(out_numel)) { - size_t a_linear_index = i; - size_t b_linear_index = i; - size_t c_linear_index = i; - - if (any_is_broadcasted) { - size_t out_indexes[kTensorDimensionLimit]; - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); - - if (a_is_broadcasted) { - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); - } - if (b_is_broadcasted) { - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); - } - if (c_is_broadcasted) { - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); - } + if (any_is_broadcasted) { + for (const auto [out_index, a_index, b_index, c_index] : + BroadcastIndexesRange<3>(out, a, b, c)) { + auto result = compute_fun( + load_a_to_common(&data_a[a_index * a_element_size]), + load_b_to_common(&data_b[b_index * b_element_size]), + load_c_to_common(&data_c[c_index * c_element_size])); + store_common_to_out(result, &data_out[out_index * out_element_size]); + } + } else { + for (const auto i : c10::irange(out_numel)) { + size_t a_linear_index = i; + size_t b_linear_index = i; + size_t c_linear_index = i; + + auto result = compute_fun( + load_a_to_common(&data_a[a_linear_index * a_element_size]), + load_b_to_common(&data_b[b_linear_index * b_element_size]), + load_c_to_common(&data_c[c_linear_index * c_element_size])); + store_common_to_out(result, &data_out[i * out_element_size]); } - - auto result = compute_fun( - load_a_to_common(&data_a[a_linear_index * a_element_size]), - load_b_to_common(&data_b[b_linear_index * b_element_size]), - load_c_to_common(&data_c[c_linear_index * c_element_size])); - store_common_to_out(result, &data_out[i * out_element_size]); } } diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index c42f38fd8b0..739bc117fbf 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -70,6 +70,9 @@ def define_common_targets(): exported_headers = [ "broadcast_util.h", ], + exported_deps = [ + ":broadcast_indexes_range", + ], deps = [ ":repeat_util", "//executorch/runtime/kernel:kernel_includes", diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index 24adb8d9c80..394ec241698 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -275,6 +275,7 @@ set(_optimized_kernels_test_sources "op_native_layer_norm_test.cpp" "op_neg_test.cpp" "op_sub_test.cpp" + "op_where_test.cpp" "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp" ${CMAKE_CURRENT_BINARY_DIR}/include/optimized/executorch/kernels/test/supported_features.cpp )