Skip to content

Commit 822f016

Browse files
committed
code refine
1 parent 395099b commit 822f016

2 files changed

Lines changed: 56 additions & 125 deletions

File tree

cpp/src/arrow/compute/api_scalar.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,14 @@ Result<Datum> Divide(const Datum& left, const Datum& right,
205205
ExecContext* ctx = NULLPTR);
206206

207207
/// \brief Raise the values of base array to the power of the exponent array values.
208-
/// Array values must be the same length. Base value 0 with negative integer exponent
209-
/// will raise divide by zero error. Nulls can be removed when base is 0 or 1 or when
210-
/// exponent is 0.
208+
/// Array values must be the same length. If either base or exponent is null the result
209+
/// will be null.
211210
///
212211
/// \param[in] left the base
213212
/// \param[in] right the exponent
214-
/// \param[in] options arithmetic options (enable/disable overflow checking and null
215-
/// removal), optional \param[in] ctx the function execution context, optional \return the
216-
/// elementwise base value raised to the power of exponent
213+
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
214+
/// \param[in] ctx the function execution context, optional
215+
/// \return the elementwise base value raised to the power of exponent
217216
ARROW_EXPORT
218217
Result<Datum> Power(const Datum& left, const Datum& right,
219218
ArithmeticOptions options = ArithmeticOptions(),

cpp/src/arrow/compute/kernels/scalar_arithmetic.cc

Lines changed: 51 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -235,132 +235,67 @@ struct DivideChecked {
235235
}
236236
};
237237

238-
template <typename T>
239-
inline T integer_power(KernelContext* ctx, T left, T right) {
240-
if (right < 0) {
241-
ctx->SetStatus(
242-
Status::Invalid("integers to negative integer powers are not allowed"));
243-
}
244-
T result = 1;
245-
if (left == 0 && right != 0) {
246-
return 0;
247-
}
248-
while (true) {
249-
if (right % 2) {
250-
result *= left;
251-
}
252-
right /= 2;
253-
if (!right) {
254-
break;
255-
}
256-
left *= left;
257-
}
258-
return result;
259-
}
260-
261-
template <typename T>
262-
inline T signed_integer_power(KernelContext* ctx, T left, T right) {
263-
if (right < 0) {
264-
ctx->SetStatus(
265-
Status::Invalid("integers to negative integer powers are not allowed"));
266-
}
267-
T result = 1;
268-
if (left == 0 && right != 0) {
269-
return 0;
270-
}
271-
while (true) {
272-
if (right % 2) {
273-
result = to_unsigned(result) * to_unsigned(left);
274-
}
275-
right /= 2;
276-
if (!right) {
277-
break;
278-
}
279-
left = to_unsigned(left) * to_unsigned(left);
280-
}
281-
return result;
282-
}
283-
284-
template <typename T, typename Arg0, typename Arg1>
285-
inline T integer_power_checked(KernelContext* ctx, Arg0 left, Arg1 right) {
286-
if (right < 0) {
287-
ctx->SetStatus(
288-
Status::Invalid("integers to negative integer powers are not allowed"));
289-
}
290-
T result = 1;
291-
if (left == 0 && right != 0) {
292-
return 0;
293-
}
294-
while (true) {
295-
if (right % 2) {
296-
if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(result, left, &result))) {
297-
ctx->SetStatus(Status::Invalid("overflow"));
298-
}
299-
}
300-
right /= 2;
301-
if (!right) {
302-
break;
303-
}
304-
if (ARROW_PREDICT_FALSE(MultiplyWithOverflow(left, left, &left))) {
305-
ctx->SetStatus(Status::Invalid("overflow"));
306-
}
307-
}
308-
return result;
309-
}
310-
311-
template <typename T, typename Arg0, typename Arg1>
312-
inline T power(KernelContext* ctx, Arg0 left, Arg1 right) {
313-
return std::pow(left, right);
314-
}
315-
316238
struct Power {
317-
template <typename T>
318-
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, T left, T right) {
319-
return integer_power<T>(ctx, left, right);
320-
}
321-
322-
template <typename T>
323-
static enable_if_signed_integer<T> Call(KernelContext* ctx, T left, T right) {
324-
return signed_integer_power<T>(ctx, left, right);
239+
ARROW_NOINLINE
240+
static uint64_t IntegerPower(uint64_t base, uint64_t exp) {
241+
// right to left O(logn) power
242+
uint64_t pow = 1;
243+
while (exp) {
244+
pow *= (exp & 1) ? base : 1;
245+
base *= base;
246+
exp >>= 1;
247+
}
248+
return pow;
325249
}
326250

327251
template <typename T>
328-
static enable_if_floating_point<T> Call(KernelContext* ctx, T left, T right) {
329-
return power<T>(ctx, left, right);
330-
}
331-
332-
// See comment about 16 bit integer multiplication in Multiply kernel.
333-
template <typename T = void>
334-
static int16_t Call(KernelContext* ctx, int16_t left, int16_t right) {
335-
if (right < 0) {
252+
static enable_if_integer<T> Call(KernelContext* ctx, T base, T exp) {
253+
if (exp < 0) {
336254
ctx->SetStatus(
337255
Status::Invalid("integers to negative integer powers are not allowed"));
256+
return 0;
338257
}
339-
return integer_power(ctx, static_cast<uint32_t>(left), static_cast<uint32_t>(right));
258+
return static_cast<T>(IntegerPower(base, exp));
340259
}
341-
template <typename T = void>
342-
static uint16_t Call(KernelContext* ctx, uint16_t left, uint16_t right) {
343-
return integer_power(ctx, static_cast<uint32_t>(left), static_cast<uint32_t>(right));
260+
261+
template <typename T>
262+
static enable_if_floating_point<T> Call(KernelContext* ctx, T base, T exp) {
263+
return std::pow(base, exp);
344264
}
345265
};
346266

347267
struct PowerChecked {
348-
template <typename T = void, typename Arg0 = void, typename Arg1 = void>
349-
static enable_if_unsigned_integer<T> Call(KernelContext* ctx, Arg0 left, Arg1 right) {
350-
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
351-
return integer_power_checked<T>(ctx, left, right);
352-
}
353-
354-
template <typename T = void, typename Arg0 = void, typename Arg1 = void>
355-
static enable_if_signed_integer<T> Call(KernelContext* ctx, Arg0 left, Arg1 right) {
356-
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
357-
return integer_power_checked<T>(ctx, left, right);
268+
template <typename T, typename Arg0, typename Arg1>
269+
static enable_if_integer<T> Call(KernelContext* ctx, Arg0 base, Arg1 exp) {
270+
if (exp < 0) {
271+
ctx->SetStatus(
272+
Status::Invalid("integers to negative integer powers are not allowed"));
273+
return 0;
274+
} else if (exp == 0) {
275+
return 1;
276+
}
277+
// left to right O(logn) power with overflow checks
278+
bool overflow = false;
279+
uint64_t bitmask =
280+
1ULL << (63 - BitUtil::CountLeadingZeros(static_cast<uint64_t>(exp)));
281+
T pow = 1;
282+
while (bitmask) {
283+
overflow |= MultiplyWithOverflow(pow, pow, &pow);
284+
if (exp & bitmask) {
285+
overflow |= MultiplyWithOverflow(pow, base, &pow);
286+
}
287+
bitmask >>= 1;
288+
}
289+
if (overflow) {
290+
ctx->SetStatus(Status::Invalid("overflow"));
291+
}
292+
return pow;
358293
}
359294

360295
template <typename T, typename Arg0, typename Arg1>
361-
static enable_if_floating_point<T> Call(KernelContext* ctx, Arg0 left, Arg1 right) {
296+
static enable_if_floating_point<T> Call(KernelContext* ctx, Arg0 base, Arg1 exp) {
362297
static_assert(std::is_same<T, Arg0>::value && std::is_same<T, Arg1>::value, "");
363-
return power<T>(ctx, left, right);
298+
return std::pow(base, exp);
364299
}
365300
};
366301

@@ -492,18 +427,15 @@ const FunctionDoc div_checked_doc{
492427

493428
const FunctionDoc pow_doc{
494429
"Raise arguments to power element-wise",
495-
("Raising zero to negative integer returns an error. However, integer overflow\n"
496-
"wraps around, and floating-point raising zero to negative integer returns an"
497-
"infinite.\n"
498-
"Use function \"power_checked_propagate_nulls\" if you want to get an error\n"
499-
"in all the aforementioned cases."),
500-
{"base", "power"}};
430+
("Integer to negative integer power returns an error. However, integer overflow\n"
431+
"wraps around. Floating poing power follows std::pow() behaviour.\n"),
432+
{"base", "exponent"}};
501433

502434
const FunctionDoc pow_checked_doc{
503435
"Raise arguments to power element-wise",
504-
("An error is returned when trying to raise zero to negative integer, or when\n"
505-
"integer overflow is encountered."),
506-
{"base", "power"}};
436+
("An error is returned when integer to negative integer power is encountered,\n"
437+
"or integer overflow is encountered."),
438+
{"base", "exponent"}};
507439

508440
} // namespace
509441

0 commit comments

Comments
 (0)