From a3c504e5088ae4fed03522967de3cbb06fdb4df6 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 12 Jul 2023 11:25:53 -0500 Subject: [PATCH 1/3] enforce deterministic tokens --- dask_expr/_util.py | 8 ++++++++ dask_expr/expr.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/dask_expr/_util.py b/dask_expr/_util.py index a206fa9ad..7e716a961 100644 --- a/dask_expr/_util.py +++ b/dask_expr/_util.py @@ -1,5 +1,8 @@ from __future__ import annotations +from dask import config +from dask.base import tokenize + def _convert_to_list(column) -> list | None: if column is None or isinstance(column, list): @@ -11,3 +14,8 @@ def _convert_to_list(column) -> list | None: else: column = [column] return column + + +def _tokenize_deterministic(*args, **kwargs): + with config.set({"tokenize.ensure-deterministic": True}): + return tokenize(*args, **kwargs) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index b5472339e..5c4d66a7f 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -10,7 +10,7 @@ import dask import pandas as pd import toolz -from dask.base import normalize_token, tokenize +from dask.base import normalize_token from dask.core import flatten, ishashable from dask.dataframe import methods from dask.dataframe.core import ( @@ -27,6 +27,8 @@ from dask.utils import M, apply, funcname, import_required, is_arraylike from tlz import merge_sorted, unique +from dask_expr._util import _tokenize_deterministic + replacement_rules = [] no_default = "__no_default__" @@ -530,7 +532,9 @@ def npartitions(self): @functools.cached_property def _name(self): - return funcname(type(self)).lower() + "-" + tokenize(*self.operands) + return ( + funcname(type(self)).lower() + "-" + _tokenize_deterministic(*self.operands) + ) @property def columns(self) -> list: @@ -836,7 +840,7 @@ def _name(self): head = funcname(self.operation) else: head = funcname(type(self)).lower() - return head + "-" + tokenize(*self.operands) + return head + "-" + _tokenize_deterministic(*self.operands) def _blockwise_arg(self, arg, i): """Return a Blockwise-task argument""" @@ -2025,7 +2029,7 @@ def __str__(self): @functools.cached_property def _name(self): - return f"{str(self)}-{tokenize(self.exprs)}" + return f"{str(self)}-{_tokenize_deterministic(self.exprs)}" def _divisions(self): return self.exprs[0]._divisions() From 73fbf2b0b246416bf8146a8eeaf9ccb1dfc2b104 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 12 Jul 2023 12:05:38 -0500 Subject: [PATCH 2/3] get all tests passing --- dask_expr/_util.py | 10 +++++++++- dask_expr/collection.py | 20 +++++++++++++------- dask_expr/expr.py | 4 ++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/dask_expr/_util.py b/dask_expr/_util.py index 7e716a961..6275e3769 100644 --- a/dask_expr/_util.py +++ b/dask_expr/_util.py @@ -1,7 +1,9 @@ from __future__ import annotations +from types import LambdaType + from dask import config -from dask.base import tokenize +from dask.base import normalize_token, tokenize def _convert_to_list(column) -> list | None: @@ -16,6 +18,12 @@ def _convert_to_list(column) -> list | None: return column +@normalize_token.register(LambdaType) +def _normalize_lambda(func): + return str(func) + + def _tokenize_deterministic(*args, **kwargs): + # Utility to be strict about deterministic tokens with config.set({"tokenize.ensure-deterministic": True}): return tokenize(*args, **kwargs) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 41a4dcedb..2b9fd7aa7 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -23,7 +23,7 @@ from tlz import first from dask_expr import expr -from dask_expr._util import _convert_to_list +from dask_expr._util import _convert_to_list, _wrap_lambdas from dask_expr.align import AlignPartitions from dask_expr.concat import Concat from dask_expr.expr import Eval, no_default @@ -374,7 +374,7 @@ def map_partitions( raise NotImplementedError() new_expr = expr.MapPartitions( self.expr, - func, + _wrap_lambdas(func), meta, enforce_metadata, transform_divisions, @@ -655,7 +655,9 @@ def __dir__(self): return list(o) def map(self, func, na_action=None): - return new_collection(expr.Map(self.expr, arg=func, na_action=na_action)) + return new_collection( + expr.Map(self.expr, arg=_wrap_lambdas(func), na_action=na_action) + ) def __repr__(self): return f"" @@ -777,7 +779,9 @@ def nbytes(self): return new_collection(self.expr.nbytes) def map(self, arg, na_action=None): - return new_collection(expr.Map(self.expr, arg=arg, na_action=na_action)) + return new_collection( + expr.Map(self.expr, arg=_wrap_lambdas(arg), na_action=na_action) + ) def __repr__(self): return f"" @@ -897,10 +901,12 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase: return from_graph(graph, ddf._meta, ddf.divisions, ddf._name) -def read_csv(*args, **kwargs): +def read_csv(path, *args, **kwargs): from dask_expr.io.csv import ReadCSV - return new_collection(ReadCSV(*args, **kwargs)) + if not isinstance(path, str): + path = stringify_path(path) + return new_collection(ReadCSV(path, *args, **kwargs)) def read_parquet( @@ -923,7 +929,7 @@ def read_parquet( ): from dask_expr.io.parquet import ReadParquet - if hasattr(path, "name"): + if not isinstance(path, str): path = stringify_path(path) kwargs["dtype_backend"] = dtype_backend diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 5c4d66a7f..76a890e86 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -27,7 +27,7 @@ from dask.utils import M, apply, funcname, import_required, is_arraylike from tlz import merge_sorted, unique -from dask_expr._util import _tokenize_deterministic +from dask_expr._util import _tokenize_deterministic, _wrap_lambdas replacement_rules = [] @@ -466,7 +466,7 @@ def round(self, decimals=0): return Round(self, decimals=decimals) def apply(self, function, *args, **kwargs): - return Apply(self, function, args, kwargs) + return Apply(self, _wrap_lambdas(function), args, kwargs) def replace(self, to_replace=None, value=no_default, regex=False): return Replace(self, to_replace=to_replace, value=value, regex=regex) From b01267abcd2816ba309fb44fc9d36630cbfa61f1 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 12 Jul 2023 12:15:10 -0500 Subject: [PATCH 3/3] remove leftover _wrap_lambdas --- dask_expr/collection.py | 12 ++++-------- dask_expr/expr.py | 4 ++-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 2b9fd7aa7..038fa09e4 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -23,7 +23,7 @@ from tlz import first from dask_expr import expr -from dask_expr._util import _convert_to_list, _wrap_lambdas +from dask_expr._util import _convert_to_list from dask_expr.align import AlignPartitions from dask_expr.concat import Concat from dask_expr.expr import Eval, no_default @@ -374,7 +374,7 @@ def map_partitions( raise NotImplementedError() new_expr = expr.MapPartitions( self.expr, - _wrap_lambdas(func), + func, meta, enforce_metadata, transform_divisions, @@ -655,9 +655,7 @@ def __dir__(self): return list(o) def map(self, func, na_action=None): - return new_collection( - expr.Map(self.expr, arg=_wrap_lambdas(func), na_action=na_action) - ) + return new_collection(expr.Map(self.expr, arg=func, na_action=na_action)) def __repr__(self): return f"" @@ -779,9 +777,7 @@ def nbytes(self): return new_collection(self.expr.nbytes) def map(self, arg, na_action=None): - return new_collection( - expr.Map(self.expr, arg=_wrap_lambdas(arg), na_action=na_action) - ) + return new_collection(expr.Map(self.expr, arg=arg, na_action=na_action)) def __repr__(self): return f"" diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 76a890e86..5c4d66a7f 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -27,7 +27,7 @@ from dask.utils import M, apply, funcname, import_required, is_arraylike from tlz import merge_sorted, unique -from dask_expr._util import _tokenize_deterministic, _wrap_lambdas +from dask_expr._util import _tokenize_deterministic replacement_rules = [] @@ -466,7 +466,7 @@ def round(self, decimals=0): return Round(self, decimals=decimals) def apply(self, function, *args, **kwargs): - return Apply(self, _wrap_lambdas(function), args, kwargs) + return Apply(self, function, args, kwargs) def replace(self, to_replace=None, value=no_default, regex=False): return Replace(self, to_replace=to_replace, value=value, regex=regex)