1818#include " arrow/compute/kernels/common.h"
1919#include " arrow/util/int_util.h"
2020
21+ #ifndef __has_builtin
22+ #define __has_builtin (x ) 0
23+ #endif
24+
2125namespace arrow {
2226namespace compute {
2327
@@ -35,6 +39,10 @@ using enable_if_signed_integer = enable_if_t<is_signed_integer<T>::value, T>;
3539template <typename T>
3640using enable_if_unsigned_integer = enable_if_t <is_unsigned_integer<T>::value, T>;
3741
42+ template <typename T>
43+ using enable_if_integer =
44+ enable_if_t <is_signed_integer<T>::value || is_unsigned_integer<T>::value, T>;
45+
3846template <typename T>
3947using enable_if_floating_point = enable_if_t <std::is_floating_point<T>::value, T>;
4048
@@ -60,6 +68,42 @@ struct Add {
6068 }
6169};
6270
71+ struct AddChecked {
72+ #if __has_builtin(__builtin_add_overflow)
73+ template <typename T>
74+ static enable_if_integer<T> Call (KernelContext* ctx, T left, T right) {
75+ T result;
76+ if (__builtin_add_overflow (left, right, &result)) {
77+ ctx->SetStatus (Status::Invalid (" overflow" ));
78+ }
79+ return result;
80+ }
81+ #else
82+ template <typename T>
83+ static enable_if_unsigned_integer<T> Call (KernelContext* ctx, T left, T right) {
84+ if (arrow::internal::HasAdditionOverflow (left, right)) {
85+ ctx->SetStatus (Status::Invalid (" overflow" ));
86+ }
87+ return left + right;
88+ }
89+
90+ template <typename T>
91+ static enable_if_signed_integer<T> Call (KernelContext* ctx, T left, T right) {
92+ auto unsigned_left = to_unsigned (left);
93+ auto unsigned_right = to_unsigned (right);
94+ if (arrow::internal::HasAdditionOverflow (unsigned_left, unsigned_right)) {
95+ ctx->SetStatus (Status::Invalid (" overflow" ));
96+ }
97+ return unsigned_left + unsigned_right;
98+ }
99+ #endif
100+
101+ template <typename T>
102+ static constexpr enable_if_floating_point<T> Call (KernelContext*, T left, T right) {
103+ return left + right;
104+ }
105+ };
106+
63107struct Subtract {
64108 template <typename T>
65109 static constexpr enable_if_floating_point<T> Call (KernelContext*, T left, T right) {
@@ -77,6 +121,40 @@ struct Subtract {
77121 }
78122};
79123
124+ struct SubtractChecked {
125+ #if __has_builtin(__builtin_sub_overflow)
126+ template <typename T>
127+ static enable_if_integer<T> Call (KernelContext* ctx, T left, T right) {
128+ T result;
129+ if (__builtin_sub_overflow (left, right, &result)) {
130+ ctx->SetStatus (Status::Invalid (" overflow" ));
131+ }
132+ return result;
133+ }
134+ #else
135+ template <typename T>
136+ static enable_if_unsigned_integer<T> Call (KernelContext* ctx, T left, T right) {
137+ if (arrow::internal::HasSubtractionOverflow (left, right)) {
138+ ctx->SetStatus (Status::Invalid (" overflow" ));
139+ }
140+ return left - right;
141+ }
142+
143+ template <typename T>
144+ static enable_if_signed_integer<T> Call (KernelContext* ctx, T left, T right) {
145+ if (arrow::internal::HasSubtractionOverflow (left, right)) {
146+ ctx->SetStatus (Status::Invalid (" overflow" ));
147+ }
148+ return to_unsigned (left) - to_unsigned (right);
149+ }
150+ #endif
151+
152+ template <typename T>
153+ static constexpr enable_if_floating_point<T> Call (KernelContext*, T left, T right) {
154+ return left - right;
155+ }
156+ };
157+
80158struct Multiply {
81159 static_assert (std::is_same<decltype (int8_t () * int8_t ()), int32_t>::value, "");
82160 static_assert (std::is_same<decltype (uint8_t () * uint8_t ()), int32_t>::value, "");
@@ -116,6 +194,29 @@ struct Multiply {
116194 }
117195};
118196
197+ struct MultiplyChecked {
198+ template <typename T>
199+ static enable_if_integer<T> Call (KernelContext* ctx, T left, T right) {
200+ T result;
201+ #if __has_builtin(__builtin_mul_overflow)
202+ if (__builtin_mul_overflow (left, right, &result)) {
203+ ctx->SetStatus (Status::Invalid (" overflow" ));
204+ }
205+ #else
206+ result = Multiply::Call (ctx, left, right);
207+ if (left != 0 && result / left != right) {
208+ ctx->SetStatus (Status::Invalid (" overflow" ));
209+ }
210+ #endif
211+ return result;
212+ }
213+
214+ template <typename T>
215+ static constexpr enable_if_floating_point<T> Call (KernelContext*, T left, T right) {
216+ return left * right;
217+ }
218+ };
219+
119220namespace codegen {
120221
121222// Generate a kernel given an arithmetic functor
@@ -168,8 +269,11 @@ namespace internal {
168269
169270void RegisterScalarArithmetic (FunctionRegistry* registry) {
170271 codegen::AddBinaryFunction<Add>(" add" , registry);
272+ codegen::AddBinaryFunction<AddChecked>(" add_checked" , registry);
171273 codegen::AddBinaryFunction<Subtract>(" subtract" , registry);
274+ codegen::AddBinaryFunction<SubtractChecked>(" subtract_checked" , registry);
172275 codegen::AddBinaryFunction<Multiply>(" multiply" , registry);
276+ codegen::AddBinaryFunction<MultiplyChecked>(" multiply_checked" , registry);
173277}
174278
175279} // namespace internal
0 commit comments