Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/topic/writer_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def send_message_without_block_if_internal_buffer_is_full(writer: ydb.TopicWrite
return False


def writer_with_buffer_limit(db: ydb.Driver, topic_path: str):
"""Writer with backpressure: waits for buffer space, raises TopicWriterBufferFullError on timeout."""
writer = db.topic_client.writer(
topic_path,
producer_id="producer-id",
max_buffer_size_bytes=10 * 1024 * 1024, # 10 MB
buffer_wait_timeout_sec=30.0,
)
try:
writer.write(ydb.TopicWriterMessage("data"))
except ydb.TopicWriterBufferFullError:
# Buffer did not free up within timeout (e.g. server slow or disconnected)
pass # handle: retry, drop, or back off
finally:
writer.close()


def send_messages_with_manual_seqno(writer: ydb.TopicWriter):
writer.write(ydb.TopicWriterMessage("mess")) # send text

Expand Down
16 changes: 16 additions & 0 deletions ydb/_topic_writer/topic_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class PublicWriterSettings:
encoder_executor: Optional[concurrent.futures.Executor] = None # default shared client executor pool
encoders: Optional[typing.Mapping[PublicCodec, typing.Callable[[bytes], bytes]]] = None
update_token_interval: Union[int, float] = 3600
# Backpressure: if set, writer waits for buffer space before adding messages; raises on timeout
max_buffer_size_bytes: Optional[int] = None # None = unlimited (previous behavior)
buffer_wait_timeout_sec: float = 30.0 # used only when max_buffer_size_bytes is set

def __post_init__(self):
if self.producer_id is None:
Expand Down Expand Up @@ -218,6 +221,12 @@ def __init__(self):
super(TopicWriterStopped, self).__init__("topic writer was stopped by call close")


class TopicWriterBufferFullError(TopicWriterError):
"""Raised when write cannot proceed: buffer is full and timeout expired waiting for free space."""

pass


def default_serializer_message_content(data: Any) -> bytes:
if data is None:
return bytes()
Expand Down Expand Up @@ -299,6 +308,13 @@ def get_message_size(msg: InternalMessage):
return _split_messages_by_size(messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size)


def internal_message_size_bytes(msg: InternalMessage) -> int:
"""Approximate size in bytes for buffer accounting (data + metadata + overhead)."""
data_len = len(msg.data)
meta_len = sum(len(k) + len(v) for k, v in msg.metadata_items.items()) if msg.metadata_items else 0
return data_len + meta_len + 64 # 64 bytes overhead per message (seq_no, timestamps, etc.)


Comment on lines +313 to +317
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal_message_size_bytes() currently uses len(msg.data) and len(k) for metadata keys. This can drift from the actual bytes held in the internal queues because msg.data is later mutated during compression (_encode_data_inplace), and len(k) counts characters not UTF-8 bytes. Consider accounting against a stable byte length (e.g., msg.uncompressed_size plus len(k.encode('utf-8')) + len(v)), or store the accounted size on the message at enqueue time and subtract the same value on ack.

Suggested change
data_len = len(msg.data)
meta_len = sum(len(k) + len(v) for k, v in msg.metadata_items.items()) if msg.metadata_items else 0
return data_len + meta_len + 64 # 64 bytes overhead per message (seq_no, timestamps, etc.)
# Prefer a stable uncompressed size for accounting, fall back to current data length.
uncompressed_size = getattr(msg, "uncompressed_size", None)
data_len = uncompressed_size if isinstance(uncompressed_size, int) and uncompressed_size >= 0 else len(msg.data)
if msg.metadata_items:
meta_len = 0
for k, v in msg.metadata_items.items():
key_bytes_len = len(k.encode("utf-8"))
if isinstance(v, str):
value_bytes_len = len(v.encode("utf-8"))
else:
value_bytes_len = len(v)
meta_len += key_bytes_len + value_bytes_len
else:
meta_len = 0
# 64 bytes overhead per message (seq_no, timestamps, etc.)
return data_len + meta_len + 64

Copilot uses AI. Check for mistakes.
def _split_messages_by_size(
messages: List[InternalMessage],
split_size: int,
Expand Down
34 changes: 33 additions & 1 deletion ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
InternalMessage,
TopicWriterStopped,
TopicWriterError,
TopicWriterBufferFullError,
internal_message_size_bytes,
messages_to_proto_requests,
PublicWriteResult,
PublicWriteResultTypes,
Expand Down Expand Up @@ -277,6 +279,8 @@ class WriterAsyncIOReconnector:
else:
_stop_reason: asyncio.Future
_init_info: Optional[PublicWriterInitInfo]
_buffer_bytes: int
_buffer_updated: asyncio.Event

def __init__(
self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None
Expand Down Expand Up @@ -317,6 +321,8 @@ def __init__(
self._messages = deque()
self._messages_future = deque()
self._new_messages = asyncio.Queue()
self._buffer_bytes = 0
self._buffer_updated = asyncio.Event()
self._stop_reason = self._loop.create_future()
connection_task = asyncio.create_task(self._connection_loop())
connection_task.set_name("connection_loop")
Expand Down Expand Up @@ -371,7 +377,6 @@ async def wait_stop(self) -> BaseException:
return stop_reason

async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asyncio.Future]:
# todo check internal buffer limit
self._check_stop()

if self._settings.auto_seqno:
Expand All @@ -380,6 +385,29 @@ async def write_with_ack_future(self, messages: List[PublicMessage]) -> List[asy
internal_messages = self._prepare_internal_messages(messages)
messages_future = [self._loop.create_future() for _ in internal_messages]

max_buf = self._settings.max_buffer_size_bytes
if max_buf is not None:
new_bytes = sum(internal_message_size_bytes(m) for m in internal_messages)
timeout_sec = self._settings.buffer_wait_timeout_sec
deadline = self._loop.time() + timeout_sec
while True:
self._buffer_updated.clear()
if self._buffer_bytes + new_bytes <= max_buf:
break
if self._loop.time() >= deadline:
raise TopicWriterBufferFullError(
"Topic writer buffer full: no free space within %.1f s (buffer=%d, need=%d, max=%d)"
% (timeout_sec, self._buffer_bytes, new_bytes, max_buf)
)
try:
await asyncio.wait_for(
self._buffer_updated.wait(),
timeout=min(0.5, max(0.01, deadline - self._loop.time())),
)
except asyncio.TimeoutError:
pass
Comment on lines +392 to +408
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backpressure wait loop doesn’t react to writer stop/close while waiting for _buffer_updated (it only calls _check_stop() once at the start). If the writer is stopped while a caller is blocked here, it will typically wait until buffer_wait_timeout_sec and then raise TopicWriterBufferFullError, masking the real stop reason. Consider checking _stop_reason inside the loop (or waiting on both _buffer_updated and _stop_reason) and/or setting _buffer_updated in _stop() to wake waiters immediately.

Copilot uses AI. Check for mistakes.
self._buffer_bytes += new_bytes

Comment on lines +388 to +410
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Buffer reservation is computed before encoding (new_bytes = sum(internal_message_size_bytes(m) ...)), but buffer release on ack recomputes size from the (possibly mutated) InternalMessage. When codec selection/compression changes message.data, _buffer_bytes can drift (over- or under-count), breaking backpressure. A robust approach is to compute a single accounted size per message at enqueue time (or use a stable field like uncompressed_size) and subtract that same value on ack.

Copilot uses AI. Check for mistakes.
self._messages_future.extend(messages_future)

if self._codec is not None and self._codec == PublicCodec.RAW:
Expand Down Expand Up @@ -629,6 +657,9 @@ async def _read_loop(self, writer: "WriterAsyncIOStream"):
def _handle_receive_ack(self, ack):
current_message = self._messages.popleft()
message_future = self._messages_future.popleft()
if self._settings.max_buffer_size_bytes is not None:
self._buffer_bytes = max(0, self._buffer_bytes - internal_message_size_bytes(current_message))
self._buffer_updated.set()
if current_message.seq_no != ack.seq_no:
raise TopicWriterError(
"internal error - receive unexpected ack. Expected seqno: %s, received seqno: %s"
Expand Down Expand Up @@ -695,6 +726,7 @@ def _stop(self, reason: BaseException):

for f in self._messages_future:
f.set_exception(reason)
f.exception() # mark as retrieved so asyncio does not log "Future exception was never retrieved"
Comment on lines 728 to +729
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_stop() unconditionally calls f.set_exception(reason) and then f.exception() for every future. If a caller cancels a returned ack future (or it’s already done), set_exception() can raise InvalidStateError, and f.exception() can raise CancelledError, potentially breaking shutdown/error propagation. Guard with if not f.done(): ... (and handle cancelled futures) before setting/reading exceptions.

Suggested change
f.set_exception(reason)
f.exception() # mark as retrieved so asyncio does not log "Future exception was never retrieved"
if not f.done():
try:
f.set_exception(reason)
except asyncio.InvalidStateError:
# Future might have been completed or cancelled concurrently; ignore.
pass
try:
# Mark exception as retrieved so asyncio does not log
# "Future exception was never retrieved".
f.exception()
except asyncio.CancelledError:
# It is valid for callers to cancel ack futures; ignore.
pass

Copilot uses AI. Check for mistakes.

self._state_changed.set()
logger.info("Stop topic writer %s: %s" % (self._id, reason))
Expand Down
92 changes: 92 additions & 0 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PublicWriterInitInfo,
PublicWriteResult,
TopicWriterError,
TopicWriterBufferFullError,
)
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
from .._topic_common.test_helpers import StreamMock, wait_for_fast
Expand Down Expand Up @@ -515,6 +516,97 @@ async def test_write_message(self, reconnector: WriterAsyncIOReconnector, get_st

await reconnector.close(flush=False)

async def test_buffer_full_timeout_raises(self, default_driver, get_stream_writer):
settings = WriterSettings(
PublicWriterSettings(
topic="/local/topic",
producer_id="test-producer",
auto_seqno=False,
auto_created_at=False,
codec=PublicCodec.RAW,
max_buffer_size_bytes=100,
buffer_wait_timeout_sec=0.1,
)
)
reconnector = WriterAsyncIOReconnector(default_driver, settings)
stream_writer = get_stream_writer()

await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)])
await stream_writer.from_client.get()

with pytest.raises(TopicWriterBufferFullError, match="buffer full"):
await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)])

await reconnector.close(flush=False)

async def test_buffer_freed_by_ack_allows_next_write(self, default_driver, get_stream_writer):
settings = WriterSettings(
PublicWriterSettings(
topic="/local/topic",
producer_id="test-producer",
auto_seqno=False,
auto_created_at=False,
codec=PublicCodec.RAW,
max_buffer_size_bytes=100,
buffer_wait_timeout_sec=5.0,
)
)
reconnector = WriterAsyncIOReconnector(default_driver, settings)
stream_writer = get_stream_writer()

await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)])
await stream_writer.from_client.get()

# Ack the first message to free buffer space
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))

# Second write must succeed once buffer is freed
await reconnector.write_with_ack_future([PublicMessage(data=b"y" * 10, seqno=2)])

stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2))
await reconnector.close(flush=True)

async def test_concurrent_writers_only_one_proceeds_after_ack(self, default_driver, get_stream_writer):
settings = WriterSettings(
PublicWriterSettings(
topic="/local/topic",
producer_id="test-producer",
auto_seqno=False,
auto_created_at=False,
codec=PublicCodec.RAW,
max_buffer_size_bytes=150,
buffer_wait_timeout_sec=5.0,
)
)
reconnector = WriterAsyncIOReconnector(default_driver, settings)
stream_writer = get_stream_writer()

# Fill 74 bytes of the 150-byte buffer (10 data + 64 overhead)
await reconnector.write_with_ack_future([PublicMessage(data=b"x" * 10, seqno=1)])
await stream_writer.from_client.get()

# Both tasks need 94 bytes (30 + 64); 76 free bytes can't fit either, so both block.
# After ack, buffer drops to 0: exactly one fits (94 ≤ 150), the other stays blocked (94+94 > 150).
task2 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"y" * 30, seqno=2)]))
task3 = asyncio.create_task(reconnector.write_with_ack_future([PublicMessage(data=b"z" * 30, seqno=3)]))

# Let both tasks start and reach their buffer-wait await point
await asyncio.sleep(0)
await asyncio.sleep(0)
assert not task2.done()
assert not task3.done()

# Free the buffer: ack msg1 drops _buffer_bytes from 74 → 0
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))

done, pending = await asyncio.wait([task2, task3], timeout=1.0, return_when=asyncio.FIRST_COMPLETED)
assert len(done) == 1, "exactly one write should proceed after ack"
assert len(pending) == 1, "other write should still be waiting for buffer space"
assert not next(iter(pending)).done()

next(iter(pending)).cancel()
await reconnector.close(flush=False)

async def test_auto_seq_no(self, default_driver, default_settings, get_stream_writer):
last_seq_no = 100
with mock.patch.object(TestWriterAsyncIOReconnector, "init_last_seqno", last_seq_no):
Expand Down
108 changes: 107 additions & 1 deletion ydb/_topic_writer/topic_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import asyncio
import threading
from typing import List
from unittest import mock

import pytest

from .topic_writer import _split_messages_by_size
from .topic_writer import (
PublicMessage,
PublicWriterSettings,
TopicWriterBufferFullError,
_split_messages_by_size,
)
from .topic_writer_asyncio import WriterAsyncIOReconnector
from .topic_writer_sync import WriterSync


@pytest.mark.parametrize(
Expand Down Expand Up @@ -48,3 +58,99 @@
def test_split_messages_by_size(messages: List[int], split_size: int, expected: List[List[int]]):
res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa
assert res == expected


@pytest.fixture
def background_loop():
loop = asyncio.new_event_loop()
ready = threading.Event()

def run():
asyncio.set_event_loop(loop)
loop.call_soon(ready.set)
loop.run_forever()

t = threading.Thread(target=run, daemon=True)
t.start()
ready.wait()
yield loop
loop.call_soon_threadsafe(loop.stop)
t.join(timeout=2)
loop.close()


@pytest.fixture
def mock_reconnector(monkeypatch):
def factory(reconnector_instance):
monkeypatch.setattr(WriterAsyncIOReconnector, "__new__", lambda cls, *a, **kw: reconnector_instance)
return reconnector_instance

return factory


class TestWriterSyncBuffer:
def _make_writer(self, background_loop, reconnector, mock_reconnector):
mock_reconnector(reconnector)
settings = PublicWriterSettings(topic="/local/topic", producer_id="test-producer")
return WriterSync(mock.Mock(), settings, eventloop=background_loop)

def test_buffer_full_error_propagates(self, background_loop, mock_reconnector):
class ImmediateFullReconnector:
async def write_with_ack_future(self, messages):
raise TopicWriterBufferFullError("buffer full")

async def close(self, flush):
pass

writer = self._make_writer(background_loop, ImmediateFullReconnector(), mock_reconnector)
with pytest.raises(TopicWriterBufferFullError):
writer.write(PublicMessage(data=b"hello", seqno=1))
writer.close(flush=False)

def test_write_blocks_until_buffer_freed(self, background_loop, mock_reconnector):
write_started = threading.Event()

class BlockingReconnector:
_release_event = None

async def write_with_ack_future(self, messages):
self._release_event = asyncio.Event()
write_started.set()
await self._release_event.wait()
loop = asyncio.get_running_loop()
futures = [loop.create_future() for _ in messages]
for f in futures:
f.set_result(None)
return futures

async def release(self):
if self._release_event:
self._release_event.set()

async def close(self, flush):
pass

reconnector = BlockingReconnector()
writer = self._make_writer(background_loop, reconnector, mock_reconnector)

write_errors = []

def do_write():
try:
writer.write(PublicMessage(data=b"hello", seqno=1))
except Exception as e:
write_errors.append(e)

write_thread = threading.Thread(target=do_write, daemon=True)
write_thread.start()

assert write_started.wait(timeout=1.0), "write did not start"

# Write thread is now blocked; release the mock to simulate buffer freed
asyncio.run_coroutine_threadsafe(reconnector.release(), background_loop).result(timeout=1.0)

write_thread.join(timeout=1.0)
assert not write_thread.is_alive(), "write should have completed after buffer was freed"
assert not write_errors, f"unexpected error: {write_errors}"

writer.close(flush=False)
2 changes: 2 additions & 0 deletions ydb/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"TopicWriterInitInfo",
"TopicWriterMessage",
"TopicWriterSettings",
"TopicWriterBufferFullError",
]

import concurrent.futures
Expand Down Expand Up @@ -72,6 +73,7 @@
RetryPolicy as TopicWriterRetryPolicy,
PublicWriterInitInfo as TopicWriterInitInfo,
PublicWriteResult as TopicWriteResult,
TopicWriterBufferFullError,
)

from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO
Expand Down
Loading