diff --git a/legate/core/launcher.py b/legate/core/launcher.py index 40dddc7b17..c609e12ab1 100644 --- a/legate/core/launcher.py +++ b/legate/core/launcher.py @@ -747,7 +747,7 @@ def __del__(self) -> None: def add_scalar_arg( self, value: Any, - dtype: DTType, + dtype: Union[DTType, tuple[DTType]], untyped: bool = True, ) -> None: self._scalars.append( diff --git a/legate/core/operation.py b/legate/core/operation.py index 90b7f3c910..a122bb3b77 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -106,7 +106,7 @@ def all_unknowns(self) -> list[PartSym]: class TaskProtocol(OperationProtocol, Protocol): _task_id: int - _scalar_args: list[tuple[Any, DTType]] + _scalar_args: list[tuple[Any, Union[DTType, tuple[DTType]]]] _comm_args: list[Communicator] @@ -225,7 +225,7 @@ def __init__( ) -> None: super().__init__(**kwargs) self._task_id = task_id - self._scalar_args: list[tuple[Any, DTType]] = [] + self._scalar_args: list[tuple[Any, Union[DTType, tuple[DTType]]]] = [] self._comm_args: list[Communicator] = [] self._exn_types: list[type] = [] self._tb: Union[None, TracebackType] = None @@ -238,7 +238,9 @@ def get_name(self) -> str: libname = self.context.library.get_name() return f"{libname}.Task(tid:{self._task_id}, uid:{self._op_id})" - def add_scalar_arg(self, value: Any, dtype: DTType) -> None: + def add_scalar_arg( + self, value: Any, dtype: Union[DTType, tuple[DTType]] + ) -> None: self._scalar_args.append((value, dtype)) def add_dtype_arg(self, dtype: DTType) -> None: