forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlowered_backend_module.py
More file actions
964 lines (842 loc) · 38.7 KB
/
lowered_backend_module.py
File metadata and controls
964 lines (842 loc) · 38.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import copy
import operator
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _PTEFile, _serialize_pte_binary
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program
from executorch.exir.graph_module import _get_submodule
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program
from executorch.exir.tracer import Value
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses import FakeTensor
from torch.export.exported_program import (
ConstantArgument,
ExportedProgram,
ExportGraphSignature,
InputKind,
InputSpec,
ModuleCallEntry,
ModuleCallSignature,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch.fx.passes.utils.fuser_utils import (
erase_nodes,
fuse_as_graphmodule,
insert_subgm,
legalize_graph,
NodeList,
topo_sort,
)
class LoweredBackendModule(torch.nn.Module):
"""
A subclass of nn.Module that is generated for modules containing
delegated functions. This is can be created by calling `to_backend`.
"""
_backend_id: str # The backend's name
_processed_bytes: bytes # The delegate blobs created from backend.preprocess
_compile_specs: List[
CompileSpec
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
_original_exported_program: ExportedProgram # The original EXIR module
_named_data_store_output: Optional[
NamedDataStoreOutput
] # Named Data serialized by the backend
meta: Optional[Dict[str, Any]] # Metadata for the lowered module
def __init__(
self,
edge_program: ExportedProgram,
backend_id: str,
processed_bytes: bytes,
compile_specs: List[CompileSpec],
named_data_store_output: Optional[NamedDataStoreOutput] = None,
) -> None:
super().__init__()
self._original_exported_program = edge_program
self._backend_id = backend_id
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
self._named_data_store_output = named_data_store_output
self.meta = None
# pyre-ignore
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
# Copy exported program
copied_program = ExportedProgram(
root=copy.deepcopy(self._original_exported_program.graph_module),
graph=copy.deepcopy(self._original_exported_program.graph),
graph_signature=copy.deepcopy(
self._original_exported_program.graph_signature
),
state_dict=self._original_exported_program.state_dict,
range_constraints=copy.deepcopy(
self._original_exported_program.range_constraints
),
module_call_graph=copy.deepcopy(
self._original_exported_program.module_call_graph
),
constants=self._original_exported_program.constants,
verifiers=[copy.deepcopy(self._original_exported_program.verifier)],
)
res = LoweredBackendModule(
edge_program=copied_program,
backend_id=self._backend_id,
processed_bytes=self._processed_bytes,
compile_specs=copy.deepcopy(self._compile_specs, memo),
named_data_store_output=self._named_data_store_output,
)
res.meta = copy.copy(getattr(self, "meta", {}))
return res
@property
def backend_id(self) -> str:
"""
Returns the backends name.
"""
return self._backend_id
@property
def processed_bytes(self) -> bytes:
"""
Returns the delegate blob created from backend.preprocess
"""
return self._processed_bytes
@property
def compile_specs(self) -> List[CompileSpec]:
"""
Returns a list of backend-specific objects with static metadata to configure the "compilation" process.
"""
return self._compile_specs
@property
def original_module(self) -> ExportedProgram:
"""
Returns the original EXIR module
"""
return self._original_exported_program
@property
def named_data_store_output(self) -> Optional[NamedDataStoreOutput]:
"""
Returns the Named Data Store Output
"""
return self._named_data_store_output
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
def buffer(
self,
extract_delegate_segments: bool = False,
segment_alignment: int = 128,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> bytes:
"""
Returns a buffer containing the serialized ExecuTorch binary.
"""
# TODO(T181463742): avoid calling bytes(..) which incurs large copies.
out = bytes(
_serialize_pte_binary(
pte_file=_PTEFile(
program=self.program(memory_planning=memory_planning),
named_data=self.named_data_store_output,
),
extract_delegate_segments=extract_delegate_segments,
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
)
)
return out
# TODO(chenlai): re-consider recapture instead of manually constructing the program because
# the meta data construction is done manually.
def program(
self,
emit_stacktrace: bool = False,
memory_planning: MemoryPlanningPass = None, # pyre-fixme[9]
) -> Program:
# Fix autodpes introuces cyclic dependencies:
# program -> verifier -> lowered_backend_module -> program
# @manual
from executorch.exir.program._program import (
_get_updated_graph_signature,
_transform,
)
"""
Returns the object that represents the ExecuTorch binary before serialization.
"""
# Creates a new module based on the original module. The original module will
# look something like following:
#
# opcode name target args kwargs
# ------------- ------------------- ---------------- ------------------------------------------ --------
# placeholder arg0_1 arg0_1 () {}
# placeholder arg1_1 arg1_1 () {}
# call_function aten_repeat_default * (arg1_1, [4, 1]) {}
# call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {}
# call_function aten_add_tensor * (arg1_1, arg1_1) {}
# output output output ([aten_mul_tensor, aten_add_tensor],) {}
#
# if the whole module is lowered, the resulting lowered module look like
#
# opcode name target args kwargs
# ------------- ------------------------ --------------------------- ---------------------------------- --------
# placeholder arg0_1 arg0_1 () {}
# placeholder arg1_1 arg1_1 () {}
# get_attr lowered_module_0 lowered_module_0 () {}
# call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {}
# call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {}
# call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {}
# output output_1 output ([getitem, getitem_1],) {}
#
# We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node
# and return the list of getitems as the output
lowered_exported_program = copy.deepcopy(self._original_exported_program)
# Cache these properties to avoid rebuilding the dict on each access.
sig = lowered_exported_program.graph_signature
params_map = sig.inputs_to_parameters
buffers_map = sig.inputs_to_buffers
# The real input nodes are the ones not buffer or parameter
all_input_nodes = [
node
for node in lowered_exported_program.graph.nodes
if (
node.op == "placeholder"
and node.name not in buffers_map
and node.name not in params_map
)
]
output_node = lowered_exported_program.graph.output_node()
# Step 1. Cleaning up the graph before inserting the call_delegate node
# Remove the original output node
lowered_exported_program.graph.erase_node(output_node)
# Remove all the everything else except the input
for node in reversed(lowered_exported_program.graph.nodes):
if node.op != "placeholder":
lowered_exported_program.graph.erase_node(node)
# Find placeholders that are parameters or buffers, remove them from the main graph
for node in lowered_exported_program.graph.nodes:
if node.op == "placeholder" and (
node.name in buffers_map or node.name in params_map
):
lowered_exported_program.graph.erase_node(node)
# Step 2. Start constructing the graph
lowered_name = get_lowered_module_name(
lowered_exported_program.graph_module, self
)
# Insert the lowered module to the graph module as an attibute
lowered_node = lowered_exported_program.graph.get_attr(lowered_name)
# Insert a call_delegate node to the graph module, with arguments from the arg list
delegate_node = lowered_exported_program.graph.call_function(
executorch_call_delegate, (lowered_node, *all_input_nodes)
)
# Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],)
# We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly
original_output_nodes = (
self._original_exported_program.graph.output_node().args[0]
)
delegate_node.meta["spec"] = tuple(
[make_spec(node.meta["val"]) for node in original_output_nodes]
)
delegate_node.meta["val"] = tuple(
[node.meta["val"] for node in original_output_nodes]
)
# The getitem nodes that are going to be inserted to the lowered graph module
getitem_nodes = []
for i in range(len(original_output_nodes)):
getitem_node = lowered_exported_program.graph.call_function(
operator.getitem,
args=(delegate_node, i),
)
getitem_node.meta["val"] = delegate_node.meta["val"][i]
getitem_nodes.append(getitem_node)
lowered_exported_program.graph.output(getitem_nodes)
lowered_exported_program.graph_module.recompile()
lowered_exported_program.graph.lint()
# Users output will be the get items nodes instead
output_specs = [
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=getitem_node.name),
target=None,
)
for getitem_node in getitem_nodes
]
# All data are consumed by the delegates so they should be removed from the state dict.
inputs_to_parameters = (
lowered_exported_program.graph_signature.inputs_to_parameters
)
inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers
input_specs = [
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
for user_input in lowered_exported_program.graph_signature.user_inputs
if user_input not in inputs_to_parameters
and user_input not in inputs_to_buffers
]
# Double check the ExportedProgram data(especially everything except graph) is good
exported_program = ExportedProgram(
root=lowered_exported_program.graph_module,
graph=lowered_exported_program.graph,
graph_signature=_get_updated_graph_signature(
ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
),
lowered_exported_program.graph_module,
),
# TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
# somewhere as we should pass it a list of tensors to the lowered module and output a
# list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the
# inputs/outputs to the toplevel program will be in the format of the eager module.
state_dict={}, # None because all data are consumed by delegate
range_constraints=lowered_exported_program.range_constraints,
module_call_graph=lowered_exported_program.module_call_graph,
example_inputs=None,
verifiers=[lowered_exported_program.verifier],
)
if memory_planning is None:
memory_planning = MemoryPlanningPass()
exported_program = _transform(exported_program, SpecPropPass(), memory_planning)
emitted_program = emit_program(
exported_program, emit_stacktrace=emit_stacktrace
).program
return emitted_program
# Used to patch each delegated function with a call_delegate call
# @staticmethod
def forward(
self,
*args: Value,
**kwargs: Tuple[Value, ...],
) -> Value:
return executorch_call_delegate(self, *args)
# TODO(zhxchen17) Try ExportPass
def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
for node in reversed(gm.graph.nodes):
if node.op == "output":
with gm.graph.inserting_before(node):
assert len(node.args) == 1
outputs = node.args[0]
if isinstance(outputs, torch.fx.Node):
val = outputs.meta.get("val")
if isinstance(val, list):
# If a list is returned, in some cases it is represented as a
# singular node, like `split_copy_tensor` but EXIR will return a
# opened-up list like `[getitem1, getitem2]`
outputs = [
torch.fx.Proxy(outputs)[i].node for i in range(len(val))
]
returns, out_spec = pytree.tree_flatten(outputs)
node.args = (returns,)
return
def arrange_graph_placeholders(
gm: torch.fx.GraphModule, owning_program: ExportedProgram, tag
) -> torch.fx.GraphModule:
"""
Modifies the graph of the given graphmodule with one that contains the same nodes as the original,
but with placeholders in order of (Params + Buffers) (User Inputs)
This is used by the delegate api which disturbs the placeholder ordering when creating a submodule
from partitioned nodes
Args:
gm: The graph module that we want arranged
owning_program: ExportedProgram that the submodule (gm) belongs to
Returns:
The graph module in-placed arranged
"""
new_graph = torch.fx.Graph()
node_map = {} # mapping of nodes from old graph to new graph
graph_sign = owning_program.graph_signature
# Add all placeholders into the graph first:
# Cache these properties — each call rebuilds the dict from input_specs.
params_map = graph_sign.inputs_to_parameters
buffers_map = graph_sign.inputs_to_buffers
param_nodes = []
buffer_nodes = []
input_nodes = []
for node in gm.graph.nodes:
if node.op != "placeholder":
continue
if node.name in params_map and node.meta.get("delegation_tag", None) == tag:
param_nodes.append(node)
elif node.name in buffers_map and node.meta.get("delegation_tag", None) == tag:
buffer_nodes.append(node)
else:
input_nodes.append(node)
for param_node in param_nodes:
new_node = new_graph.node_copy(param_node, lambda x: node_map[x])
node_map[param_node] = new_node
for buffer_node in buffer_nodes:
new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x])
node_map[buffer_node] = new_node
for input_node in input_nodes:
new_node = new_graph.node_copy(input_node, lambda x: node_map[x])
node_map[input_node] = new_node
# Now add all the other nodes in order
for node in gm.graph.nodes:
if node.op == "placeholder":
continue
new_node = new_graph.node_copy(node, lambda x: node_map[x])
node_map[node] = new_node
# lint to ensure correctness
new_graph.lint()
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm
# TODO Don't regenerate new signature manually.
def _get_new_signature( # noqa: C901
original_program: ExportedProgram,
gm: torch.fx.GraphModule,
call_module_node: torch.fx.Node,
tag: str,
is_submodule: bool = False,
) -> Tuple[
ExportGraphSignature,
Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]],
Dict[str, InputSpec],
Dict[str, OutputSpec],
]:
"""
Args:
original_program: The original program that we are paritioning
gm: The partitioned graph module.
call_module_node: The node in the original program that is calling the
partitioned graph module.
tag: The tag being used for this partitioned submodule. This is used to
tell if a particular parameter/buffer/constant node is being tagged,
aka consumed by the delegate.
is_submodule: True if we are currently partitioning inside of a
submodule (like cond's submodule). If we are inside of a submodule,
we do not care about consuming params/buffers.
Returns:
new_signature (ExportGraphSignature): The new signature for the
partitioned graph module.
new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The
new state dict containing the consumed params/buffers.
new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject,
torch.ScriptObject]]): The new constants table containing the
consumed constants .
input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
been consumed by the delegate (param/buffer input nodes) and should
be removed from the toplevel ExportedProgram.
output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
been consumed by the delegate (buffer mutation nodes) and should be
removed from the toplevel ExportedProgram.
"""
old_signature = original_program.graph_signature
input_specs = []
output_specs = []
input_specs_to_delete = {}
output_specs_to_delete = {}
new_state_dict = {}
new_constants = {}
# If we are within a submodule, we do not need to care about consuming
# parameter/buffers
input_node_to_sig: Dict[str, InputSpec] = (
{input_spec.arg.name: input_spec for input_spec in old_signature.input_specs}
if not is_submodule
else {}
)
toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
if not is_submodule:
for output_spec in old_signature.output_specs:
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name not in input_node_to_sig:
input_specs.append(
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
)
continue
orig_input_spec = input_node_to_sig[node.name]
if not isinstance(orig_input_spec.arg, TensorArgument):
input_specs.append(orig_input_spec)
elif node.meta.get("delegation_tag", None) == tag:
input_specs.append(orig_input_spec)
if orig_input_spec.kind == InputKind.USER_INPUT:
continue
# The following input specs are all attributes that should be
# consumed by the delegate, so we want to remove it from the
# toplevel module input/output
input_specs_to_delete[node.name] = orig_input_spec
input_target = orig_input_spec.target
if input_target in original_program.state_dict:
assert orig_input_spec.kind in (
InputKind.PARAMETER,
InputKind.BUFFER,
)
new_state_dict[input_target] = original_program.state_dict[
input_target
]
elif input_target in original_program.constants:
assert orig_input_spec.kind in (
InputKind.CONSTANT_TENSOR,
InputKind.CUSTOM_OBJ,
InputKind.BUFFER,
)
new_constants[input_target] = original_program.constants[
input_target
]
else:
raise RuntimeError(f"Invalid input spec {orig_input_spec} received")
else:
input_specs.append(
InputSpec(
kind=InputKind.USER_INPUT,
arg=TensorArgument(name=node.name),
target=None,
)
)
if node.op == "output":
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
for user in call_module_node.users.keys():
if user.name in toplevel_output_node_to_sig:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
getitem_idx = user.args[1]
assert isinstance(
getitem_idx, int
), f"Invalid getitem type: {type(getitem_idx)}"
buffer_mutation_idxs[getitem_idx].extend(
toplevel_output_node_to_sig[user.name]
)
for i, output_node in enumerate(node.args[0]):
if i in buffer_mutation_idxs:
assert isinstance(output_node, torch.fx.Node)
orig_output_specs = buffer_mutation_idxs[i]
if any(
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and (
orig_output_spec.target in new_state_dict
or orig_output_spec.target in new_constants
)
for orig_output_spec in orig_output_specs
):
# If the delegate wants to consume the buffer, then the
# delegate should also consume the buffer mutation
# (output spec would be a BUFFER_MUTATION). Otherwise
# the delegate will just return the result of the
# mutation as a USER_OUTPUT.
orig_output_spec = [
orig_output_spec
for orig_output_spec in orig_output_specs
if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and (
orig_output_spec.target in new_state_dict
or orig_output_spec.target in new_constants
)
][0]
assert len(orig_output_specs) == 1, (
f"Constant {orig_output_spec.target} was tagged to be "
"consumed by the buffer, and was found to also contain "
"a buffer mutation. However this buffer mutation node "
"was found to also be used as other types of outputs "
"which is currently not supported. Please file an "
"issue on Github. \n\n"
f"The toplevel program: {original_program}\n"
)
output_specs.append(
OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=TensorArgument(name=output_node.name),
target=orig_output_spec.target,
)
)
output_specs_to_delete[orig_output_spec.arg.name] = (
orig_output_spec
)
else:
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_node.name),
target=None,
)
)
elif not isinstance(output_node, torch.fx.Node):
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=ConstantArgument(name="", value=output_node),
target=None,
)
)
else:
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_node.name),
target=None,
)
)
new_signature = ExportGraphSignature(
input_specs=input_specs, output_specs=output_specs
)
return (
new_signature,
new_state_dict,
new_constants,
input_specs_to_delete,
output_specs_to_delete,
)
def create_exported_program_from_submodule(
submodule: torch.fx.GraphModule,
owning_program: ExportedProgram,
tag: str,
call_module_node: torch.fx.Node,
is_submodule: bool,
) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]:
"""
Creates an ExportedProgram from the given submodule using the parameters and buffers
from the top-level owning program
Args:
submodule: submodule to create and exported program from
owning_program: exported program containing the parameters and buffers used within
the submodule
Returns:
The ExportedProgram created from submodule
input_specs_to_delete (Dict[str, InputSpec]): The input specs that have
been consumed by the delegate (param/buffer input nodes) and should
be removed from the toplevel ExportedProgram.
output_specs_to_delete (Dict[str, InputSpec]): The output specs that have
been consumed by the delegate (buffer mutation nodes) and should be
removed from the toplevel ExportedProgram.
"""
# Arrange the submodule's placeholders in order
submodule = arrange_graph_placeholders(submodule, owning_program, tag)
# TODO: we probably need to arrange the outputs wrt buffer mutations.
# Get updated graph signature
(
subgraph_signature,
subgraph_state_dict,
subgraph_constants,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
) = _get_new_signature(
owning_program, submodule, call_module_node, tag, is_submodule
)
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]
return (
ExportedProgram(
root=submodule,
graph=submodule.graph,
graph_signature=subgraph_signature,
state_dict=subgraph_state_dict,
range_constraints=copy.deepcopy(owning_program.range_constraints),
module_call_graph=[
ModuleCallEntry(
"",
ModuleCallSignature(
inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec
),
)
],
constants=subgraph_constants,
verifiers=[owning_program.verifier],
),
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)
def create_submodule_from_nodes(
gm: torch.fx.GraphModule,
node_list: NodeList,
tag: str,
skip_legalize_graph: bool = False,
) -> Tuple[torch.fx.GraphModule, torch.fx.Node]:
"""
Modifies the given graph module in-place to separate out the given nodes
into a submodule. The given node_list should form a fully connected
subgraph.
Args:
gm: The graph module that we want to partition
node_list: A list of nodes that belong in the partition
Returns:
The submodule that has been partitioned, the call_module node in the
toplevel graph module calling the submodule
"""
sorted_nodes = topo_sort(node_list)
submodule_name = "fused_" + tag
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
gm, sorted_nodes, submodule_name
)
_fixup_output_node(sub_gm)
gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
submodule_node = None
for node in gm.graph.nodes:
if node.op == "call_module" and node.target == submodule_name:
submodule_node = node
if submodule_node is None:
raise RuntimeError(
f"The submodule created with nodes {node_list} did not form \
one fully contained subgraph. Check that these nodes form a \
fully contained graph. Partitioned graph: {gm.graph}."
)
if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
# If the original output is a single tensor, it has been
# pytree.tree_flatten-ed to be a singleton list, so we want to replace
# all uses with a getitem call to the 0th index of the result
with gm.graph.inserting_after(submodule_node):
proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index]
submodule_node.replace_all_uses_with(proxy_out)
proxy_out.meta["val"] = submodule_node.meta["val"]
# Reset the args since it was overwritten in the previous line
proxy_out.args = (submodule_node, 0)
else:
# fuse_as_graphmodule will automatically propagate the metadata of the
# partition's last node to the getitem nodes that appear after the
# call_module node. However, in the case of delegation we do not want
# these getitem nodes to contain irrelevant previous metadata
# (ex. source_fn, # nn_module_stack)
for user_node in submodule_node.users:
user_node.meta.pop("nn_module_stack", None)
user_node.meta.pop("source_fn_stack", None)
erase_nodes(gm, sorted_nodes)
# Topological sort original gm with newly created sub_gm
# TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes
# once we transition to using fuse_by_partitions.
if not skip_legalize_graph:
legalize_graph(gm)
# Get the call_module node
submodule_node = None
for node in gm.graph.nodes:
if node.op == "call_module" and node.target == submodule_name:
submodule_node = node
if submodule_node is None:
raise RuntimeError(
f"The submodule created with nodes {node_list} did not form \
one fully contained subgraph. Check that these nodes form a \
fully contained graph. Partitioned graph: {gm.graph}."
)
assert (
submodule_node is not None
), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
return sub_gm, submodule_node
def get_lowered_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]:
"""
Returns a list of lowered modules that are in the given graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the lowered module that's stored in the graph module, the
lowered module itself, and the fx node that called this lowered module).
"""
lowered_submodules = []
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
name, module, node = _get_submodule(graph_module, node, 0)
assert isinstance(module, LoweredBackendModule)
lowered_submodules.append((name, module, node))
return lowered_submodules
def get_lowered_backend_modules(
graph_module: torch.fx.GraphModule,
) -> List[LoweredBackendModule]:
"""
Returns a list of exported programs which were lowered by backen delegates
"""
lowered_programs = []
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_backend_module = getattr(graph_module, node.args[0].name)
lowered_programs.append(lowered_backend_module)
return lowered_programs
def _unsafe_adjust_original_program( # noqa: C901
original_program: ExportedProgram,
call_delegate_node: torch.fx.Node,
input_specs_to_delete: Dict[str, InputSpec],
output_specs_to_delete: Dict[str, OutputSpec],
) -> None:
"""
Directly modify the original exported program's signature and state dict
based on the consumed params/buffers in the delegate.
"""
original_program._graph_signature.input_specs = [
input_spec
for input_spec in original_program.graph_signature.input_specs
if input_spec.arg.name not in input_specs_to_delete
]
currently_used_targets: Set[str] = {
input_spec.target
for input_spec in original_program._graph_signature.input_specs
if input_spec.target is not None
}
original_program._graph_signature.output_specs = [
output_spec
for output_spec in original_program.graph_signature.output_specs
if output_spec.arg.name not in output_specs_to_delete
]
# Delete all parameters/buffers consumed by the created exported program
# from the graph signature, state dict, constants table
for node in original_program.graph.nodes:
if node.op == "placeholder":
if node.name in input_specs_to_delete:
assert len(node.users) == 0
original_program.graph.erase_node(node)
else:
break
for input_spec in input_specs_to_delete.values():
input_target = input_spec.target
assert input_target is not None
if input_target in currently_used_targets:
continue
if input_spec.kind == InputKind.PARAMETER:
del original_program._state_dict[input_target]
elif input_spec.kind == InputKind.BUFFER:
if input_spec.persistent:
original_program._state_dict.pop(input_target, None)
else:
del original_program._constants[input_spec.target]
elif input_spec.kind == InputKind.CONSTANT_TENSOR:
del original_program._constants[input_spec.target]
else:
raise RuntimeError(f"Invalid input spec {input_spec} received")
# Delete buffer mutations from the output which were consumed by the delegate
toplevel_output_node = original_program.graph.output_node()
assert toplevel_output_node is not None
assert (
len(toplevel_output_node.args) == 1
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"
new_output_args = [
arg
for arg in toplevel_output_node.args[0]
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
]
toplevel_output_node.args = (tuple(new_output_args),)
# Delete the buffer mutation getitem nodes
getitem_idxs: List[int] = []
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
if user.name in output_specs_to_delete:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
user_idx = user.args[1]
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
getitem_idxs.append(user_idx)
original_program.graph.erase_node(user)
getitem_idxs.sort(reverse=True)
# Adjust all the getitem indices after the deleted getitems
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
assert user.op == "call_function" and user.target == operator.getitem
user_idx = user.args[1]
assert isinstance(user_idx, int)
for i, idx in enumerate(getitem_idxs):
if user_idx > idx:
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
break