Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def __hash__(self):
return hash(self._name)

def __reduce__(self):
if dask.config.get("dask-expr-no-serialize", False):
raise RuntimeError(f"Serializing a {type(self)} object")
return type(self), tuple(self.operands)

def _depth(self):
Expand Down Expand Up @@ -1349,7 +1351,7 @@ def _task(self, index: int):
args = [self._blockwise_arg(self.frame, index)] + [
self.state_data[index],
self.frac,
self.replace,
self.operand("replace"),
]
return (self.operation,) + tuple(args)

Expand Down
3 changes: 2 additions & 1 deletion dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ class Var(GroupByReduction):
reduction_aggregate = _var_agg
reduction_combine = _var_combine

def chunk(self, frame, by, **kwargs):
@staticmethod
def chunk(frame, by, **kwargs):
if hasattr(by, "dtype"):
by = [by]
return _var_chunk(frame, *by, **kwargs)
Expand Down
14 changes: 10 additions & 4 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def _tree_repr_lines(self, indent=0, recursive=True):

class Unique(ApplyConcatApply):
_parameters = ["frame"]
chunk = staticmethod(lambda x, **kwargs: methods.unique(x, **kwargs))
aggregate_func = methods.unique
chunk = staticmethod(methods.unique)
aggregate_func = staticmethod(methods.unique)

@functools.cached_property
def _meta(self):
Expand Down Expand Up @@ -602,9 +602,12 @@ def _simplify_up(self, parent):


class Size(Reduction):
reduction_chunk = staticmethod(lambda df: df.size)
reduction_aggregate = sum

@staticmethod
def reduction_chunk(df):
return df.size

def _simplify_down(self):
if is_dataframe_like(self.frame._meta) and len(self.frame.columns) > 1:
return len(self.frame.columns) * Len(self.frame)
Expand All @@ -617,10 +620,13 @@ def _simplify_up(self, parent):

class NBytes(Reduction):
# Only supported for Series objects
reduction_chunk = lambda ser: ser.nbytes
reduction_aggregate = sum
_required_attribute = "nbytes"

@staticmethod
def reduction_chunk(ser):
return ser.nbytes


class Var(Reduction):
# Uses the parallel version of Welford's online algorithm (Chan 79')
Expand Down
23 changes: 23 additions & 0 deletions dask_expr/tests/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import importlib
import pickle

import pytest
from dask import config
from dask.dataframe.utils import assert_eq as dd_assert_eq


def _backend_name() -> str:
Expand All @@ -16,3 +18,24 @@ def xfail_gpu(reason=None):
condition = _backend_name() == "cudf"
reason = reason or "Failure expected for cudf backend."
return pytest.mark.xfail(condition, reason=reason)


def assert_eq(a, b, *args, serialize_graph=True, **kwargs):
if serialize_graph:
# Check that no `Expr` instances are found in
# the graph generated by `Expr.dask`
with config.set({"dask-expr-no-serialize": True}):
for obj in [a, b]:
if hasattr(obj, "dask"):
try:
pickle.dumps(obj.dask)
except AttributeError:
try:
import cloudpickle as cp

cp.dumps(obj.dask)
except ImportError:
pass

# Use `dask.dataframe.assert_eq`
return dd_assert_eq(a, b, *args, **kwargs)
3 changes: 1 addition & 2 deletions dask_expr/tests/test_categorical.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from dask.dataframe import assert_eq

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
4 changes: 2 additions & 2 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import numpy as np
import pytest
from dask.dataframe._compat import PANDAS_GE_200, PANDAS_GE_210
from dask.dataframe.utils import UNKNOWN_CATEGORIES, assert_eq
from dask.dataframe.utils import UNKNOWN_CATEGORIES
from dask.utils import M

from dask_expr import expr, from_pandas, is_scalar, optimize
from dask_expr._expr import are_co_aligned
from dask_expr._reductions import Len
from dask_expr.datasets import timeseries
from dask_expr.tests._util import _backend_library, xfail_gpu
from dask_expr.tests._util import _backend_library, assert_eq, xfail_gpu

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_concat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
import pytest
from dask.dataframe import assert_eq

from dask_expr import Len, concat, from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pytest
from dask.dataframe._compat import PANDAS_GE_200
from dask.dataframe.utils import assert_eq

from dask_expr import new_collection
from dask_expr._expr import Lengths
from dask_expr.datasets import Timeseries, timeseries
from dask_expr.tests._util import assert_eq


def test_timeseries():
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from dask.dataframe import assert_eq

from dask_expr._collection import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

lib = _backend_library()

Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_fusion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import from_pandas, optimize
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import from_pandas
from dask_expr._reductions import TreeReduce
from dask_expr.tests._util import _backend_library, xfail_gpu
from dask_expr.tests._util import _backend_library, assert_eq, xfail_gpu

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import Merge, from_pandas
from dask_expr._expr import Projection
from dask_expr._shuffle import Shuffle
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
4 changes: 1 addition & 3 deletions dask_expr/tests/test_quantiles.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from dask.dataframe import assert_eq

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_repartition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from dask.dataframe import assert_eq

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

lib = _backend_library()

Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_reshape.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from dask.dataframe import assert_eq

from dask_expr import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from collections import OrderedDict

import pytest
from dask.dataframe.utils import assert_eq

from dask_expr import SetIndexBlockwise, from_pandas
from dask_expr._expr import Blockwise
from dask_expr._repartition import RepartitionToFewer
from dask_expr._shuffle import divisions_lru
from dask_expr.io import FromPandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

# Set DataFrame backend for this module
lib = _backend_library()
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_string_accessor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from dask.dataframe import assert_eq
from dask.dataframe._compat import PANDAS_GE_200

from dask_expr._collection import from_pandas
from dask_expr.tests._util import _backend_library
from dask_expr.tests._util import _backend_library, assert_eq

lib = _backend_library()

Expand Down