Skip to content

Commit f8c877f

Browse files
authored
Merge branch 'main' into op-select-scatter
2 parents 79cd4d5 + 4860984 commit f8c877f

33 files changed

Lines changed: 447 additions & 158 deletions

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113

114114
from executorch.backends.arm._passes.arm_pass import ArmPass
115115
from executorch.backends.arm.tosa.specification import (
116+
tosa_spec_in_set,
116117
TosaLoweringContext,
117118
TosaSpecification,
118119
)
@@ -309,16 +310,20 @@ def transform_to_backend_pipeline(
309310
self, exported_program: ExportedProgram, graph_module: GraphModule
310311
):
311312
"""Apply passes before transforming program to backend"""
312-
if self.tosa_spec in (
313-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
314-
TosaSpecification.create_from_string("TOSA-1.0+INT"),
313+
314+
if not tosa_spec_in_set(
315+
self.tosa_spec,
316+
{
317+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
318+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
319+
},
315320
):
316-
return self._tosa_pipeline(exported_program, graph_module)
317-
else:
318-
raise NotImplementedError(
319-
f"No pass pipeline implemented for {self.tosa_spec}"
321+
raise RuntimeError(
322+
f"No pass pipeline found for TOSA specification: {self.tosa_spec}"
320323
)
321324

325+
return self._tosa_pipeline(exported_program, graph_module)
326+
322327
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
323328
# Preprocessing passes
324329
self.add_pass(RemoveGraphAssertsPass())

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
9090
return False
9191
groups = node.args[-1]
9292
in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1]
93-
out_channels = get_first_fake_tensor(node.all_input_nodes[1]).shape[0]
93+
out_channels = get_first_fake_tensor(node).shape[1]
9494
return (in_channels == groups) and (out_channels % in_channels) == 0
9595

9696
def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None:
@@ -103,6 +103,7 @@ def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None
103103
raise RuntimeError(
104104
f"Weight node {weight_node.name} is not a parameter or buffer"
105105
)
106+
106107
reshaped_weight_tensor = (
107108
weight_tensor.permute(HWCM_ORDER)
108109
.reshape(
@@ -118,14 +119,19 @@ def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None
118119
param_name = self.exported_program.graph_signature.inputs_to_buffers[
119120
weight_node.name
120121
]
122+
reshaped_weight_tensor = torch.nn.Buffer(reshaped_weight_tensor)
121123
elif is_param(self.exported_program, weight_node):
122124
param_name = self.exported_program.graph_signature.inputs_to_parameters[
123125
weight_node.name
124126
]
127+
reshaped_weight_tensor = torch.nn.Parameter(
128+
reshaped_weight_tensor, requires_grad=False
129+
)
125130
else:
126131
raise RuntimeError(
127132
f"Weight node {weight_node.name} is neither a parameter nor a buffer"
128133
)
134+
129135
self.exported_program.state_dict[param_name] = reshaped_weight_tensor
130136
weight_node.meta["val"] = weight_node.meta["val"].reshape(
131137
weight_tensor.shape[2],
@@ -243,7 +249,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
243249

244250
if self._is_depthwise_conv2d(node):
245251
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
246-
self._reshape_weights(weight, input_fake_tensor.shape[1])
252+
# If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them.
253+
if all(user.target != target_op for user in weight.users):
254+
self._reshape_weights(weight, input_fake_tensor.shape[1])
247255
weight_fake_tensor = get_first_fake_tensor(weight)
248256
else:
249257
target_op = exir_ops.backend.tosa.CONV2D.default

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ def __init__(self, reporter: WhyNoPartitionReporter):
7878

7979
targeted_ops_i8_i16_i32 = [
8080
exir_ops.edge.aten.cat.default,
81+
exir_ops.edge.aten.expand_copy.default,
8182
exir_ops.edge.aten.repeat.default,
8283
exir_ops.edge.aten.constant_pad_nd.default,
8384
exir_ops.edge.aten.view.default,
8485
exir_ops.edge.aten.permute.default,
86+
exir_ops.edge.aten.permute_copy.default,
8587
]
8688

8789
target_ops_i8 = tuple(TableOps.included_ops())

backends/arm/operator_support/slice_copy_support.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def is_node_tosa_supported(
4141
non-unit step sizes.
4242
4343
"""
44-
if tosa_spec not in self.tosa_specs:
45-
return False
46-
4744
args = node.args
4845
if len(args) == 5 and (step := args[4]) != 1:
4946
logger.warning(f"{node.target} with step size of {step} not supported.")

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,61 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
146146
return checker
147147

148148

149+
def _is_quantized_constant(node: torch.fx.Node) -> bool:
150+
if node.target not in (
151+
exir_ops.edge.aten.full_like.default,
152+
*ComputeConstantOpsAOTPass.targeted_ops,
153+
):
154+
return False
155+
156+
users = tuple(node.users)
157+
if users and all(user.target in Q_OPS for user in users):
158+
# The node feeds directly into only quantized ops.
159+
return True
160+
161+
for user in users:
162+
if user.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
163+
dim_order_dtype = get_first_fake_tensor(user).dtype
164+
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
165+
return False
166+
else:
167+
return False
168+
169+
return len(users) > 0
170+
171+
172+
def is_quantized(node: torch.fx.Node) -> bool:
173+
"""Checks if the node is quantized.
174+
175+
A node is considered quantized if any of the following is true:
176+
- Its output dtype is not floating point or complex => integer
177+
- It is an op that produces a constant that in turn feeds only quantized users
178+
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
179+
180+
Args:
181+
node (torch.fx.Node): The FX node to check.
182+
183+
Returns:
184+
bool: True if the node is quantized, False otherwise.
185+
"""
186+
187+
node_dtype = get_first_fake_tensor(node).dtype
188+
# Integer-like dtype implies the node is already quantized.
189+
if not node_dtype.is_complex and not node_dtype.is_floating_point:
190+
return True
191+
192+
# Nodes introduced during lowering that exclusively feed quantized users.
193+
if _is_quantized_constant(node):
194+
return True
195+
196+
# Finally, fall back to the explicit annotation emitted by Arm passes.
197+
custom_meta = node.meta.get("custom", {})
198+
if ArmAnnotationInfo.CUSTOM_META_KEY in custom_meta:
199+
return custom_meta[ArmAnnotationInfo.CUSTOM_META_KEY]["quantized"]
200+
201+
return False
202+
203+
149204
def get_registered_tosa_support_checks(
150205
tosa_spec: TosaSpecification,
151206
) -> list[Type[SupportedTOSAOperatorCheck]]:
@@ -194,9 +249,11 @@ def tosa_support_factory(
194249
ControlFlowOpSupported(exported_program, tosa_spec, reporter),
195250
]
196251

197-
if tosa_spec.support_integer():
252+
if tosa_spec.support_integer() and tosa_spec.support_float():
253+
positive_checks.append(TOSAProINTFPSupportList())
254+
elif tosa_spec.support_integer():
198255
positive_checks.append(TOSAProINTSupportList())
199-
if tosa_spec.support_float():
256+
elif tosa_spec.support_float():
200257
positive_checks.append(TOSAProFPSupportList())
201258
# TODO: Refactor to use TOSAProSupportLists + negtive checks
202259
positive_checks += [
@@ -268,6 +325,27 @@ def is_node_supported(
268325
return node.op == "call_function" and node.target in TOSA_PRO_FP_SupportList
269326

270327

328+
class TOSAProINTFPSupportList(OperatorSupportBase):
329+
"""
330+
TOSA_PRO_INT_FP_SupportList:
331+
Ops supported in INT+FP profile via native TOSA ops, decomposition/transformation, pre-compute, or TableOp.
332+
"""
333+
334+
def is_node_supported(
335+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
336+
) -> bool:
337+
if node.op != "call_function":
338+
return False
339+
340+
# Select list based on whether the node is quantized.
341+
if is_quantized(node) or node.target in (*Q_OPS, *DQ_OPS):
342+
support_list = TOSA_PRO_INT_SupportList
343+
else:
344+
support_list = TOSA_PRO_FP_SupportList
345+
346+
return node.target in support_list
347+
348+
271349
class CheckArmQuantized(OperatorSupportBase):
272350
"""
273351
Check if the node was marked as quantized in the Arm backend.
@@ -278,60 +356,14 @@ class CheckArmQuantized(OperatorSupportBase):
278356
def __init__(self, reporter: WhyNoPartitionReporter):
279357
self.reporter = reporter
280358

281-
def _is_quantized(self, node: torch.fx.Node) -> bool:
282-
"""Checks if the node is quantized.
283-
284-
A node is considered quantized if at least one criteria is met:
285-
- Its dtype is not floating point or complex => integer
286-
- It is one of the special cases where the node has been created in to_edge, e.g.
287-
.Scalar operations that have been promoted .Tensor operations
288-
where the scalar is replaced by a full op.
289-
- It has been marked as quantized in the ArmAnnotationInfo custom meta.
290-
291-
Args:
292-
node (torch.fx.Node): The FX node to check.
293-
294-
Returns:
295-
bool: True if the node is quantized, False otherwise.
296-
"""
297-
node_dtype = get_first_fake_tensor(node).dtype
298-
if not node_dtype.is_complex and not node_dtype.is_floating_point:
299-
return True
300-
if node.target in (
301-
exir_ops.edge.aten.full_like.default,
302-
*ComputeConstantOpsAOTPass.targeted_ops,
303-
):
304-
# Special cases where nodes have been created in to_edge, e.g.
305-
# .Scalar operations that have been promoted .Tensor operations
306-
# where the scalar is replaced by a full op.
307-
if all(user.target in Q_OPS for user in node.users):
308-
return True
309-
for user in node.users:
310-
if (
311-
user.target
312-
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
313-
):
314-
dim_order_dtype = get_first_fake_tensor(user).dtype
315-
if dim_order_dtype.is_complex or dim_order_dtype.is_floating_point:
316-
return False
317-
else:
318-
return False
319-
return True
320-
return (
321-
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
322-
and ArmAnnotationInfo(
323-
node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY]
324-
).quantized
325-
)
326-
327359
def is_node_supported(
328360
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
329361
) -> bool:
330362

331363
if node.target in (*DQ_OPS, *Q_OPS):
332364
return True
333365

334-
if not self._is_quantized(node):
366+
if not is_quantized(node):
335367
self.reporter.report_reject(
336368
node, "Node was not marked as quantized in the Arm backend."
337369
)

backends/arm/operators/op_avg_pool2d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,13 @@ def define_node(
115115
) -> None:
116116
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
117117
validate_same_dtype(self.target, [inputs[0], output], ts)
118+
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
119+
if self.tosa_spec.support_extension("int16"):
120+
supported_dtypes.append(ts.DType.INT16)
118121
validate_valid_dtype(
119122
self.target,
120123
[inputs[0], output],
121-
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
124+
supported_dtypes,
122125
output.tosa_spec,
123126
)
124127

backends/arm/operators/op_cat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
)
1515
from executorch.backends.arm.operators.operator_validation_utils import (
1616
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
1719
)
1820
from executorch.backends.arm.tosa.mapping import TosaArg
1921
from torch.fx import Node
@@ -35,9 +37,19 @@ def define_node(
3537
inputs: List[TosaArg],
3638
output: TosaArg,
3739
) -> None:
40+
supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32]
41+
if self.tosa_spec.support_extension("int16"):
42+
supported_dtypes.append(ts.DType.INT16)
3843
validate_num_inputs(self.target, inputs, [1, 2])
44+
input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special]
45+
validate_same_dtype(self.target, [*input_tosa_args, output], ts)
46+
validate_valid_dtype(
47+
self.target,
48+
[*input_tosa_args, output],
49+
supported_dtypes,
50+
output.tosa_spec,
51+
)
3952

40-
tensors = inputs[0].special
4153
dim = 0 if len(inputs) < 2 else inputs[1].number
4254
rank = len(output.shape)
4355
dim = (dim + rank) % rank
@@ -50,7 +62,7 @@ def define_node(
5062
node,
5163
tosa_graph,
5264
ts.Op.CONCAT,
53-
[tensor.name for tensor in tensors],
65+
[tensor.name for tensor in input_tosa_args],
5466
[output.name],
5567
attr,
5668
)

backends/arm/operators/op_clamp.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,18 @@ def define_node(
8787
) -> None:
8888
validate_num_inputs(self.target, inputs, [2, 3])
8989
validate_same_dtype(self.target, [inputs[0], output], ts)
90+
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32]
91+
if self.tosa_spec.support_extension("int16"):
92+
supported_dtypes.append(ts.DType.INT16)
9093
validate_valid_dtype(
9194
self.target,
9295
[inputs[0], output],
93-
[
94-
ts.DType.INT8,
95-
ts.DType.INT16,
96-
ts.DType.FP16,
97-
ts.DType.FP32,
98-
],
96+
supported_dtypes,
9997
output.tosa_spec,
10098
)
10199

102100
node_input_dtype = node.meta["val"].dtype
103-
# NOTE: Quantization of the min/max arguments is handled by QuantizeClampArgumentsPass
101+
# NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments
104102
min_val, max_val = self._get_min_max_arguments(node, node_input_dtype)
105103

106104
attr = ts.TosaSerializerAttribute()

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def define_node(
4747
validate_valid_dtype(
4848
self.target,
4949
inputs,
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
50+
[ts.DType.INT32, ts.DType.FP32],
5151
output.tosa_spec,
5252
)
5353
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)

0 commit comments

Comments
 (0)