77except ImportError :
88 pass
99
10+ import collections
1011import itertools
1112import operator
1213from typing import (
1314 Any ,
1415 Callable ,
1516 Dict ,
17+ DefaultDict ,
1618 Hashable ,
1719 Mapping ,
1820 Sequence ,
@@ -222,6 +224,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
222224 indexes .update ({k : template .indexes [k ] for k in new_indexes })
223225
224226 graph : Dict [Any , Any ] = {}
227+ new_layers : DefaultDict [str , Dict [Any , Any ]] = collections .defaultdict (dict )
225228 gname = "{}-{}" .format (
226229 dask .utils .funcname (func ), dask .base .tokenize (dataset , args , kwargs )
227230 )
@@ -310,9 +313,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
310313 # unchunked dimensions in the input have one chunk in the result
311314 key += (0 ,)
312315
313- graph [key ] = (operator .getitem , from_wrapper , name )
316+ new_layers [ gname_l ] [key ] = (operator .getitem , from_wrapper , name )
314317
315- graph = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
318+ hlg = HighLevelGraph .from_collections (gname , graph , dependencies = [dataset ])
319+
320+ for gname_l , layer in new_layers .items ():
321+ hlg .dependencies [gname_l ] = {gname }
322+ hlg .layers [gname_l ] = layer
316323
317324 result = Dataset (coords = indexes , attrs = template .attrs )
318325 for name , gname_l in var_key_map .items ():
@@ -325,7 +332,7 @@ def _wrapper(func, obj, to_array, args, kwargs):
325332 var_chunks .append ((len (indexes [dim ]),))
326333
327334 data = dask .array .Array (
328- graph , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
335+ hlg , name = gname_l , chunks = var_chunks , dtype = template [name ].dtype
329336 )
330337 result [name ] = (dims , data , template [name ].attrs )
331338
0 commit comments