Skip to content

Commit bb41e19

Browse files
xi-dbhvanhovell
authored andcommitted
[SPARK-53525][CONNECT] Spark Connect ArrowBatch Result Chunking
### What changes were proposed in this pull request? Currently, we enforce gRPC message limits on both the client and the server. These limits are largely meant to protect both sides from potential OOMs by rejecting abnormally large messages. However, there are cases in which the server incorrectly sends oversized messages that exceed these limits and cause execution failures. Specifically, the large message issue from the server to the client we’re solving here, comes from the Arrow batch data in ExecutePlanResponse being too large. It’s caused by a single arrow row exceeding the 128MB message limit, and Arrow cannot partition further and it has to return the single large row in one gRPC message. To improve Spark Connect stability, this PR implements chunking large Arrow batches when returning query results from the server to the client, ensuring each ExecutePlanResponse chunk remains within the size limit, and the chunks from a batch will be reassembled on the client when parsing as an arrow batch. (Scala client changes are being implemented in a follow-up PR.) To reproduce the existing issue we are solving here, run this code on Spark Connect: ``` repeat_num_per_mb = 1024 * 1024 // len('Apache Spark ') res = spark.sql(f"select repeat('Apache Spark ', {repeat_num_per_mb * 300}) as huge_col from range(1)").collect() print(len(res)) ``` It fails with `StatusCode.RESOURCE_EXHAUSTED` error with message `Received message larger than max (314570608 vs. 134217728)`, because the server is trying to send an ExecutePlanResponse of ~300MB to the client. With the improvement introduced by the PR, the above code runs successfully and prints the expected result. ### Why are the changes needed? It improves Spark Connect stability when returning large rows. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests on both the server side and the client side. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52271 from xi-db/arrow-batch-chunking. Authored-by: Xi Lyu <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent a8f56d4 commit bb41e19

File tree

9 files changed

+644
-147
lines changed

9 files changed

+644
-147
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,8 @@ def __init__(
607607
retry_policy: Optional[Dict[str, Any]] = None,
608608
use_reattachable_execute: bool = True,
609609
session_hooks: Optional[list["SparkSession.Hook"]] = None,
610+
allow_arrow_batch_chunking: bool = True,
611+
preferred_arrow_chunk_size: Optional[int] = None,
610612
):
611613
"""
612614
Creates a new SparkSession for the Spark Connect interface.
@@ -639,6 +641,21 @@ def __init__(
639641
Enable reattachable execution.
640642
session_hooks: list[SparkSession.Hook], optional
641643
List of session hooks to call.
644+
allow_arrow_batch_chunking: bool
645+
Whether to allow the server to split large Arrow batches into smaller chunks.
646+
Although Arrow results are split into batches with a size limit according to estimation,
647+
the size of the batches is not guaranteed to be less than the limit, especially when a
648+
single row is larger than the limit, in which case the server will fail to split it
649+
further into smaller batches. As a result, the client may encounter a gRPC error stating
650+
"Received message larger than max" when a batch is too large.
651+
If true, the server will split large Arrow batches into smaller chunks, and the client
652+
is expected to handle the chunked Arrow batches.
653+
If false, the server will not chunk large Arrow batches.
654+
preferred_arrow_chunk_size: Optional[int]
655+
Optional preferred Arrow batch size in bytes for the server to use when sending Arrow
656+
results.
657+
The server will attempt to use this size if it is set and within the valid range
658+
([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used.
642659
"""
643660
self.thread_local = threading.local()
644661

@@ -678,6 +695,8 @@ def __init__(
678695
self._user_id, self._session_id, self._channel, self._builder.metadata()
679696
)
680697
self._use_reattachable_execute = use_reattachable_execute
698+
self._allow_arrow_batch_chunking = allow_arrow_batch_chunking
699+
self._preferred_arrow_chunk_size = preferred_arrow_chunk_size
681700
self._session_hooks = session_hooks or []
682701
# Configure logging for the SparkConnect client.
683702

@@ -1235,6 +1254,15 @@ def _execute_plan_request_with_metadata(
12351254
req.client_observed_server_side_session_id = self._server_session_id
12361255
if self._user_id:
12371256
req.user_context.user_id = self._user_id
1257+
# Add request option to allow result chunking.
1258+
req.request_options.append(
1259+
pb2.ExecutePlanRequest.RequestOption(
1260+
result_chunking_options=pb2.ResultChunkingOptions(
1261+
allow_arrow_batch_chunking=self._allow_arrow_batch_chunking,
1262+
preferred_arrow_chunk_size=self._preferred_arrow_chunk_size,
1263+
)
1264+
)
1265+
)
12381266
if operation_id is not None:
12391267
try:
12401268
uuid.UUID(operation_id, version=4)
@@ -1408,6 +1436,7 @@ def _execute_and_fetch_as_iterator(
14081436
req = hook.on_execute_plan(req)
14091437

14101438
num_records = 0
1439+
arrow_batch_chunks_to_assemble: List[bytes] = []
14111440

14121441
def handle_response(
14131442
b: pb2.ExecutePlanResponse,
@@ -1495,32 +1524,65 @@ def handle_response(
14951524
if b.HasField("arrow_batch"):
14961525
logger.debug(
14971526
f"Received arrow batch rows={b.arrow_batch.row_count} "
1527+
f"Number of chunks in batch={b.arrow_batch.num_chunks_in_batch} "
1528+
f"Chunk index={b.arrow_batch.chunk_index} "
14981529
f"size={len(b.arrow_batch.data)}"
14991530
)
15001531

1532+
if arrow_batch_chunks_to_assemble:
1533+
# Expect next chunk of the same batch
1534+
if b.arrow_batch.chunk_index != len(arrow_batch_chunks_to_assemble):
1535+
raise SparkConnectException(
1536+
f"Expected chunk index {len(arrow_batch_chunks_to_assemble)} of the "
1537+
f"arrow batch but got {b.arrow_batch.chunk_index}."
1538+
)
1539+
else:
1540+
# Expect next batch
1541+
if (
1542+
b.arrow_batch.HasField("start_offset")
1543+
and num_records != b.arrow_batch.start_offset
1544+
):
1545+
# Expect next batch
1546+
raise SparkConnectException(
1547+
f"Expected arrow batch to start at row offset {num_records} in "
1548+
+ "results, but received arrow batch starting at offset "
1549+
+ f"{b.arrow_batch.start_offset}."
1550+
)
1551+
if b.arrow_batch.chunk_index != 0:
1552+
raise SparkConnectException(
1553+
f"Expected chunk index 0 of the next arrow batch "
1554+
f"but got {b.arrow_batch.chunk_index}."
1555+
)
1556+
1557+
arrow_batch_chunks_to_assemble.append(b.arrow_batch.data)
1558+
# Assemble the chunks to an arrow batch to process if
1559+
# (a) chunking is not enabled (num_chunks_in_batch is not set or is 0,
1560+
# in this case, it is the single chunk in the batch)
1561+
# (b) or the client has received all chunks of the batch.
15011562
if (
1502-
b.arrow_batch.HasField("start_offset")
1503-
and num_records != b.arrow_batch.start_offset
1563+
not b.arrow_batch.HasField("num_chunks_in_batch")
1564+
or b.arrow_batch.num_chunks_in_batch == 0
1565+
or len(arrow_batch_chunks_to_assemble) == b.arrow_batch.num_chunks_in_batch
15041566
):
1505-
raise SparkConnectException(
1506-
f"Expected arrow batch to start at row offset {num_records} in results, "
1507-
+ "but received arrow batch starting at offset "
1508-
+ f"{b.arrow_batch.start_offset}."
1567+
arrow_batch_data = b"".join(arrow_batch_chunks_to_assemble)
1568+
arrow_batch_chunks_to_assemble.clear()
1569+
logger.debug(
1570+
f"Assembling arrow batch of size {len(arrow_batch_data)} from "
1571+
f"{b.arrow_batch.num_chunks_in_batch} chunks."
15091572
)
15101573

1511-
num_records_in_batch = 0
1512-
with pa.ipc.open_stream(b.arrow_batch.data) as reader:
1513-
for batch in reader:
1514-
assert isinstance(batch, pa.RecordBatch)
1515-
num_records_in_batch += batch.num_rows
1516-
yield batch
1517-
1518-
if num_records_in_batch != b.arrow_batch.row_count:
1519-
raise SparkConnectException(
1520-
f"Expected {b.arrow_batch.row_count} rows in arrow batch but got "
1521-
+ f"{num_records_in_batch}."
1522-
)
1523-
num_records += num_records_in_batch
1574+
num_records_in_batch = 0
1575+
with pa.ipc.open_stream(arrow_batch_data) as reader:
1576+
for batch in reader:
1577+
assert isinstance(batch, pa.RecordBatch)
1578+
num_records_in_batch += batch.num_rows
1579+
if num_records_in_batch != b.arrow_batch.row_count:
1580+
raise SparkConnectException(
1581+
f"Expected {b.arrow_batch.row_count} rows in arrow batch but "
1582+
+ f"got {num_records_in_batch}."
1583+
)
1584+
num_records += num_records_in_batch
1585+
yield batch
15241586
if b.HasField("create_resource_profile_command_result"):
15251587
profile_id = b.create_resource_profile_command_result.profile_id
15261588
yield {"create_resource_profile_command_result": profile_id}

python/pyspark/sql/connect/proto/base_pb2.py

Lines changed: 113 additions & 111 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/base_pb2.pyi

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,16 +1093,20 @@ class ExecutePlanRequest(google.protobuf.message.Message):
10931093
DESCRIPTOR: google.protobuf.descriptor.Descriptor
10941094

10951095
REATTACH_OPTIONS_FIELD_NUMBER: builtins.int
1096+
RESULT_CHUNKING_OPTIONS_FIELD_NUMBER: builtins.int
10961097
EXTENSION_FIELD_NUMBER: builtins.int
10971098
@property
10981099
def reattach_options(self) -> global___ReattachOptions: ...
10991100
@property
1101+
def result_chunking_options(self) -> global___ResultChunkingOptions: ...
1102+
@property
11001103
def extension(self) -> google.protobuf.any_pb2.Any:
11011104
"""Extension type for request options"""
11021105
def __init__(
11031106
self,
11041107
*,
11051108
reattach_options: global___ReattachOptions | None = ...,
1109+
result_chunking_options: global___ResultChunkingOptions | None = ...,
11061110
extension: google.protobuf.any_pb2.Any | None = ...,
11071111
) -> None: ...
11081112
def HasField(
@@ -1114,6 +1118,8 @@ class ExecutePlanRequest(google.protobuf.message.Message):
11141118
b"reattach_options",
11151119
"request_option",
11161120
b"request_option",
1121+
"result_chunking_options",
1122+
b"result_chunking_options",
11171123
],
11181124
) -> builtins.bool: ...
11191125
def ClearField(
@@ -1125,11 +1131,16 @@ class ExecutePlanRequest(google.protobuf.message.Message):
11251131
b"reattach_options",
11261132
"request_option",
11271133
b"request_option",
1134+
"result_chunking_options",
1135+
b"result_chunking_options",
11281136
],
11291137
) -> None: ...
11301138
def WhichOneof(
11311139
self, oneof_group: typing_extensions.Literal["request_option", b"request_option"]
1132-
) -> typing_extensions.Literal["reattach_options", "extension"] | None: ...
1140+
) -> (
1141+
typing_extensions.Literal["reattach_options", "result_chunking_options", "extension"]
1142+
| None
1143+
): ...
11331144

11341145
SESSION_ID_FIELD_NUMBER: builtins.int
11351146
CLIENT_OBSERVED_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
@@ -1308,38 +1319,78 @@ class ExecutePlanResponse(google.protobuf.message.Message):
13081319
ROW_COUNT_FIELD_NUMBER: builtins.int
13091320
DATA_FIELD_NUMBER: builtins.int
13101321
START_OFFSET_FIELD_NUMBER: builtins.int
1322+
CHUNK_INDEX_FIELD_NUMBER: builtins.int
1323+
NUM_CHUNKS_IN_BATCH_FIELD_NUMBER: builtins.int
13111324
row_count: builtins.int
13121325
"""Count rows in `data`. Must match the number of rows inside `data`."""
13131326
data: builtins.bytes
13141327
"""Serialized Arrow data."""
13151328
start_offset: builtins.int
13161329
"""If set, row offset of the start of this ArrowBatch in execution results."""
1330+
chunk_index: builtins.int
1331+
"""Index of this chunk in the batch if chunking is enabled. The index starts from 0."""
1332+
num_chunks_in_batch: builtins.int
1333+
"""Total number of chunks in this batch if chunking is enabled.
1334+
It is missing when chunking is disabled - the batch is returned whole
1335+
and client will treat this response as the batch.
1336+
"""
13171337
def __init__(
13181338
self,
13191339
*,
13201340
row_count: builtins.int = ...,
13211341
data: builtins.bytes = ...,
13221342
start_offset: builtins.int | None = ...,
1343+
chunk_index: builtins.int | None = ...,
1344+
num_chunks_in_batch: builtins.int | None = ...,
13231345
) -> None: ...
13241346
def HasField(
13251347
self,
13261348
field_name: typing_extensions.Literal[
1327-
"_start_offset", b"_start_offset", "start_offset", b"start_offset"
1349+
"_chunk_index",
1350+
b"_chunk_index",
1351+
"_num_chunks_in_batch",
1352+
b"_num_chunks_in_batch",
1353+
"_start_offset",
1354+
b"_start_offset",
1355+
"chunk_index",
1356+
b"chunk_index",
1357+
"num_chunks_in_batch",
1358+
b"num_chunks_in_batch",
1359+
"start_offset",
1360+
b"start_offset",
13281361
],
13291362
) -> builtins.bool: ...
13301363
def ClearField(
13311364
self,
13321365
field_name: typing_extensions.Literal[
1366+
"_chunk_index",
1367+
b"_chunk_index",
1368+
"_num_chunks_in_batch",
1369+
b"_num_chunks_in_batch",
13331370
"_start_offset",
13341371
b"_start_offset",
1372+
"chunk_index",
1373+
b"chunk_index",
13351374
"data",
13361375
b"data",
1376+
"num_chunks_in_batch",
1377+
b"num_chunks_in_batch",
13371378
"row_count",
13381379
b"row_count",
13391380
"start_offset",
13401381
b"start_offset",
13411382
],
13421383
) -> None: ...
1384+
@typing.overload
1385+
def WhichOneof(
1386+
self, oneof_group: typing_extensions.Literal["_chunk_index", b"_chunk_index"]
1387+
) -> typing_extensions.Literal["chunk_index"] | None: ...
1388+
@typing.overload
1389+
def WhichOneof(
1390+
self,
1391+
oneof_group: typing_extensions.Literal["_num_chunks_in_batch", b"_num_chunks_in_batch"],
1392+
) -> typing_extensions.Literal["num_chunks_in_batch"] | None: ...
1393+
@typing.overload
13431394
def WhichOneof(
13441395
self, oneof_group: typing_extensions.Literal["_start_offset", b"_start_offset"]
13451396
) -> typing_extensions.Literal["start_offset"] | None: ...
@@ -2942,6 +2993,62 @@ class ReattachOptions(google.protobuf.message.Message):
29422993

29432994
global___ReattachOptions = ReattachOptions
29442995

2996+
class ResultChunkingOptions(google.protobuf.message.Message):
2997+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
2998+
2999+
ALLOW_ARROW_BATCH_CHUNKING_FIELD_NUMBER: builtins.int
3000+
PREFERRED_ARROW_CHUNK_SIZE_FIELD_NUMBER: builtins.int
3001+
allow_arrow_batch_chunking: builtins.bool
3002+
"""Although Arrow results are split into batches with a size limit according to estimation, the
3003+
size of the batches is not guaranteed to be less than the limit, especially when a single row
3004+
is larger than the limit, in which case the server will fail to split it further into smaller
3005+
batches. As a result, the client may encounter a gRPC error stating “Received message larger
3006+
than max” when a batch is too large.
3007+
If allow_arrow_batch_chunking=true, the server will split large Arrow batches into smaller chunks,
3008+
and the client is expected to handle the chunked Arrow batches.
3009+
3010+
If false, the server will not chunk large Arrow batches.
3011+
"""
3012+
preferred_arrow_chunk_size: builtins.int
3013+
"""Optional preferred Arrow batch size in bytes for the server to use when sending Arrow results.
3014+
The server will attempt to use this size if it is set and within the valid range
3015+
([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used.
3016+
"""
3017+
def __init__(
3018+
self,
3019+
*,
3020+
allow_arrow_batch_chunking: builtins.bool = ...,
3021+
preferred_arrow_chunk_size: builtins.int | None = ...,
3022+
) -> None: ...
3023+
def HasField(
3024+
self,
3025+
field_name: typing_extensions.Literal[
3026+
"_preferred_arrow_chunk_size",
3027+
b"_preferred_arrow_chunk_size",
3028+
"preferred_arrow_chunk_size",
3029+
b"preferred_arrow_chunk_size",
3030+
],
3031+
) -> builtins.bool: ...
3032+
def ClearField(
3033+
self,
3034+
field_name: typing_extensions.Literal[
3035+
"_preferred_arrow_chunk_size",
3036+
b"_preferred_arrow_chunk_size",
3037+
"allow_arrow_batch_chunking",
3038+
b"allow_arrow_batch_chunking",
3039+
"preferred_arrow_chunk_size",
3040+
b"preferred_arrow_chunk_size",
3041+
],
3042+
) -> None: ...
3043+
def WhichOneof(
3044+
self,
3045+
oneof_group: typing_extensions.Literal[
3046+
"_preferred_arrow_chunk_size", b"_preferred_arrow_chunk_size"
3047+
],
3048+
) -> typing_extensions.Literal["preferred_arrow_chunk_size"] | None: ...
3049+
3050+
global___ResultChunkingOptions = ResultChunkingOptions
3051+
29453052
class ReattachExecuteRequest(google.protobuf.message.Message):
29463053
DESCRIPTOR: google.protobuf.descriptor.Descriptor
29473054

0 commit comments

Comments
 (0)