Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import legate.core.types as ty
import numpy as np
from legate.core import Future, ReductionOp, Store
from legate.core import Annotation, Future, ReductionOp, Store
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -2961,20 +2961,21 @@ def unary_op(
lhs = self.base
rhs = src._broadcast(lhs.shape)

task = self.context.create_auto_task(CuNumericOpCode.UNARY_OP)
task.add_output(lhs)
task.add_input(rhs)
task.add_scalar_arg(op.value, ty.int32)
self.add_arguments(task, args)
with Annotation(self.context, {"OpCode": op.name}):
task = self.context.create_auto_task(CuNumericOpCode.UNARY_OP)
task.add_output(lhs)
task.add_input(rhs)
task.add_scalar_arg(op.value, ty.int32)
self.add_arguments(task, args)

task.add_alignment(lhs, rhs)
task.add_alignment(lhs, rhs)

if multiout is not None:
for out in multiout:
task.add_output(out.base)
task.add_alignment(out.base, rhs)
if multiout is not None:
for out in multiout:
task.add_output(out.base)
task.add_alignment(out.base, rhs)

task.execute()
task.execute()

# Perform a unary reduction operation from one set of dimensions down to
# fewer
Expand Down Expand Up @@ -3010,10 +3011,6 @@ def unary_reduction(
0 if keepdims else lhs_array.ndim
)

task = self.context.create_auto_task(
CuNumericOpCode.SCALAR_UNARY_RED
)

if initial is not None:
assert not argred
fill_value = initial
Expand All @@ -3026,14 +3023,21 @@ def unary_reduction(
while lhs.ndim > 1:
lhs = lhs.project(0, 0)

task.add_reduction(lhs, _UNARY_RED_TO_REDUCTION_OPS[op])
task.add_input(rhs_array.base)
task.add_scalar_arg(op, ty.int32)
task.add_scalar_arg(rhs_array.shape, (ty.int64,))
with Annotation(
self.context, {"OpCode": op.name, "ArgRed?": str(argred)}
):
task = self.context.create_auto_task(
CuNumericOpCode.SCALAR_UNARY_RED
)

self.add_arguments(task, args)
task.add_reduction(lhs, _UNARY_RED_TO_REDUCTION_OPS[op])
task.add_input(rhs_array.base)
task.add_scalar_arg(op, ty.int32)
task.add_scalar_arg(rhs_array.shape, (ty.int64,))

task.execute()
self.add_arguments(task, args)

task.execute()

else:
# Before we perform region reduction, make sure to have the lhs
Expand Down Expand Up @@ -3062,18 +3066,21 @@ def unary_reduction(
"Need support for reducing multiple dimensions"
)

task = self.context.create_auto_task(CuNumericOpCode.UNARY_RED)
with Annotation(
self.context, {"OpCode": op.name, "ArgRed?": str(argred)}
):
task = self.context.create_auto_task(CuNumericOpCode.UNARY_RED)

task.add_input(rhs_array.base)
task.add_reduction(result, _UNARY_RED_TO_REDUCTION_OPS[op])
task.add_scalar_arg(axis, ty.int32)
task.add_scalar_arg(op, ty.int32)
task.add_input(rhs_array.base)
task.add_reduction(result, _UNARY_RED_TO_REDUCTION_OPS[op])
task.add_scalar_arg(axis, ty.int32)
task.add_scalar_arg(op, ty.int32)

self.add_arguments(task, args)
self.add_arguments(task, args)

task.add_alignment(result, rhs_array.base)
task.add_alignment(result, rhs_array.base)

task.execute()
task.execute()

if argred:
self.unary_op(
Expand Down Expand Up @@ -3107,18 +3114,19 @@ def binary_op(
rhs1 = src1._broadcast(lhs.shape)
rhs2 = src2._broadcast(lhs.shape)

# Populate the Legate launcher
task = self.context.create_auto_task(CuNumericOpCode.BINARY_OP)
task.add_output(lhs)
task.add_input(rhs1)
task.add_input(rhs2)
task.add_scalar_arg(op_code.value, ty.int32)
self.add_arguments(task, args)
with Annotation(self.context, {"OpCode": op_code.name}):
# Populate the Legate launcher
task = self.context.create_auto_task(CuNumericOpCode.BINARY_OP)
task.add_output(lhs)
task.add_input(rhs1)
task.add_input(rhs2)
task.add_scalar_arg(op_code.value, ty.int32)
self.add_arguments(task, args)

task.add_alignment(lhs, rhs1)
task.add_alignment(lhs, rhs2)
task.add_alignment(lhs, rhs1)
task.add_alignment(lhs, rhs2)

task.execute()
task.execute()

@auto_convert("src1", "src2")
def binary_reduction(
Expand Down