Skip to content

Commit dfb3909

Browse files
kszucswesm
authored andcommitted
ARROW-9022: [C++] Add/Sub/Mul arithmetic kernels with overflow check
Quick draft for checked arithmetics. TODOs: - [x] more portable overflow checks - [x] consolidate the tests - [x] add arithmetics options to let the user choose which variant to run (so remove the `*Checked` functions) Closes #7420 from kszucs/ARROW-9022 Authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com> Signed-off-by: Wes McKinney <wesm@apache.org>
1 parent 285844e commit dfb3909

6 files changed

Lines changed: 240 additions & 60 deletions

File tree

cpp/src/arrow/compute/api_scalar.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,16 @@ namespace compute {
4141
// ----------------------------------------------------------------------
4242
// Arithmetic
4343

44-
SCALAR_EAGER_BINARY(Add, "add")
45-
SCALAR_EAGER_BINARY(Subtract, "subtract")
46-
SCALAR_EAGER_BINARY(Multiply, "multiply")
44+
#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
45+
Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
46+
ExecContext* ctx) { \
47+
auto func_name = (options.check_overflow) ? REGISTRY_CHECKED_NAME : REGISTRY_NAME; \
48+
return CallFunction(func_name, {left, right}, ctx); \
49+
}
50+
51+
SCALAR_ARITHMETIC_BINARY(Add, "add", "add_checked")
52+
SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked")
53+
SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
4754

4855
// ----------------------------------------------------------------------
4956
// Set-related operations

cpp/src/arrow/compute/api_scalar.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ namespace compute {
3535

3636
// ----------------------------------------------------------------------
3737

38+
struct ArithmeticOptions : public FunctionOptions {
39+
ArithmeticOptions() : check_overflow(false) {}
40+
bool check_overflow;
41+
};
42+
3843
/// \brief Add two values together. Array values must be the same length. If
3944
/// either addend is null the result will be null.
4045
///
@@ -43,7 +48,9 @@ namespace compute {
4348
/// \param[in] ctx the function execution context, optional
4449
/// \return the elementwise sum
4550
ARROW_EXPORT
46-
Result<Datum> Add(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
51+
Result<Datum> Add(const Datum& left, const Datum& right,
52+
ArithmeticOptions options = ArithmeticOptions(),
53+
ExecContext* ctx = NULLPTR);
4754

4855
/// \brief Subtract two values. Array values must be the same length. If the
4956
/// minuend or subtrahend is null the result will be null.
@@ -53,7 +60,9 @@ Result<Datum> Add(const Datum& left, const Datum& right, ExecContext* ctx = NULL
5360
/// \param[in] ctx the function execution context, optional
5461
/// \return the elementwise difference
5562
ARROW_EXPORT
56-
Result<Datum> Subtract(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
63+
Result<Datum> Subtract(const Datum& left, const Datum& right,
64+
ArithmeticOptions options = ArithmeticOptions(),
65+
ExecContext* ctx = NULLPTR);
5766

5867
/// \brief Multiply two values. Array values must be the same length. If either
5968
/// factor is null the result will be null.
@@ -63,7 +72,9 @@ Result<Datum> Subtract(const Datum& left, const Datum& right, ExecContext* ctx =
6372
/// \param[in] ctx the function execution context, optional
6473
/// \return the elementwise product
6574
ARROW_EXPORT
66-
Result<Datum> Multiply(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR);
75+
Result<Datum> Multiply(const Datum& left, const Datum& right,
76+
ArithmeticOptions options = ArithmeticOptions(),
77+
ExecContext* ctx = NULLPTR);
6778

6879
enum CompareOperator {
6980
EQUAL,

cpp/src/arrow/compute/kernels/scalar_arithmetic.cc

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
#include "arrow/compute/kernels/common.h"
1919
#include "arrow/util/int_util.h"
2020

21+
#ifndef __has_builtin
22+
#define __has_builtin(x) 0
23+
#endif
24+
2125
namespace arrow {
2226
namespace compute {
2327

@@ -35,6 +39,10 @@ using enable_if_signed_integer = enable_if_t<is_signed_integer<T>::value, T>;
3539
template <typename T>
3640
using enable_if_unsigned_integer = enable_if_t<is_unsigned_integer<T>::value, T>;
3741

42+
template <typename T>
43+
using enable_if_integer =
44+
enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, T>;
45+
3846
template <typename T>
3947
using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, T>;
4048

@@ -60,6 +68,42 @@ struct Add {
6068
}
6169
};
6270

71+
struct AddChecked {
72+
#if __has_builtin(__builtin_add_overflow)
73+
template <typename T>
74+
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
75+
T result;
76+
if (__builtin_add_overflow(left, right, &result)) {
77+
ctx->SetStatus(Status::Invalid("overflow"));
78+
}
79+
return result;
80+
}
81+
#else
82+
template <typename T>
83+
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, T left, T right) {
84+
if (arrow::internal::HasAdditionOverflow(left, right)) {
85+
ctx->SetStatus(Status::Invalid("overflow"));
86+
}
87+
return left + right;
88+
}
89+
90+
template <typename T>
91+
static enable_if_signed_integer<T> Call(KernelContext* ctx, T left, T right) {
92+
auto unsigned_left = to_unsigned(left);
93+
auto unsigned_right = to_unsigned(right);
94+
if (arrow::internal::HasAdditionOverflow(unsigned_left, unsigned_right)) {
95+
ctx->SetStatus(Status::Invalid("overflow"));
96+
}
97+
return unsigned_left + unsigned_right;
98+
}
99+
#endif
100+
101+
template <typename T>
102+
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
103+
return left + right;
104+
}
105+
};
106+
63107
struct Subtract {
64108
template <typename T>
65109
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
@@ -77,6 +121,40 @@ struct Subtract {
77121
}
78122
};
79123

124+
struct SubtractChecked {
125+
#if __has_builtin(__builtin_sub_overflow)
126+
template <typename T>
127+
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
128+
T result;
129+
if (__builtin_sub_overflow(left, right, &result)) {
130+
ctx->SetStatus(Status::Invalid("overflow"));
131+
}
132+
return result;
133+
}
134+
#else
135+
template <typename T>
136+
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, T left, T right) {
137+
if (arrow::internal::HasSubtractionOverflow(left, right)) {
138+
ctx->SetStatus(Status::Invalid("overflow"));
139+
}
140+
return left - right;
141+
}
142+
143+
template <typename T>
144+
static enable_if_signed_integer<T> Call(KernelContext* ctx, T left, T right) {
145+
if (arrow::internal::HasSubtractionOverflow(left, right)) {
146+
ctx->SetStatus(Status::Invalid("overflow"));
147+
}
148+
return to_unsigned(left) - to_unsigned(right);
149+
}
150+
#endif
151+
152+
template <typename T>
153+
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
154+
return left - right;
155+
}
156+
};
157+
80158
struct Multiply {
81159
static_assert(std::is_same<decltype(int8_t() * int8_t()), int32_t>::value, "");
82160
static_assert(std::is_same<decltype(uint8_t() * uint8_t()), int32_t>::value, "");
@@ -116,6 +194,29 @@ struct Multiply {
116194
}
117195
};
118196

197+
struct MultiplyChecked {
198+
template <typename T>
199+
static enable_if_integer<T> Call(KernelContext* ctx, T left, T right) {
200+
T result;
201+
#if __has_builtin(__builtin_mul_overflow)
202+
if (__builtin_mul_overflow(left, right, &result)) {
203+
ctx->SetStatus(Status::Invalid("overflow"));
204+
}
205+
#else
206+
result = Multiply::Call(ctx, left, right);
207+
if (left != 0 && result / left != right) {
208+
ctx->SetStatus(Status::Invalid("overflow"));
209+
}
210+
#endif
211+
return result;
212+
}
213+
214+
template <typename T>
215+
static constexpr enable_if_floating_point<T> Call(KernelContext*, T left, T right) {
216+
return left * right;
217+
}
218+
};
219+
119220
namespace codegen {
120221

121222
// Generate a kernel given an arithmetic functor
@@ -168,8 +269,11 @@ namespace internal {
168269

169270
void RegisterScalarArithmetic(FunctionRegistry* registry) {
170271
codegen::AddBinaryFunction<Add>("add", registry);
272+
codegen::AddBinaryFunction<AddChecked>("add_checked", registry);
171273
codegen::AddBinaryFunction<Subtract>("subtract", registry);
274+
codegen::AddBinaryFunction<SubtractChecked>("subtract_checked", registry);
172275
codegen::AddBinaryFunction<Multiply>("multiply", registry);
276+
codegen::AddBinaryFunction<MultiplyChecked>("multiply_checked", registry);
173277
}
174278

175279
} // namespace internal

cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ namespace compute {
3030

3131
constexpr auto kSeed = 0x94378165;
3232

33-
using BinaryOp = Result<Datum>(const Datum&, const Datum&, ExecContext*);
33+
using BinaryOp = Result<Datum>(const Datum&, const Datum&, ArithmeticOptions,
34+
ExecContext*);
3435

3536
template <BinaryOp& Op, typename ArrowType, typename CType = typename ArrowType::c_type>
3637
static void ArrayScalarKernel(benchmark::State& state) {
@@ -46,7 +47,7 @@ static void ArrayScalarKernel(benchmark::State& state) {
4647

4748
Datum fifteen(CType(15));
4849
for (auto _ : state) {
49-
ABORT_NOT_OK(Op(lhs, fifteen, nullptr).status());
50+
ABORT_NOT_OK(Op(lhs, fifteen, ArithmeticOptions(), nullptr).status());
5051
}
5152
state.SetItemsProcessed(state.iterations() * array_size);
5253
}
@@ -66,7 +67,7 @@ static void ArrayArrayKernel(benchmark::State& state) {
6667
rand.Numeric<ArrowType>(array_size, min, max, args.null_proportion));
6768

6869
for (auto _ : state) {
69-
ABORT_NOT_OK(Op(lhs, rhs, nullptr).status());
70+
ABORT_NOT_OK(Op(lhs, rhs, ArithmeticOptions(), nullptr).status());
7071
}
7172
state.SetItemsProcessed(state.iterations() * array_size);
7273
}

0 commit comments

Comments
 (0)