diff --git a/distributed/core.py b/distributed/core.py index 84e6c8aebfc..6d80706570b 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -31,6 +31,7 @@ from tornado.ioloop import IOLoop import dask +from dask.typing import NoDefault, no_default from dask.utils import parse_timedelta from distributed import profile, protocol @@ -55,6 +56,7 @@ has_keyword, import_file, iscoroutinefunction, + log_errors, offload, recursive_to_dict, truncate_exception, @@ -65,6 +67,7 @@ if TYPE_CHECKING: from typing_extensions import ParamSpec, Self + from distributed.batched import BatchedSend from distributed.counter import Digest P = ParamSpec("P") @@ -99,6 +102,11 @@ class Status(Enum): Status.lookup = {s.name: s for s in Status} # type: ignore +class RPCCall: + def __getattr__(self, key: str) -> Callable[..., Awaitable]: + raise NotImplementedError() + + class RPCClosed(IOError): pass @@ -427,6 +435,7 @@ def __init__( "echo": self.echo, "connection_stream": self.handle_stream, "dump_state": self._to_dict, + "_ordered_send_payload": self._handle_ordered_send_payload, } self.handlers.update(handlers) if blocked_handlers is None: @@ -434,7 +443,12 @@ def __init__( "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] ) self.blocked_handlers = blocked_handlers - self.stream_handlers = {} + self.stream_handlers = { + "__ordered_send": self._handle_ordered_send, + "__ordered_rcv": self._handle_ordered_rcv, + } + self._side_channel_payload = {} + self._side_channel_arrived = defaultdict(asyncio.Event) self.stream_handlers.update(stream_handlers or {}) self.id = type(self).__name__ + "-" + str(uuid.uuid4()) @@ -532,7 +546,15 @@ def set_thread_ident(): timeout=timeout, server=self, ) + import itertools + + self._counter = itertools.count() + self._responses = {} + self._waiting_for = deque() + self._ensure_order = asyncio.Condition() + self._batched_comms = {} + self._batched_comms_locks = defaultdict(asyncio.Lock) self.__stopped = False async def upload_file( @@ -1063,6 +1085,135 @@ async def handle_stream( await comm.close() assert comm.closed() + async def _handle_ordered_send_payload(self, sig, payload, origin): + # FIXME: If something goes wrong, this can leak memory + # We'd need a callback for when the incoming connection is closed to + # clean this up + key = (origin, sig) + self._side_channel_payload[key] = payload + self._side_channel_arrived[key].set() + + async def _handle_ordered_send( + self, sig, user_op, origin, user_kwargs, use_side_channel, **extra + ): + # Note: The backchannel is currently unique. It's currently unclear if + # we need more control here + bcomm = await self._get_bcomm(origin) + try: + if use_side_channel: + assert user_kwargs is None + key = (origin, sig) + await self._side_channel_arrived[key].wait() + user_kwargs = self._side_channel_payload.pop(key) + result = self.handlers[user_op](**merge(extra, user_kwargs)) + if inspect.isawaitable(result): + result = await result + bcomm.send({"op": "__ordered_rcv", "sig": sig, "result": result}) + except Exception as e: + exc_info = error_message(e) + bcomm.send({"op": "__ordered_rcv", "sig": sig, "exc_info": exc_info}) + + async def _handle_ordered_rcv(self, sig, result=no_default, exc_info=no_default): + fut = self._responses[sig] + if result is not no_default: + assert exc_info is no_default + fut.set_result(result) + elif exc_info is not no_default: + assert result is no_default + _, exc, tb = clean_exception(**exc_info) + fut.set_exception(exc.with_traceback(tb)) + else: + raise RuntimeError("Unreachable") + + @log_errors + async def ordered_rpc( + self, + addr: str | NoDefault = no_default, + bcomm: BatchedSend | NoDefault = no_default, + use_side_channel: bool = False, + ) -> RPCCall: + # TODO: Allow different channels? + if addr is not no_default: + assert bcomm is no_default + bcomm = await self._get_bcomm(addr) + else: + assert bcomm is not no_default + addr = bcomm.comm.peer_address + + server = self + + class OrderedRPC(RPCCall): + def __init__(self, bcomm): + self._bcomm = bcomm + + def __getattr__(self, key): + async def send_recv_from_rpc(**kwargs): + sig = next(server._counter) + msg = { + "op": "__ordered_send", + "sig": sig, + "user_op": key, + "origin": server.address, + "use_side_channel": use_side_channel, + } + if not use_side_channel: + msg["user_kwargs"] = kwargs + else: + msg["user_kwargs"] = None + self._bcomm.send(msg) + fut = asyncio.Future() + server._responses[sig] = fut + server._waiting_for.append(sig) + if use_side_channel: + # Note: We may even want to consider moving this to a + # background task + async def _(): + await server.rpc(addr)._ordered_send_payload( + sig=sig, + payload=kwargs, + origin=server.address, + ) + + server._ongoing_background_tasks.call_soon(_) + + async def watch_comm(): + while True: + if self._bcomm.comm.closed(): + fut.set_exception(CommClosedError) + break + await asyncio.sleep(0.1) + + t = asyncio.create_task(watch_comm()) + + def is_next(): + return server._waiting_for[0] == sig + + try: + async with server._ensure_order: + await server._ensure_order.wait_for(is_next) + return await fut + finally: + t.cancel() + server._waiting_for.popleft() + + return send_recv_from_rpc + + return OrderedRPC(bcomm) + + async def _get_bcomm(self, addr): + async with self._batched_comms_locks[addr]: + if addr in self._batched_comms: + bcomm = self._batched_comms[addr] + if not bcomm.comm.closed(): + return bcomm + from distributed.batched import BatchedSend + + comm = await self.rpc.connect(addr) + await comm.write({"op": "connection_stream"}) + self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01) + bcomm.start(comm) + return bcomm + async def close(self, timeout: float | None = None, reason: str = "") -> None: try: for pc in self.periodic_callbacks.values(): @@ -1365,7 +1516,7 @@ def __repr__(self): return "" % (self.address, len(self.comms)) -class PooledRPCCall: +class PooledRPCCall(RPCCall): """The result of ConnectionPool()('host:port') See Also: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b173a478330..d6a2aadcf2d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4173,7 +4173,7 @@ async def log_errors(func): def heartbeat_worker( self, *, - address: str, + worker: str, resolve_address: bool = True, now: float | None = None, resources: dict[str, float] | None = None, @@ -4182,7 +4182,7 @@ def heartbeat_worker( executing: dict[Key, float] | None = None, extensions: dict | None = None, ) -> dict[str, Any]: - address = self.coerce_address(address, resolve_address) + address = self.coerce_address(worker, resolve_address) address = normalize_address(address) ws = self.workers.get(address) if ws is None: @@ -4361,7 +4361,7 @@ async def add_worker( self.aliases[name] = address self.heartbeat_worker( - address=address, + worker=address, resolve_address=resolve_address, now=now, resources=resources, diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 486e782f025..e6697069539 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -313,7 +313,7 @@ async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None: if assigned_worker != self.local_address: result = await self.scheduler.shuffle_restrict_task( - id=self.id, run_id=self.run_id, key=key, worker=assigned_worker + id=self.id, run_id=self.run_id, key=key, assigned_worker=assigned_worker ) if result["status"] == "error": raise RuntimeError(result["message"]) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index a7a081d7239..3f70a592605 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -774,7 +774,7 @@ def create_run_on_worker( local_address=plugin.worker.address, rpc=plugin.worker.rpc, digest_metric=plugin.worker.digest_metric, - scheduler=plugin.worker.scheduler, + scheduler=plugin.worker.scheduler_ordered, # type: ignore memory_limiter_disk=plugin.memory_limiter_disk, memory_limiter_comms=plugin.memory_limiter_comms, disk=self.disk, diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 09d97fffc9a..a4b736c87a4 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -78,7 +78,9 @@ async def start(self, scheduler: Scheduler) -> None: def shuffle_ids(self) -> set[ShuffleId]: return set(self.active_shuffles) - async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: + async def barrier( + self, id: ShuffleId, run_id: int, consistent: bool, worker: None + ) -> None: shuffle = self.active_shuffles[id] if shuffle.run_id != run_id: raise ValueError(f"{run_id=} does not match {shuffle}") @@ -98,7 +100,9 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: workers=list(shuffle.participating_workers), ) - def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> dict: + def restrict_task( + self, id: ShuffleId, run_id: int, key: Key, assigned_worker: str, worker: str + ) -> dict: shuffle = self.active_shuffles[id] if shuffle.run_id > run_id: return { @@ -111,7 +115,7 @@ def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> di "message": f"Request invalid, expected {run_id=} for {shuffle}", } ts = self.scheduler.tasks[key] - self._set_restriction(ts, worker) + self._set_restriction(ts, assigned_worker) return {"status": "OK"} def heartbeat(self, ws: WorkerState, data: dict) -> None: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 7d0247336ef..c3138845c3c 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -574,7 +574,7 @@ def create_run_on_worker( local_address=plugin.worker.address, rpc=plugin.worker.rpc, digest_metric=plugin.worker.digest_metric, - scheduler=plugin.worker.scheduler, + scheduler=plugin.worker.scheduler_ordered, # type: ignore memory_limiter_disk=plugin.memory_limiter_disk if self.disk else ResourceLimiter(None), diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index b9864d62dad..5ae3cb65e93 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -2486,10 +2486,12 @@ def __init__(self, scheduler: Scheduler): self.in_barrier = asyncio.Event() self.block_barrier = asyncio.Event() - async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: + async def barrier( + self, id: ShuffleId, run_id: int, consistent: bool, worker: None + ) -> None: self.in_barrier.set() await self.block_barrier.wait() - return await super().barrier(id, run_id, consistent) + return await super().barrier(id, run_id, consistent, worker) @gen_cluster(client=True) diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 85a502bafee..e5db1402dcc 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1481,3 +1481,103 @@ def sync_handler(val): assert ledger == list(range(n)) finally: await comm.close() + + +@pytest.mark.parametrize( + "use_side_channel", + [False, True], +) +@gen_test() +async def test_ordered_rpc(use_side_channel): + entered_sleep = asyncio.Event() + i = 0 + + async def sleep(duration): + nonlocal i + entered_sleep.set() + await asyncio.sleep(duration) + try: + return i + finally: + i += 1 + + class MyServer(Server): + def __init__(self, *args, **kwargs): + handlers = { + "sleep": sleep, + "do_work": self.do_work, + } + super().__init__(handlers, *args, **kwargs) + + async def do_work(self, other_addr, ordered=False): + if ordered: + r = await self.ordered_rpc( + other_addr, use_side_channel=use_side_channel + ) + else: + r = self.rpc(other_addr) + + t1 = asyncio.create_task(r.sleep(duration=0.1)) + + async def wait_to_unblock(error=False): + await entered_sleep.wait() + if error: + raise RuntimeError("error") + return await r.sleep(duration=0) + + t2 = asyncio.create_task(wait_to_unblock(error=True)) + t3 = asyncio.create_task(wait_to_unblock()) + + await asyncio.wait([t1, t2, t3]) + assert t2.exception + r1, r3 = await asyncio.gather(t1, t3) + try: + return r1 == 0 and r3 == 1 + finally: + nonlocal i + entered_sleep.clear() + i = 0 + + async with MyServer() as s1, MyServer() as s2: + await s1.listen() + await s2.listen() + async with rpc(s2.address) as r: + assert not await r.do_work(other_addr=s1.address) + assert await r.do_work(other_addr=s1.address, ordered=True) + + +@pytest.mark.parametrize( + "use_side_channel", + [False, True], +) +@gen_test() +async def test_ordered_rpc_comm_closed(use_side_channel): + async def sleep(duration): + await asyncio.sleep(duration) + + class MyServer(Server): + def __init__(self, *args, **kwargs): + handlers = { + "sleep": sleep, + "do_work": self.do_work, + "kill": self.kill, + } + super().__init__(handlers, *args, **kwargs) + + async def kill(self): + await self.close() + + async def do_work(self, other_addr): + r = await self.ordered_rpc(other_addr, use_side_channel=use_side_channel) + t1 = asyncio.create_task(r.sleep(duration=100000)) + with contextlib.suppress(OSError): + await self.rpc(other_addr).kill() + with pytest.raises(CommClosedError): + await t1 + return True + + async with MyServer() as s1, MyServer() as s2: + await s1.listen() + await s2.listen() + async with rpc(s2.address) as r: + assert await r.do_work(other_addr=s1.address) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2510da5b5af..7da329aafdf 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1751,39 +1751,6 @@ async def test_shutdown_on_scheduler_comm_closed(s, a): assert f"Connection to {s.address} has been closed" in logger.getvalue() -@gen_cluster(nthreads=[]) -async def test_heartbeat_comm_closed(s, monkeypatch): - with captured_logger("distributed.worker", level=logging.WARNING) as logger: - - def bad_heartbeat_worker(*args, **kwargs): - raise CommClosedError() - - async with Worker(s.address) as w: - # Trigger CommClosedError during worker heartbeat - monkeypatch.setattr(w.scheduler, "heartbeat_worker", bad_heartbeat_worker) - - await w.heartbeat() - assert w.status == Status.running - logs = logger.getvalue() - assert "Failed to communicate with scheduler during heartbeat" in logs - assert "Traceback" in logs - - -@gen_cluster(nthreads=[("", 1)], worker_kwargs={"heartbeat_interval": "100s"}) -async def test_heartbeat_missing(s, a, monkeypatch): - async def missing_heartbeat_worker(*args, **kwargs): - return {"status": "missing"} - - with captured_logger("distributed.worker", level=logging.WARNING) as wlogger: - monkeypatch.setattr(a.scheduler, "heartbeat_worker", missing_heartbeat_worker) - await a.heartbeat() - assert a.status == Status.closed - assert "Scheduler was unaware of this worker" in wlogger.getvalue() - - while s.workers: - await asyncio.sleep(0.01) - - @gen_cluster(nthreads=[("", 1)], worker_kwargs={"heartbeat_interval": "100s"}) async def test_heartbeat_missing_real_cluster(s, a): # The idea here is to create a situation where `s.workers[a.address]`, diff --git a/distributed/tests/test_worker_metrics.py b/distributed/tests/test_worker_metrics.py index e12c4902b15..b1b9a82158f 100644 --- a/distributed/tests/test_worker_metrics.py +++ b/distributed/tests/test_worker_metrics.py @@ -597,7 +597,7 @@ async def test_new_metrics_during_heartbeat(c, s, a): a.digest_metric(("execute", span.id, "x", "test", "test"), 1) await asyncio.sleep(0) await hb_task - assert n > 9 + assert n > 1 await a.heartbeat() assert a.digests_total["execute", span.id, "x", "test", "test"] == n diff --git a/distributed/worker.py b/distributed/worker.py index 55dd5a7724f..dd318dd058f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -469,6 +469,7 @@ class Worker(BaseWorker, ServerNode): execution_state: dict[str, Any] plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] + scheduler_ordered: object def __init__( self, @@ -786,6 +787,7 @@ def __init__( BaseWorker.__init__(self, state) self.scheduler = self.rpc(scheduler_addr) + self.scheduler_ordered = None self.execution_state = { "scheduler": self.scheduler.address, "ioloop": self.loop, @@ -1225,6 +1227,7 @@ async def _register_with_scheduler(self) -> None: raise ValueError(f"Unexpected response from register: {response!r}") self.batched_stream.start(comm) + self.scheduler_ordered = await self.ordered_rpc(bcomm=self.batched_stream) self.status = Status.running await asyncio.gather( @@ -1249,9 +1252,7 @@ async def heartbeat(self) -> None: logger.debug("Heartbeat: %s", self.address) try: start = time() - response = await retry_operation( - self.scheduler.heartbeat_worker, - address=self.contact_address, + response = await self.scheduler_ordered.heartbeat_worker( # type: ignore now=start, metrics=await self.get_metrics(), executing={ @@ -1286,8 +1287,6 @@ async def heartbeat(self) -> None: ) self.bandwidth_workers.clear() self.bandwidth_types.clear() - except OSError: - logger.exception("Failed to communicate with scheduler during heartbeat.") except Exception: logger.exception("Unexpected exception during heartbeat. Closing worker.") await self.close()