From fb00d4e599620d06b60b3c89eb200517e5417cf8 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 12 Oct 2023 10:34:44 -0500 Subject: [PATCH 1/4] add custom check to assert_eq --- dask_expr/_expr.py | 2 ++ dask_expr/_groupby.py | 3 ++- dask_expr/tests/_util.py | 20 ++++++++++++++++++++ dask_expr/tests/test_categorical.py | 3 +-- dask_expr/tests/test_collection.py | 4 ++-- dask_expr/tests/test_concat.py | 3 +-- dask_expr/tests/test_datasets.py | 2 +- dask_expr/tests/test_datetime.py | 3 +-- dask_expr/tests/test_fusion.py | 3 +-- dask_expr/tests/test_groupby.py | 3 +-- dask_expr/tests/test_merge.py | 3 +-- dask_expr/tests/test_quantiles.py | 4 +--- dask_expr/tests/test_repartition.py | 3 +-- dask_expr/tests/test_reshape.py | 3 +-- dask_expr/tests/test_shuffle.py | 3 +-- dask_expr/tests/test_string_accessor.py | 3 +-- 16 files changed, 38 insertions(+), 27 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 47bba4182..ef0dd1246 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -136,6 +136,8 @@ def __hash__(self): return hash(self._name) def __reduce__(self): + if dask.config.get("dask-expr-no-serialize", False): + raise RuntimeError(f"Serializing a {type(self)} object") return type(self), tuple(self.operands) def _depth(self): diff --git a/dask_expr/_groupby.py b/dask_expr/_groupby.py index c7dbd3259..7bcc7a393 100644 --- a/dask_expr/_groupby.py +++ b/dask_expr/_groupby.py @@ -328,7 +328,8 @@ class Var(GroupByReduction): reduction_aggregate = _var_agg reduction_combine = _var_combine - def chunk(self, frame, by, **kwargs): + @staticmethod + def chunk(frame, by, **kwargs): if hasattr(by, "dtype"): by = [by] return _var_chunk(frame, *by, **kwargs) diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 1343da74b..f94355cd5 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -2,6 +2,7 @@ import pytest from dask import config +from dask.dataframe.utils import assert_eq as dd_assert_eq def _backend_name() -> str: @@ -16,3 +17,22 @@ def xfail_gpu(reason=None): condition = _backend_name() == "cudf" reason = reason or "Failure expected for cudf backend." return pytest.mark.xfail(condition, reason=reason) + + +def assert_eq(a, b, *args, serialize_graph=True, **kwargs): + if serialize_graph: + # Check that no `Expr` instances are found in + # the graph generated by `Expr.dask` + try: + from distributed.protocol import serialize + from distributed.protocol.serialize import ToPickle + + with config.set({"dask-expr-no-serialize": True}): + for obj in [a, b]: + if hasattr(obj, "dask"): + serialize(ToPickle(obj.dask)) + except ImportError: + pass + + # Use `dask.dataframe.assert_eq` + return dd_assert_eq(a, b, *args, **kwargs) diff --git a/dask_expr/tests/test_categorical.py b/dask_expr/tests/test_categorical.py index 7fe9e9abb..91e7c5de1 100644 --- a/dask_expr/tests/test_categorical.py +++ b/dask_expr/tests/test_categorical.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe import assert_eq from dask_expr import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 1d5c08d6b..ba9933b01 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -7,14 +7,14 @@ import numpy as np import pytest from dask.dataframe._compat import PANDAS_GE_200, PANDAS_GE_210 -from dask.dataframe.utils import UNKNOWN_CATEGORIES, assert_eq +from dask.dataframe.utils import UNKNOWN_CATEGORIES from dask.utils import M from dask_expr import expr, from_pandas, is_scalar, optimize from dask_expr._expr import are_co_aligned from dask_expr._reductions import Len from dask_expr.datasets import timeseries -from dask_expr.tests._util import _backend_library, xfail_gpu +from dask_expr.tests._util import _backend_library, assert_eq, xfail_gpu # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_concat.py b/dask_expr/tests/test_concat.py index 9022bd9bf..d319f3796 100644 --- a/dask_expr/tests/test_concat.py +++ b/dask_expr/tests/test_concat.py @@ -1,9 +1,8 @@ import numpy as np import pytest -from dask.dataframe import assert_eq from dask_expr import Len, concat, from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_datasets.py b/dask_expr/tests/test_datasets.py index 5b644def0..1541beb29 100644 --- a/dask_expr/tests/test_datasets.py +++ b/dask_expr/tests/test_datasets.py @@ -3,11 +3,11 @@ import pytest from dask.dataframe._compat import PANDAS_GE_200 -from dask.dataframe.utils import assert_eq from dask_expr import new_collection from dask_expr._expr import Lengths from dask_expr.datasets import Timeseries, timeseries +from dask_expr.tests._util import assert_eq def test_timeseries(): diff --git a/dask_expr/tests/test_datetime.py b/dask_expr/tests/test_datetime.py index c32968ce2..b68973689 100644 --- a/dask_expr/tests/test_datetime.py +++ b/dask_expr/tests/test_datetime.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe import assert_eq from dask_expr._collection import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq lib = _backend_library() diff --git a/dask_expr/tests/test_fusion.py b/dask_expr/tests/test_fusion.py index 747e634d4..3dbc15c0b 100644 --- a/dask_expr/tests/test_fusion.py +++ b/dask_expr/tests/test_fusion.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe.utils import assert_eq from dask_expr import from_pandas, optimize -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_groupby.py b/dask_expr/tests/test_groupby.py index ef229b18f..9b6970027 100644 --- a/dask_expr/tests/test_groupby.py +++ b/dask_expr/tests/test_groupby.py @@ -1,9 +1,8 @@ import pytest -from dask.dataframe.utils import assert_eq from dask_expr import from_pandas from dask_expr._reductions import TreeReduce -from dask_expr.tests._util import _backend_library, xfail_gpu +from dask_expr.tests._util import _backend_library, assert_eq, xfail_gpu # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index ca8013f0d..5ee31d333 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe.utils import assert_eq from dask_expr import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_quantiles.py b/dask_expr/tests/test_quantiles.py index 4c63b603d..733987db5 100644 --- a/dask_expr/tests/test_quantiles.py +++ b/dask_expr/tests/test_quantiles.py @@ -1,7 +1,5 @@ -from dask.dataframe import assert_eq - from dask_expr import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_repartition.py b/dask_expr/tests/test_repartition.py index 8f4d2b7b5..321508c24 100644 --- a/dask_expr/tests/test_repartition.py +++ b/dask_expr/tests/test_repartition.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe import assert_eq from dask_expr import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq lib = _backend_library() diff --git a/dask_expr/tests/test_reshape.py b/dask_expr/tests/test_reshape.py index 682926ca5..5eb6cf69f 100644 --- a/dask_expr/tests/test_reshape.py +++ b/dask_expr/tests/test_reshape.py @@ -1,8 +1,7 @@ import pytest -from dask.dataframe import assert_eq from dask_expr import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 1348fafdb..6649c01cf 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -1,14 +1,13 @@ from collections import OrderedDict import pytest -from dask.dataframe.utils import assert_eq from dask_expr import SetIndexBlockwise, from_pandas from dask_expr._expr import Blockwise from dask_expr._repartition import RepartitionToFewer from dask_expr._shuffle import divisions_lru from dask_expr.io import FromPandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq # Set DataFrame backend for this module lib = _backend_library() diff --git a/dask_expr/tests/test_string_accessor.py b/dask_expr/tests/test_string_accessor.py index 74a8ddb00..643a99a8e 100644 --- a/dask_expr/tests/test_string_accessor.py +++ b/dask_expr/tests/test_string_accessor.py @@ -1,9 +1,8 @@ import pytest -from dask.dataframe import assert_eq from dask.dataframe._compat import PANDAS_GE_200 from dask_expr._collection import from_pandas -from dask_expr.tests._util import _backend_library +from dask_expr.tests._util import _backend_library, assert_eq lib = _backend_library() From 4597094c3949b0dc0eba69b4403808090d19660b Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 12 Oct 2023 13:31:42 -0500 Subject: [PATCH 2/4] try using pickle first --- dask_expr/_expr.py | 2 +- dask_expr/tests/_util.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index ef0dd1246..11f8d7852 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1344,7 +1344,7 @@ def _task(self, index: int): args = [self._blockwise_arg(self.frame, index)] + [ self.state_data[index], self.frac, - self.replace, + self.operand("replace"), ] return (self.operation,) + tuple(args) diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index f94355cd5..3383d3f98 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -1,4 +1,5 @@ import importlib +import pickle import pytest from dask import config @@ -30,7 +31,10 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): with config.set({"dask-expr-no-serialize": True}): for obj in [a, b]: if hasattr(obj, "dask"): - serialize(ToPickle(obj.dask)) + try: + pickle.dumps(obj.dask) + except (AttributeError, pickle.PicklingError): + serialize(ToPickle(obj.dask)) except ImportError: pass From f9342f66eb975925f8cd8c37d19a94e4ac6b850b Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 13 Oct 2023 08:14:38 -0500 Subject: [PATCH 3/4] remove a few more lambda instances within the library --- dask_expr/_reductions.py | 14 ++++++++++---- dask_expr/tests/_util.py | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dask_expr/_reductions.py b/dask_expr/_reductions.py index 0e8ce5188..426f5b856 100644 --- a/dask_expr/_reductions.py +++ b/dask_expr/_reductions.py @@ -242,8 +242,8 @@ def _tree_repr_lines(self, indent=0, recursive=True): class Unique(ApplyConcatApply): _parameters = ["frame"] - chunk = staticmethod(lambda x, **kwargs: methods.unique(x, **kwargs)) - aggregate_func = methods.unique + chunk = staticmethod(methods.unique) + aggregate_func = staticmethod(methods.unique) @functools.cached_property def _meta(self): @@ -602,9 +602,12 @@ def _simplify_up(self, parent): class Size(Reduction): - reduction_chunk = staticmethod(lambda df: df.size) reduction_aggregate = sum + @staticmethod + def reduction_chunk(df): + return df.size + def _simplify_down(self): if is_dataframe_like(self.frame._meta) and len(self.frame.columns) > 1: return len(self.frame.columns) * Len(self.frame) @@ -617,10 +620,13 @@ def _simplify_up(self, parent): class NBytes(Reduction): # Only supported for Series objects - reduction_chunk = lambda ser: ser.nbytes reduction_aggregate = sum _required_attribute = "nbytes" + @staticmethod + def reduction_chunk(ser): + return ser.nbytes + class Var(Reduction): # Uses the parallel version of Welford's online algorithm (Chan 79') diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index 3383d3f98..ac9e48b4b 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -33,7 +33,7 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): if hasattr(obj, "dask"): try: pickle.dumps(obj.dask) - except (AttributeError, pickle.PicklingError): + except AttributeError: serialize(ToPickle(obj.dask)) except ImportError: pass From 870dac7b93288139fe584fa63ee36e182606bf80 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 13 Oct 2023 08:23:52 -0500 Subject: [PATCH 4/4] use cloudpickle --- dask_expr/tests/_util.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/dask_expr/tests/_util.py b/dask_expr/tests/_util.py index ac9e48b4b..1f24bfda1 100644 --- a/dask_expr/tests/_util.py +++ b/dask_expr/tests/_util.py @@ -24,19 +24,18 @@ def assert_eq(a, b, *args, serialize_graph=True, **kwargs): if serialize_graph: # Check that no `Expr` instances are found in # the graph generated by `Expr.dask` - try: - from distributed.protocol import serialize - from distributed.protocol.serialize import ToPickle - - with config.set({"dask-expr-no-serialize": True}): - for obj in [a, b]: - if hasattr(obj, "dask"): + with config.set({"dask-expr-no-serialize": True}): + for obj in [a, b]: + if hasattr(obj, "dask"): + try: + pickle.dumps(obj.dask) + except AttributeError: try: - pickle.dumps(obj.dask) - except AttributeError: - serialize(ToPickle(obj.dask)) - except ImportError: - pass + import cloudpickle as cp + + cp.dumps(obj.dask) + except ImportError: + pass # Use `dask.dataframe.assert_eq` return dd_assert_eq(a, b, *args, **kwargs)