Skip to content

Commit e84cc97

Browse files
committed
Optimize dask array equality checks.
Dask arrays with the same graph have the same name. We can use this to quickly compare dask-backed variables without computing. Fixes pydata#3068 and pydata#3311
1 parent 79b3cdd commit e84cc97

5 files changed

Lines changed: 119 additions & 9 deletions

File tree

xarray/core/concat.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from . import dtypes, utils
44
from .alignment import align
5+
from .duck_array_ops import lazy_array_equiv
56
from .merge import _VALID_COMPAT, unique_variable
67
from .variable import IndexVariable, Variable, as_variable
78
from .variable import concat as concat_vars
@@ -189,6 +190,21 @@ def process_subset_opt(opt, subset):
189190
# all nonindexes that are not the same in each dataset
190191
for k in getattr(datasets[0], subset):
191192
if k not in concat_over:
193+
equals[k] = None
194+
variables = [ds.variables[k] for ds in datasets]
195+
# first check without comparing values i.e. no computes
196+
for var in variables[1:]:
197+
equals[k] = getattr(variables[0], compat)(
198+
var, equiv=lazy_array_equiv
199+
)
200+
if not equals[k]:
201+
break
202+
203+
if equals[k] is not None:
204+
if equals[k] is False:
205+
concat_over.add(k)
206+
continue
207+
192208
# Compare the variable of all datasets vs. the one
193209
# of the first dataset. Perform the minimum amount of
194210
# loads in order to avoid multiple loads from disk

xarray/core/duck_array_ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,49 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
181181
arr2 = asarray(arr2)
182182
if arr1.shape != arr2.shape:
183183
return False
184+
if (
185+
dask_array
186+
and isinstance(arr1, dask_array.Array)
187+
and isinstance(arr2, dask_array.Array)
188+
):
189+
# GH3068
190+
if arr1.name == arr2.name:
191+
return True
184192
return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
185193

186194

195+
def lazy_array_equiv(arr1, arr2):
196+
"""Like array_equal, but doesn't actually compare values
197+
"""
198+
arr1 = asarray(arr1)
199+
arr2 = asarray(arr2)
200+
if arr1.shape != arr2.shape:
201+
return False
202+
if (
203+
dask_array
204+
and isinstance(arr1, dask_array.Array)
205+
and isinstance(arr2, dask_array.Array)
206+
):
207+
# GH3068
208+
if arr1.name == arr2.name:
209+
return True
210+
211+
187212
def array_equiv(arr1, arr2):
188213
"""Like np.array_equal, but also allows values to be NaN in both arrays
189214
"""
190215
arr1 = asarray(arr1)
191216
arr2 = asarray(arr2)
192217
if arr1.shape != arr2.shape:
193218
return False
219+
if (
220+
dask_array
221+
and isinstance(arr1, dask_array.Array)
222+
and isinstance(arr2, dask_array.Array)
223+
):
224+
# GH3068
225+
if arr1.name == arr2.name:
226+
return True
194227
with warnings.catch_warnings():
195228
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
196229
flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
@@ -205,6 +238,14 @@ def array_notnull_equiv(arr1, arr2):
205238
arr2 = asarray(arr2)
206239
if arr1.shape != arr2.shape:
207240
return False
241+
if (
242+
dask_array
243+
and isinstance(arr1, dask_array.Array)
244+
and isinstance(arr2, dask_array.Array)
245+
):
246+
# GH3068
247+
if arr1.name == arr2.name:
248+
return True
208249
with warnings.catch_warnings():
209250
warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
210251
flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)

xarray/core/merge.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from . import dtypes, pdcompat
2121
from .alignment import deep_align
22+
from .duck_array_ops import lazy_array_equiv
2223
from .utils import Frozen, dict_equiv
2324
from .variable import Variable, as_variable, assert_unique_multiindex_level_names
2425

@@ -123,16 +124,24 @@ def unique_variable(
123124
combine_method = "fillna"
124125

125126
if equals is None:
126-
out = out.compute()
127+
# first check without comparing values i.e. no computes
127128
for var in variables[1:]:
128-
equals = getattr(out, compat)(var)
129+
equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
129130
if not equals:
130131
break
131132

133+
# now compare values with minimum number of computes
134+
if not equals:
135+
out = out.compute()
136+
for var in variables[1:]:
137+
equals = getattr(out, compat)(var)
138+
if not equals:
139+
break
140+
132141
if not equals:
133142
raise MergeError(
134-
"conflicting values for variable {!r} on objects to be combined. "
135-
"You can skip this check by specifying compat='override'.".format(name)
143+
f"conflicting values for variable {name!r} on objects to be combined. "
144+
"You can skip this check by specifying compat='override'."
136145
)
137146

138147
if combine_method:

xarray/core/variable.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,9 @@ def transpose(self, *dims) -> "Variable":
12291229
if len(dims) == 0:
12301230
dims = self.dims[::-1]
12311231
axes = self.get_axis_num(dims)
1232-
if len(dims) < 2: # no need to transpose if only one dimension
1232+
if len(dims) < 2 or dims == self.dims:
1233+
# no need to transpose if only one dimension
1234+
# or dims are in same order
12331235
return self.copy(deep=False)
12341236

12351237
data = as_indexable(self._data).transpose(axes)
@@ -1588,22 +1590,24 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv):
15881590
return False
15891591
return self.equals(other, equiv=equiv)
15901592

1591-
def identical(self, other):
1593+
def identical(self, other, equiv=duck_array_ops.array_equiv):
15921594
"""Like equals, but also checks attributes.
15931595
"""
15941596
try:
1595-
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(other)
1597+
return utils.dict_equiv(self.attrs, other.attrs) and self.equals(
1598+
other, equiv=equiv
1599+
)
15961600
except (TypeError, AttributeError):
15971601
return False
15981602

1599-
def no_conflicts(self, other):
1603+
def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
16001604
"""True if the intersection of two Variable's non-null data is
16011605
equal; otherwise false.
16021606
16031607
Variables can thus still be equal if there are locations where either,
16041608
or both, contain NaN values.
16051609
"""
1606-
return self.broadcast_equals(other, equiv=duck_array_ops.array_notnull_equiv)
1610+
return self.broadcast_equals(other, equiv=equiv)
16071611

16081612
def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
16091613
"""Compute the qth quantile of the data along the specified dimension.

xarray/tests/test_dask.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
assert_identical,
2323
raises_regex,
2424
)
25+
from ..core.duck_array_ops import lazy_array_equiv
2526

2627
dask = pytest.importorskip("dask")
2728
da = pytest.importorskip("dask.array")
@@ -1135,3 +1136,42 @@ def test_make_meta(map_ds):
11351136
for variable in map_ds.data_vars:
11361137
assert variable in meta.data_vars
11371138
assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim
1139+
1140+
1141+
def test_identical_coords_no_computes():
1142+
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1143+
a = xr.DataArray(
1144+
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1145+
)
1146+
b = xr.DataArray(
1147+
da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1148+
)
1149+
with raise_if_dask_computes():
1150+
c = a + b
1151+
assert_identical(c, a)
1152+
1153+
1154+
def test_lazy_array_equiv():
1155+
lons1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1156+
lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1157+
var1 = lons1.variable
1158+
var2 = lons2.variable
1159+
with raise_if_dask_computes():
1160+
lons1.equals(lons2)
1161+
with raise_if_dask_computes():
1162+
var1.equals(var2 / 2, equiv=lazy_array_equiv)
1163+
assert var1.equals(var2.compute(), equiv=lazy_array_equiv) is None
1164+
assert var1.compute().equals(var2.compute(), equiv=lazy_array_equiv) is None
1165+
1166+
with raise_if_dask_computes():
1167+
assert lons1.equals(lons1.transpose("y", "x"))
1168+
1169+
with raise_if_dask_computes():
1170+
for compat in [
1171+
"broadcast_equals",
1172+
"equals",
1173+
"override",
1174+
"identical",
1175+
"no_conflicts",
1176+
]:
1177+
xr.merge([lons1, lons2], compat=compat)

0 commit comments

Comments
 (0)