From 13d7703864f516494b4a41a9a3ad06dd204ef8c7 Mon Sep 17 00:00:00 2001 From: Bryan Van de Ven Date: Wed, 20 Jul 2022 11:41:49 -0700 Subject: [PATCH 1/2] Broaden add_scalar_arg type --- legate/core/launcher.py | 4 ++-- legate/core/operation.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/legate/core/launcher.py b/legate/core/launcher.py index 40dddc7b17..c0af4dfab6 100644 --- a/legate/core/launcher.py +++ b/legate/core/launcher.py @@ -120,7 +120,7 @@ def __init__( self, core_types: ty.TypeSystem, value: Any, - dtype: Union[DTType, tuple[DTType]], + dtype: Union[DTType, tuple[DTType, ...]], untyped: bool = True, ) -> None: self._core_types = core_types @@ -747,7 +747,7 @@ def __del__(self) -> None: def add_scalar_arg( self, value: Any, - dtype: DTType, + dtype: Union[DTType, tuple[DTType, ...]], untyped: bool = True, ) -> None: self._scalars.append( diff --git a/legate/core/operation.py b/legate/core/operation.py index 90b7f3c910..e69a0cca24 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -106,7 +106,7 @@ def all_unknowns(self) -> list[PartSym]: class TaskProtocol(OperationProtocol, Protocol): _task_id: int - _scalar_args: list[tuple[Any, DTType]] + _scalar_args: list[tuple[Any, Union[DTType, tuple[DTType, ...]]]] _comm_args: list[Communicator] @@ -225,7 +225,9 @@ def __init__( ) -> None: super().__init__(**kwargs) self._task_id = task_id - self._scalar_args: list[tuple[Any, DTType]] = [] + self._scalar_args: list[ + tuple[Any, Union[DTType, tuple[DTType, ...]]] + ] = [] self._comm_args: list[Communicator] = [] self._exn_types: list[type] = [] self._tb: Union[None, TracebackType] = None @@ -238,7 +240,9 @@ def get_name(self) -> str: libname = self.context.library.get_name() return f"{libname}.Task(tid:{self._task_id}, uid:{self._op_id})" - def add_scalar_arg(self, value: Any, dtype: DTType) -> None: + def add_scalar_arg( + self, value: Any, dtype: Union[DTType, tuple[DTType, ...]] + ) -> None: self._scalar_args.append((value, dtype)) def add_dtype_arg(self, dtype: DTType) -> None: From 97c35c87714929bf02c28b9f26901e5480b4cacf Mon Sep 17 00:00:00 2001 From: Bryan Van de Ven Date: Wed, 20 Jul 2022 12:42:24 -0700 Subject: [PATCH 2/2] only need to accept 1-tuples --- legate/core/launcher.py | 4 ++-- legate/core/operation.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/legate/core/launcher.py b/legate/core/launcher.py index c0af4dfab6..c609e12ab1 100644 --- a/legate/core/launcher.py +++ b/legate/core/launcher.py @@ -120,7 +120,7 @@ def __init__( self, core_types: ty.TypeSystem, value: Any, - dtype: Union[DTType, tuple[DTType, ...]], + dtype: Union[DTType, tuple[DTType]], untyped: bool = True, ) -> None: self._core_types = core_types @@ -747,7 +747,7 @@ def __del__(self) -> None: def add_scalar_arg( self, value: Any, - dtype: Union[DTType, tuple[DTType, ...]], + dtype: Union[DTType, tuple[DTType]], untyped: bool = True, ) -> None: self._scalars.append( diff --git a/legate/core/operation.py b/legate/core/operation.py index e69a0cca24..a122bb3b77 100644 --- a/legate/core/operation.py +++ b/legate/core/operation.py @@ -106,7 +106,7 @@ def all_unknowns(self) -> list[PartSym]: class TaskProtocol(OperationProtocol, Protocol): _task_id: int - _scalar_args: list[tuple[Any, Union[DTType, tuple[DTType, ...]]]] + _scalar_args: list[tuple[Any, Union[DTType, tuple[DTType]]]] _comm_args: list[Communicator] @@ -225,9 +225,7 @@ def __init__( ) -> None: super().__init__(**kwargs) self._task_id = task_id - self._scalar_args: list[ - tuple[Any, Union[DTType, tuple[DTType, ...]]] - ] = [] + self._scalar_args: list[tuple[Any, Union[DTType, tuple[DTType]]]] = [] self._comm_args: list[Communicator] = [] self._exn_types: list[type] = [] self._tb: Union[None, TracebackType] = None @@ -241,7 +239,7 @@ def get_name(self) -> str: return f"{libname}.Task(tid:{self._task_id}, uid:{self._op_id})" def add_scalar_arg( - self, value: Any, dtype: Union[DTType, tuple[DTType, ...]] + self, value: Any, dtype: Union[DTType, tuple[DTType]] ) -> None: self._scalar_args.append((value, dtype))