Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,9 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table:

return result

def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
def to_record_batches(
self, tasks: Iterable[FileScanTask], concurrent_tasks: Optional[int] = None
) -> Iterator[pa.RecordBatch]:
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].

Returns an Iterator of pa.RecordBatch with data from the Iceberg table
Expand All @@ -1634,6 +1636,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record

Args:
tasks: FileScanTasks representing the data files and delete files to read from.
concurrent_tasks: number of concurrent tasks

Returns:
An Iterator of PyArrow RecordBatches.
Expand All @@ -1643,8 +1646,20 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
from concurrent.futures import ThreadPoolExecutor

deletes_per_file = _read_all_delete_files(self._io, tasks)
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)

if concurrent_tasks is not None:
with ThreadPoolExecutor(max_workers=concurrent_tasks) as pool:
Comment thread
Fokko marked this conversation as resolved.
Outdated
for batches in pool.map(
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
):
for batch in batches:
yield batch

else:
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)

def _record_batches_from_scan_tasks_and_deletes(
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]
Expand Down
62 changes: 39 additions & 23 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,39 +774,55 @@ def upsert(
matched_predicate = upsert_util.create_match_filter(df, join_cols)

# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
matched_iceberg_table = DataScan(
matched_iceberg_record_batches = DataScan(
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=matched_predicate,
case_sensitive=case_sensitive,
).to_arrow()
).to_arrow_batch_reader()

update_row_cnt = 0
insert_row_cnt = 0
batches_to_overwrite = []
overwrite_predicates = []
rows_to_insert = df

if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
for batch in matched_iceberg_record_batches:
rows = pa.Table.from_batches([batch])

update_row_cnt = len(rows_to_update)
if when_matched_update_all:
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
# this extra step avoids unnecessary IO and writes
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)

if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
if len(rows_to_update) > 0:
# build the match predicate filter
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)

self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
batches_to_overwrite.append(rows_to_update)
overwrite_predicates.append(overwrite_mask_predicate)

if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
rows_to_insert = df.filter(~expr_match_arrow)
if when_not_matched_insert_all:
expr_match = upsert_util.create_match_filter(rows, join_cols)
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
expr_match_arrow = expression_to_pyarrow(expr_match_bound)

insert_row_cnt = len(rows_to_insert)
# Filter rows per batch.
rows_to_insert = rows_to_insert.filter(~expr_match_arrow)

if insert_row_cnt > 0:
update_row_cnt = 0
insert_row_cnt = 0

if batches_to_overwrite:
rows_to_update = pa.concat_tables(batches_to_overwrite)
update_row_cnt = len(rows_to_update)
self.overwrite(
rows_to_update,
overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0],
)

if when_not_matched_insert_all:
insert_row_cnt = len(rows_to_insert)
if rows_to_insert:
self.append(rows_to_insert)

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
Expand Down Expand Up @@ -1848,7 +1864,7 @@ def to_arrow(self) -> pa.Table:
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
).to_table(self.plan_files())

def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.RecordBatchReader:
"""Return an Arrow RecordBatchReader from this DataScan.

For large results, using a RecordBatchReader requires less memory than
Expand All @@ -1866,7 +1882,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
target_schema = schema_to_pyarrow(self.projection())
batches = ArrowScan(
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
).to_record_batches(self.plan_files())
).to_record_batches(self.plan_files(), concurrent_tasks=concurrent_tasks)

return pa.RecordBatchReader.from_batches(
target_schema,
Expand Down