Skip to content

Commit 7d834d6

Browse files
GH-36709: [Python] Allow to specify use_threads=False in Table.group_by to have stable ordering (#36768)
### Rationale for this change Add a `use_threads` keyword to the `group_by` method on Table, and passes this through to the Declaration.to_table call. This also allows to specify `use_threads=False` to get stable ordering of the output, and which is also required to specify for certain aggregations (eg `"first"` will fail with the default of `use_threads=True`) ### Are these changes tested? Yes, added a test (similar to the one we have for this for `filter`), that would fail (>50% of the times) if the output was no longer ordered. * Closes: #36709 Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
1 parent 02de3c1 commit 7d834d6

4 files changed

Lines changed: 46 additions & 7 deletions

File tree

python/pyarrow/acero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,10 @@ def _sort_source(table_or_dataset, sort_keys, output_type=Table, **kwargs):
299299
raise TypeError("Unsupported output type")
300300

301301

302-
def _group_by(table, aggregates, keys):
302+
def _group_by(table, aggregates, keys, use_threads=True):
303303

304304
decl = Declaration.from_sequence([
305305
Declaration("table_source", TableSourceNodeOptions(table)),
306306
Declaration("aggregate", AggregateNodeOptions(aggregates, keys=keys))
307307
])
308-
return decl.to_table(use_threads=True)
308+
return decl.to_table(use_threads=use_threads)

python/pyarrow/table.pxi

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4599,8 +4599,9 @@ cdef class Table(_Tabular):
45994599
"""
46004600
return self.drop_columns(columns)
46014601

4602-
def group_by(self, keys):
4603-
"""Declare a grouping over the columns of the table.
4602+
def group_by(self, keys, use_threads=True):
4603+
"""
4604+
Declare a grouping over the columns of the table.
46044605
46054606
Resulting grouping can then be used to perform aggregations
46064607
with a subsequent ``aggregate()`` method.
@@ -4609,6 +4610,9 @@ cdef class Table(_Tabular):
46094610
----------
46104611
keys : str or list[str]
46114612
Name of the columns that should be used as the grouping key.
4613+
use_threads : bool, default True
4614+
Whether to use multithreading or not. When set to True (the
4615+
default), no stable ordering of the output is guaranteed.
46124616
46134617
Returns
46144618
-------
@@ -4635,7 +4639,7 @@ cdef class Table(_Tabular):
46354639
year: [[2020,2022,2021,2019]]
46364640
n_legs_sum: [[2,6,104,5]]
46374641
"""
4638-
return TableGroupBy(self, keys)
4642+
return TableGroupBy(self, keys, use_threads=use_threads)
46394643

46404644
def join(self, right_table, keys, right_keys=None, join_type="left outer",
46414645
left_suffix=None, right_suffix=None, coalesce_keys=True,
@@ -5183,6 +5187,9 @@ class TableGroupBy:
51835187
Input table to execute the aggregation on.
51845188
keys : str or list[str]
51855189
Name of the grouped columns.
5190+
use_threads : bool, default True
5191+
Whether to use multithreading or not. When set to True (the default),
5192+
no stable ordering of the output is guaranteed.
51865193
51875194
Examples
51885195
--------
@@ -5208,12 +5215,13 @@ class TableGroupBy:
52085215
values_sum: [[3,7,5]]
52095216
"""
52105217

5211-
def __init__(self, table, keys):
5218+
def __init__(self, table, keys, use_threads=True):
52125219
if isinstance(keys, str):
52135220
keys = [keys]
52145221

52155222
self._table = table
52165223
self.keys = keys
5224+
self._use_threads = use_threads
52175225

52185226
def aggregate(self, aggregations):
52195227
"""
@@ -5328,4 +5336,6 @@ list[tuple(str, str, FunctionOptions)]
53285336
aggr_name = "_".join(target) + "_" + func_nohash
53295337
group_by_aggrs.append((target, func, opt, aggr_name))
53305338

5331-
return _pac()._group_by(self._table, group_by_aggrs, self.keys)
5339+
return _pac()._group_by(
5340+
self._table, group_by_aggrs, self.keys, use_threads=self._use_threads
5341+
)

python/pyarrow/tests/test_exec_plan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,17 @@ def test_join_extension_array_column():
321321
result = _perform_join(
322322
"left outer", t1, ["colB"], t3, ["colC"])
323323
assert result["colB"] == pa.chunked_array(ext_array)
324+
325+
326+
def test_group_by_ordering():
327+
# GH-36709 - preserve ordering in groupby by setting use_threads=False
328+
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a'] * 4})
329+
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b'] * 4})
330+
table = pa.concat_tables([table1, table2])
331+
332+
for _ in range(50):
333+
# 50 seems to consistently cause errors when order is not preserved.
334+
# If the order problem is reintroduced this test will become flaky
335+
# which is still a signal that the order is not preserved.
336+
result = table.group_by("b", use_threads=False).aggregate([])
337+
assert result["b"] == pa.chunked_array([["a"], ["b"]])

python/pyarrow/tests/test_table.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,6 +2175,21 @@ def sorted_by_keys(d):
21752175
}
21762176

21772177

2178+
@pytest.mark.acero
2179+
def test_table_group_by_first():
2180+
# "first" is an ordered aggregation -> requires to specify use_threads=False
2181+
table1 = pa.table({'a': [1, 2, 3, 4], 'b': ['a', 'b'] * 2})
2182+
table2 = pa.table({'a': [1, 2, 3, 4], 'b': ['b', 'a'] * 2})
2183+
table = pa.concat_tables([table1, table2])
2184+
2185+
with pytest.raises(NotImplementedError):
2186+
table.group_by("b").aggregate([("a", "first")])
2187+
2188+
result = table.group_by("b", use_threads=False).aggregate([("a", "first")])
2189+
expected = pa.table({"b": ["a", "b"], "a_first": [1, 2]})
2190+
assert result.equals(expected)
2191+
2192+
21782193
def test_table_to_recordbatchreader():
21792194
table = pa.Table.from_pydict({'x': [1, 2, 3]})
21802195
reader = table.to_reader()

0 commit comments

Comments
 (0)