forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
386 lines (322 loc) · 13.6 KB
/
utils.py
File metadata and controls
386 lines (322 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from executorch.exir import ExportedProgram
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch._subclasses.fake_tensor import FakeTensorConverter
from torch.export.graph_signature import (
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
def _get_fake_tensor_mode(graph: torch.fx.Graph, data: torch.Tensor) -> torch.Tensor:
"""
Helper function to create a fake tensor using the fake_mode from existing nodes in the graph.
Args:
graph: The graph to get fake_mode from
data: The tensor data to create fake tensor for
Returns:
A fake tensor with the appropriate fake_mode
Raises:
RuntimeError: If the graph has no nodes to extract fake_mode from
"""
nodes = list(graph.nodes)
if not nodes:
raise RuntimeError(
"Cannot create fake tensor: graph has no nodes to extract fake_mode from"
)
example_node = nodes[0]
if isinstance(
example_node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
example_fake_tensor = example_node.meta["val"][0]
else:
example_fake_tensor = example_node.meta["val"]
return FakeTensorConverter().from_real_tensor(example_fake_tensor.fake_mode, t=data)
def is_get_attr_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_get_attr_node(node)
or is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)
def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")
def create_constant_placeholder(
exp_program: ExportedProgram,
graph: torch.fx.Graph,
name: str,
kind: InputKind,
data: torch.Tensor,
persistent_buffer: Optional[bool] = None,
) -> torch.fx.Node:
"""
Creates and returns a constant placeholder node, meaning that it is of type parameter, buffer,
or lifted constant tensor. graph.inserting_before/after() should be used before the call to
decide where to insert the node, at an insertion point before the first input node.
"""
target = name
# If a placeholder with this target already exists, return it to avoid
# duplicate parameter names in the generated function signature which would
# cause a SyntaxError on recompile. This can happen when multiple pattern
# replacements independently create placeholders for a shared weight.
if name in exp_program.state_dict or name in exp_program.constants:
for n in graph.nodes:
if n.op == "placeholder" and n.target == name:
return n
# Add data to state_dict/ constants
match kind:
case InputKind.PARAMETER:
exp_program.state_dict[target] = torch.nn.Parameter(
data, requires_grad=False
)
case InputKind.BUFFER:
if persistent_buffer is None:
raise RuntimeError(
"Must set persistent_buffer when creating a new buffer."
)
elif persistent_buffer:
exp_program.state_dict[target] = data
else:
exp_program.constants[target] = data
case InputKind.CONSTANT_TENSOR:
exp_program.constants[target] = data
case _:
raise RuntimeError("Can only create constant input nodes.")
fake_tensor = _get_fake_tensor_mode(graph, data)
# Create node
node = graph.create_node(op="placeholder", name=name, target=name)
node.meta["val"] = fake_tensor
# Add tensor to graph_signature in the same order as nodes in the graph
node_names = [n.name for n in graph.nodes if n.op == "placeholder"]
node_index = node_names.index(name)
input_specs = exp_program.graph_signature.input_specs
user_input_indices = [
i for i, spec in enumerate(input_specs) if spec.kind == InputKind.USER_INPUT
]
if not all(
(user_input_index >= node_index for user_input_index in user_input_indices)
):
raise RuntimeError(
f"Failed to insert {name}; Const placeholder nodes must be inserted before user input nodes in the graph."
)
arg_spec = TensorArgument(name)
input_spec = InputSpec(kind, arg_spec, target, persistent_buffer)
input_specs.insert(node_index, input_spec)
new_graph_signature = ExportGraphSignature(
input_specs, exp_program.graph_signature.output_specs
)
exp_program._graph_signature = new_graph_signature
return node
def delete_constant_placeholder(exp_program: ExportedProgram, node: torch.fx.Node):
"""
Deletes a node of type parameter, buffer, or lifted constant tensor and its related
graph signature and state_dict/constant entries. The node may not have any users.
"""
if not len(node.users) == 0:
raise RuntimeError(
f"Cannot delete input node {node.name} since it has users in the graph."
)
# Remove tensor from state_dict/ constants
if node.name in exp_program.graph_signature.inputs_to_parameters:
target = exp_program.graph_signature.inputs_to_parameters[node.name]
del exp_program.state_dict[target]
elif node.name in exp_program.graph_signature.inputs_to_buffers:
target = exp_program.graph_signature.inputs_to_buffers[node.name]
if target in exp_program.graph_signature.non_persistent_buffers:
del exp_program.constants[target]
else:
del exp_program.state_dict[target]
elif node.name in exp_program.graph_signature.inputs_to_lifted_tensor_constants:
target = exp_program.graph_signature.inputs_to_lifted_tensor_constants[
node.name
]
del exp_program.constants[target]
else:
raise RuntimeError(
f"Cannot delete input node {node.name} since it is not a parameter, a buffer, nor a lifted tensor constant."
)
# Remove input from graph signature
input_specs = [
spec
for spec in exp_program.graph_signature.input_specs
if spec.arg.name != node.name
]
new_graph_signature = ExportGraphSignature(
input_specs, exp_program.graph_signature.output_specs
)
exp_program._graph_signature = new_graph_signature
# Remove node from graph
node.graph.erase_node(node)
def _validate_graph_signature(exp_program: ExportedProgram):
"""
Validates that the graph signature is up to date with the graph.
"""
placeholders = [n for n in exp_program.graph.nodes if n.op == "placeholder"]
if len(placeholders) != len(exp_program.graph_signature.input_specs):
raise RuntimeError(
f"Graph has {len(placeholders)} placeholder nodes but signature has "
f"{len(exp_program.graph_signature.input_specs)} input specs"
)
for node, input_spec in zip(placeholders, exp_program.graph_signature.input_specs):
if node.name != input_spec.arg.name:
raise RuntimeError(
f"Input node {node.name} does not match input spec {input_spec.arg.name}"
)
outputs = exp_program.graph.output_node().args[0]
if len(outputs) != len(exp_program.graph_signature.output_specs):
raise RuntimeError(
f"Graph has {len(outputs)} output nodes but signature has "
f"{len(exp_program.graph_signature.output_specs)} output specs"
)
for node, output_spec in zip(outputs, exp_program.graph_signature.output_specs):
if node.name != output_spec.arg.name:
raise RuntimeError(
f"Output node {node.name} does not match output spec {output_spec.arg.name}"
)
def _spec_to_node(
exp_program: ExportedProgram, spec: InputSpec | OutputSpec
) -> torch.fx.Node:
"""
Converts an InputSpec or OutputSpec to its corresponding node in the graph.
"""
# Extract the argument name from the spec
if hasattr(spec, "arg") and hasattr(spec.arg, "name"):
arg_name = spec.arg.name
else:
raise RuntimeError(f"Invalid spec format: {spec}")
# Find the corresponding node in the graph
for node in exp_program.graph.nodes:
if node.name == arg_name:
return node
raise RuntimeError(f"Could not find node with name '{arg_name}' in the graph")
def create_mutable_buffer(
exp_program: ExportedProgram,
name: str,
data: torch.Tensor,
) -> torch.fx.Node:
"""
Creates and returns a mutable buffer placeholder node. This is similar to
create_constant_placeholder but specifically for creating mutable buffers that
can be modified during execution.
The difference between this and create_constant_placeholder is that this doesn't
expect user to set the correct position for the placeholder node to be inserted,
it finds the correct position automatically.
It also updates the graph outputs to include the mutable buffer.
Args:
exp_program: The exported program to modify
name: The name for the new buffer node (should start with "b_" prefix by convention)
data: The initial tensor data for the buffer
Returns:
The created placeholder node to be used in the graph
"""
# Input validation
if not name or not name.strip():
raise ValueError("Buffer name cannot be empty")
if not isinstance(data, torch.Tensor):
raise ValueError("Data must be a torch.Tensor")
# Extract target name (remove "b_" prefix if present, following export convention)
if name.startswith("b_"):
target = name[2:]
else:
target = name
# Check if target already exists
if target in exp_program.state_dict:
raise RuntimeError(f"Buffer target '{target}' already exists in state_dict")
_validate_graph_signature(exp_program)
persistent_buffer = True
exp_program.state_dict[target] = data
graph = exp_program.graph_module.graph
# Create fake tensor using helper function
fake_tensor = _get_fake_tensor_mode(graph, data)
# Signature ordering is as follows:
# Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
# ^^^^^^^
# insert here (at the end of buffers)
# Outputs = [*mutated_inputs, *flattened_user_outputs]
# ^^^^^^^^^^^^^^^
# insert here (at the end of mutated inputs)
# Inputs
# Find const or user input node if any, and insert before it
node_index = 0
node = None
input_specs = exp_program.graph_signature.input_specs
if len(input_specs) == 0 or all(
spec.kind not in [InputKind.CONSTANT_TENSOR, InputKind.USER_INPUT]
for spec in input_specs
):
# No const or user input nodes
node_index = len(input_specs)
node = graph.create_node(op="placeholder", name=name, target=name)
else:
# Find the first constant or user input node
for i, spec in enumerate(input_specs):
if spec.kind in [InputKind.CONSTANT_TENSOR, InputKind.USER_INPUT]:
node_index = i
with graph.inserting_before(_spec_to_node(exp_program, spec)):
node = graph.create_node(op="placeholder", name=name, target=name)
break
assert node is not None, "node should be created at this point"
node.meta["val"] = fake_tensor
buffer_input_spec = InputSpec(
InputKind.BUFFER, TensorArgument(name), target, persistent_buffer
)
input_specs.insert(node_index, buffer_input_spec)
# Outputs
# Create output spec for the mutable buffer, and insert it at the beginning of output specs
user_output_indices = [
i
for i, spec in enumerate(exp_program.graph_signature.output_specs)
if spec.kind == OutputKind.USER_OUTPUT
]
output_index = user_output_indices[0] if user_output_indices else 0
output_specs = exp_program.graph_signature.output_specs
mutation_output_spec = OutputSpec(
OutputKind.BUFFER_MUTATION, TensorArgument(name), target
)
output_specs.insert(output_index, mutation_output_spec)
# Update the outputs to include the mutable buffer
output_node = graph.output_node()
args = list(output_node.args[0])
args.insert(output_index, node)
output_node.args = (args,)
# Update graph signature in the exported program
new_graph_signature = ExportGraphSignature(input_specs, output_specs)
exp_program._graph_signature = new_graph_signature
return node