From 6e607369a19ffab34ba956b11b93069fac5f1bd3 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 1 Apr 2026 11:32:39 +0300 Subject: [PATCH] Topic Writer Backpressure --- examples/topic/writer_example.py | 17 +++ ydb/_topic_writer/topic_writer.py | 16 +++ ydb/_topic_writer/topic_writer_asyncio.py | 34 +++++- .../topic_writer_asyncio_test.py | 92 +++++++++++++++ ydb/_topic_writer/topic_writer_test.py | 108 +++++++++++++++++- ydb/topic.py | 2 + 6 files changed, 267 insertions(+), 2 deletions(-) diff --git a/examples/topic/writer_example.py b/examples/topic/writer_example.py index 99346a27..8bcacb7d 100644 --- a/examples/topic/writer_example.py +++ b/examples/topic/writer_example.py @@ -71,6 +71,23 @@ def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWrite return False +def writer_with_buffer_limit(db: ydb.Driver, topic_path: str): + """Writer with backpressure: waits for buffer space, raises TopicWriterBufferFullError on timeout.""" + writer = db.topic_client.writer( + topic_path, + producer_id="producer-id", + max_buffer_size_bytes=10 * 1024 * 1024, # 10 MB + buffer_wait_timeout_sec=30.0, + ) + try: + writer.write(ydb.TopicWriterMessage("data")) + except ydb.TopicWriterBufferFullError: + # Buffer did not free up within timeout (e.g. server slow or disconnected) + pass # handle: retry, drop, or back off + finally: + writer.close() + + def send_messages_with_manual_seqno(writer: ydb.TopicWriter): writer.write(ydb.TopicWriterMessage("mess")) # send text diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 4ce63a91..ca199a7c 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -37,6 +37,9 @@ class PublicWriterSettings: encoder_executor: Optional[concurrent.futures.Executor] = None # default shared client executor pool encoders: Optional[typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]]] = None update_token_interval: Union[int, float] = 3600 + # Backpressure: if set, writer waits for buffer space before adding messages; raises on timeout + max_buffer_size_bytes: Optional[int] = None # None = unlimited (previous behavior) + buffer_wait_timeout_sec: float = 30.0 # used only when max_buffer_size_bytes is set def __post_init__(self): if self.producer_id is None: @@ -218,6 +221,12 @@ def __init__(self): super(TopicWriterStopped, self).__init__("topic writer was stopped by call close") +class TopicWriterBufferFullError(TopicWriterError): + """Raised when write cannot proceed: buffer is full and timeout expired waiting for free space.""" + + pass + + def default_serializer_message_content(data: Any) -> bytes: if data is None: return bytes() @@ -299,6 +308,13 @@ def get_message_size(msg: InternalMessage): return _split_messages_by_size(messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size) +def internal_message_size_bytes(msg: InternalMessage) -> int: + """Approximate size in bytes for buffer accounting (data + metadata + overhead).""" + data_len = len(msg.data) + meta_len = sum(len(k) + len(v) for k, v in msg.metadata_items.items()) if msg.metadata_items else 0 + return data_len + meta_len + 64 # 64 bytes overhead per message (seq_no, timestamps, etc.) + + def _split_messages_by_size( messages: List[InternalMessage], split_size: int, diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b80537dc..142e8f0b 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -19,6 +19,8 @@ InternalMessage, TopicWriterStopped, TopicWriterError, + TopicWriterBufferFullError, + internal_message_size_bytes, messages_to_proto_requests, PublicWriteResult, PublicWriteResultTypes, @@ -277,6 +279,8 @@ class WriterAsyncIOReconnector: else: _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] + _buffer_bytes: int + _buffer_updated: asyncio.Event def __init__( self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None @@ -317,6 +321,8 @@ def __init__( self._messages = deque() self._messages_future = deque() self._new_messages = asyncio.Queue() + self._buffer_bytes = 0 + self._buffer_updated = asyncio.Event() self._stop_reason = self._loop.create_future() connection_task = asyncio.create_task(self._connection_loop()) connection_task.set_name("connection_loop") @@ -371,7 +377,6 @@ async def wait_stop(self) -> BaseException: return stop_reason async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asyncio.Future]: - # todo check internal buffer limit self._check_stop() if self._settings.auto_seqno: @@ -380,6 +385,29 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy internal_messages = self._prepare_internal_messages(messages) messages_future = [self._loop.create_future() for _ in internal_messages] + max_buf = self._settings.max_buffer_size_bytes + if max_buf is not None: + new_bytes = sum(internal_message_size_bytes(m) for m in internal_messages) + timeout_sec = self._settings.buffer_wait_timeout_sec + deadline = self._loop.time() + timeout_sec + while True: + self._buffer_updated.clear() + if self._buffer_bytes + new_bytes <= max_buf: + break + if self._loop.time() >= deadline: + raise TopicWriterBufferFullError( + "Topic writer buffer full: no free space within %.1f s (buffer=%d, need=%d, max=%d)" + % (timeout_sec, self._buffer_bytes, new_bytes, max_buf) + ) + try: + await asyncio.wait_for( + self._buffer_updated.wait(), + timeout=min(0.5, max(0.01, deadline - self._loop.time())), + ) + except asyncio.TimeoutError: + pass + self._buffer_bytes += new_bytes + self._messages_future.extend(messages_future) if self._codec is not None and self._codec == PublicCodec.RAW: @@ -629,6 +657,9 @@ async def _read_loop(self, writer: "WriterAsyncIOStream"): def _handle_receive_ack(self, ack): current_message = self._messages.popleft() message_future = self._messages_future.popleft() + if self._settings.max_buffer_size_bytes is not None: + self._buffer_bytes = max(0, self._buffer_bytes - internal_message_size_bytes(current_message)) + self._buffer_updated.set() if current_message.seq_no != ack.seq_no: raise TopicWriterError( "internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s" @@ -695,6 +726,7 @@ def _stop(self, reason: BaseException): for f in self._messages_future: f.set_exception(reason) + f.exception() # mark as retrieved so asyncio does not log "Future exception was never retrieved" self._state_changed.set() logger.info("Stop topic writer %s: %s" % (self._id, reason)) diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index a616b0b6..3298a343 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -31,6 +31,7 @@ PublicWriterInitInfo, PublicWriteResult, TopicWriterError, + TopicWriterBufferFullError, ) from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._topic_common.test_helpers import StreamMock, wait_for_fast @@ -515,6 +516,97 @@ async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_st await reconnector.close(flush=False) + async def test_buffer_full_timeout_raises(self, default_driver, get_stream_writer): + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=100, + buffer_wait_timeout_sec=0.1, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + with pytest.raises(TopicWriterBufferFullError, match="buffer full"): + await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)]) + + await reconnector.close(flush=False) + + async def test_buffer_freed_by_ack_allows_next_write(self, default_driver, get_stream_writer): + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=100, + buffer_wait_timeout_sec=5.0, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + # Ack the first message to free buffer space + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + # Second write must succeed once buffer is freed + await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)]) + + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) + await reconnector.close(flush=True) + + async def test_concurrent_writers_only_one_proceeds_after_ack(self, default_driver, get_stream_writer): + settings = WriterSettings( + PublicWriterSettings( + topic="/local/topic", + producer_id="test-producer", + auto_seqno=False, + auto_created_at=False, + codec=PublicCodec.RAW, + max_buffer_size_bytes=150, + buffer_wait_timeout_sec=5.0, + ) + ) + reconnector = WriterAsyncIOReconnector(default_driver, settings) + stream_writer = get_stream_writer() + + # Fill 74 bytes of the 150-byte buffer (10 data + 64 overhead) + await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)]) + await stream_writer.from_client.get() + + # Both tasks need 94 bytes (30 + 64); 76 free bytes can't fit either, so both block. + # After ack, buffer drops to 0: exactly one fits (94 ≤ 150), the other stays blocked (94+94 > 150). + task2 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"y" * 30, seqno=2)])) + task3 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"z" * 30, seqno=3)])) + + # Let both tasks start and reach their buffer-wait await point + await asyncio.sleep(0) + await asyncio.sleep(0) + assert not task2.done() + assert not task3.done() + + # Free the buffer: ack msg1 drops _buffer_bytes from 74 → 0 + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + + done, pending = await asyncio.wait([task2, task3], timeout=1.0, return_when=asyncio.FIRST_COMPLETED) + assert len(done) == 1, "exactly one write should proceed after ack" + assert len(pending) == 1, "other write should still be waiting for buffer space" + assert not next(iter(pending)).done() + + next(iter(pending)).cancel() + await reconnector.close(flush=False) + async def test_auto_seq_no(self, default_driver, default_settings, get_stream_writer): last_seq_no = 100 with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no): diff --git a/ydb/_topic_writer/topic_writer_test.py b/ydb/_topic_writer/topic_writer_test.py index 0e829255..4a324e3f 100644 --- a/ydb/_topic_writer/topic_writer_test.py +++ b/ydb/_topic_writer/topic_writer_test.py @@ -1,8 +1,18 @@ +import asyncio +import threading from typing import List +from unittest import mock import pytest -from .topic_writer import _split_messages_by_size +from .topic_writer import ( + PublicMessage, + PublicWriterSettings, + TopicWriterBufferFullError, + _split_messages_by_size, +) +from .topic_writer_asyncio import WriterAsyncIOReconnector +from .topic_writer_sync import WriterSync @pytest.mark.parametrize( @@ -48,3 +58,99 @@ def test_split_messages_by_size(messages: List[int], split_size: int, expected: List[List[int]]): res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa assert res == expected + + +@pytest.fixture +def background_loop(): + loop = asyncio.new_event_loop() + ready = threading.Event() + + def run(): + asyncio.set_event_loop(loop) + loop.call_soon(ready.set) + loop.run_forever() + + t = threading.Thread(target=run, daemon=True) + t.start() + ready.wait() + yield loop + loop.call_soon_threadsafe(loop.stop) + t.join(timeout=2) + loop.close() + + +@pytest.fixture +def mock_reconnector(monkeypatch): + def factory(reconnector_instance): + monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", lambda cls, *a, **kw: reconnector_instance) + return reconnector_instance + + return factory + + +class TestWriterSyncBuffer: + def _make_writer(self, background_loop, reconnector, mock_reconnector): + mock_reconnector(reconnector) + settings = PublicWriterSettings(topic="/local/topic", producer_id="test-producer") + return WriterSync(mock.Mock(), settings, eventloop=background_loop) + + def test_buffer_full_error_propagates(self, background_loop, mock_reconnector): + class ImmediateFullReconnector: + async def write_with_ack_future(self, messages): + raise TopicWriterBufferFullError("buffer full") + + async def close(self, flush): + pass + + writer = self._make_writer(background_loop, ImmediateFullReconnector(), mock_reconnector) + with pytest.raises(TopicWriterBufferFullError): + writer.write(PublicMessage(data=b"hello", seqno=1)) + writer.close(flush=False) + + def test_write_blocks_until_buffer_freed(self, background_loop, mock_reconnector): + write_started = threading.Event() + + class BlockingReconnector: + _release_event = None + + async def write_with_ack_future(self, messages): + self._release_event = asyncio.Event() + write_started.set() + await self._release_event.wait() + loop = asyncio.get_running_loop() + futures = [loop.create_future() for _ in messages] + for f in futures: + f.set_result(None) + return futures + + async def release(self): + if self._release_event: + self._release_event.set() + + async def close(self, flush): + pass + + reconnector = BlockingReconnector() + writer = self._make_writer(background_loop, reconnector, mock_reconnector) + + write_errors = [] + + def do_write(): + try: + writer.write(PublicMessage(data=b"hello", seqno=1)) + except Exception as e: + write_errors.append(e) + + write_thread = threading.Thread(target=do_write, daemon=True) + write_thread.start() + + assert write_started.wait(timeout=1.0), "write did not start" + + # Write thread is now blocked; release the mock to simulate buffer freed + asyncio.run_coroutine_threadsafe(reconnector.release(), background_loop).result(timeout=1.0) + + write_thread.join(timeout=1.0) + assert not write_thread.is_alive(), "write should have completed after buffer was freed" + assert not write_errors, f"unexpected error: {write_errors}" + + writer.close(flush=False) diff --git a/ydb/topic.py b/ydb/topic.py index 1faf4659..ad51bb7e 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -32,6 +32,7 @@ "TopicWriterInitInfo", "TopicWriterMessage", "TopicWriterSettings", + "TopicWriterBufferFullError", ] import concurrent.futures @@ -72,6 +73,7 @@ RetryPolicy as TopicWriterRetryPolicy, PublicWriterInitInfo as TopicWriterInitInfo, PublicWriteResult as TopicWriteResult, + TopicWriterBufferFullError, ) from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO