Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -24,6 +24,7 @@
from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir import ExportedProgram

from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -32,6 +33,13 @@
from torch.fx import GraphModule, Node


def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
if qspec.dtype == torch.int8:
if qspec.qmax == 7 and qspec.qmin == -7:
return TosaSpecialDtype.INT4
return None


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'.
Expand Down Expand Up @@ -157,6 +165,11 @@ def fold_and_annotate_arg(
node.replace_input_with(n, cast(Node, n.args[0]))
if len(n.users) == 0:
graph_module.graph.erase_node(n)
special_dtype = _get_special_dtype(input_qparams)
if special_dtype:
node.all_input_nodes[i].meta[
TosaSpecialDtype.meta_key()
] = special_dtype

def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
"""Fold outmost quant nodes inside submodule.
Expand Down
22 changes: 21 additions & 1 deletion backends/arm/_passes/fuse_constant_ops_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -18,6 +18,7 @@
from executorch.backends.arm._passes.fuse_equal_placeholders_pass import (
FuseEqualPlaceholdersPass,
)
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
Expand Down Expand Up @@ -52,6 +53,23 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.exported_program = exported_program

def _propagate_special_dtype(self, from_nodes, to_node, data):
"""Propagate special dtype meta if it exists."""
special_dtypes = set()
for input_node in from_nodes:
special_type = input_node.meta.get(TosaSpecialDtype.meta_key(), None)
if special_type:
special_dtypes.add(special_type)
if len(special_dtypes) > 1:
logger.warning(
"Propagating mixed special dtypes is not implemented, skipping."
)
elif len(special_dtypes) == 1:
special_dtype = list(special_dtypes)[0]
# Make sure data is still within special dtype range.
if data.abs().max() <= special_dtype.max():
to_node.meta[TosaSpecialDtype.meta_key()] = special_dtype

def _fuse_nodes(self, node) -> bool:
"""
Takes a node with only parameter inputs and replaces it with one constant tensor node with
Expand Down Expand Up @@ -105,6 +123,8 @@ def resolve_arg(arg):
persistent_buffer=persistent_buffer,
)

self._propagate_special_dtype(input_nodes, const_node, data)

node.replace_all_uses_with(const_node)

return True
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/fuse_equal_placeholders_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -53,11 +53,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

# ensure we don't merge any special case int48_t tensors with int32_t tensors
# since int48_t tensors needs to be instantiated separately.
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None)
t_cpu = tensor.detach().cpu().contiguous()
data_bytes = t_cpu.numpy().tobytes()
key = (
is_int48,
is_special_dtype,
str(t_cpu.dtype),
tuple(t_cpu.shape),
hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(),
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/ethosu/compile_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
compiler_flags.append(f"--memory-mode={memory_mode}")

# Set TOSA version.
base_tosa_version = "TOSA-1.0+INT+int16"
base_tosa_version = "TOSA-1.0+INT+int16+int4"
if "u55" in target_lower:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/ops_identity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -53,6 +53,8 @@ def define_node(
supported_dtypes += [ts.DType.FP32]
if self.tosa_spec.support_extension("int16"):
supported_dtypes += [ts.DType.INT48]
if self.tosa_spec.support_extension("int4"):
supported_dtypes += [ts.DType.INT4]
validate_valid_dtype(
self.target,
[inputs[0], output],
Expand Down
17 changes: 3 additions & 14 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -12,7 +12,7 @@
import torch.fx
import tosa_serializer as ts
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.arm.tosa.utils import tosa_shape
from torch._export.utils import (
Expand Down Expand Up @@ -116,21 +116,10 @@ def process_inputs_to_parameters(
)
parameter_values = parameter_data.detach().numpy()

if tosa_arg.dtype == torch.float32:
if not tosa_spec.support_float():
raise ValueError(f"{tosa_spec} doesn't support float operations")

# Handle special case for INT48 tensors
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
if isinstance(special_type, TosaSpecialDtype):
tosa_dtype = special_type.get_tosa_dtype()
else:
tosa_dtype = tosa_arg.dtype

parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)

tosa_graph.addConst(
parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
)


Expand Down
10 changes: 9 additions & 1 deletion backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -180,6 +180,14 @@ def get_symmetric_quantization_config(
return quantization_config


def get_symmetric_a8w4_quantization_config(
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
):
return get_symmetric_quantization_config(
is_per_channel, is_qat, is_dynamic, weight_qmin=-7, weight_qmax=7
)


@functools.lru_cache
def get_symmetric_a16w8_quantization_config(
is_per_channel: bool = True,
Expand Down
81 changes: 77 additions & 4 deletions backends/arm/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,6 +7,9 @@
from typing import List, Tuple, Union

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a8w4_quantization_config,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
Expand All @@ -17,6 +20,7 @@
VgfPipeline,
)


aten_op = "torch.ops.aten.conv2d.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default"

Expand Down Expand Up @@ -162,8 +166,8 @@ def forward(self, x):
batches=1,
)

conv2d_2x2_1x1x14x13_st2 = Conv2d(
in_channels=1,
conv2d_2x2_2x1x14x13_st2 = Conv2d(
in_channels=2,
out_channels=1,
kernel_size=(2, 2),
stride=2,
Expand Down Expand Up @@ -363,7 +367,7 @@ def forward(self, x):
"3x3_1x3x24x24_st1": lambda: conv2d_3x3_1x3x24x24_st1,
"3x3_1x3x12x12_st2_pd1": lambda: conv2d_3x3_1x3x12x12_st2_pd1,
"1x1_1x2x16x16_st1": lambda: conv2d_1x1_1x2x16x16_st1,
"2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_1x1x14x13_st2,
"2x2_2x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_2x1x14x13_st2,
"5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv2d_5x5_1x3x14x15_st3_pd1,
"7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv2d_7x7_1x3x16x16_st2_pd1_dl2,
"7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": lambda: conv2d_7x7_1x3x15x15_st1_pd0_dl1,
Expand Down Expand Up @@ -391,6 +395,15 @@ def forward(self, x):
input_t = Tuple[torch.Tensor]


def _get_dtype_count(model: torch.nn.Module):
nbr_convs: int = model.nbr_convs # noqa
return {
"CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp.
"CONV2D": {"INT32": nbr_convs},
"RESCALE": {"INT8": nbr_convs},
}


@common.parametrize("test_data", test_data_FP)
def test_convolution_2d_tosa_FP(test_data):
model = test_data()
Expand All @@ -417,6 +430,36 @@ def test_convolution_2d_tosa_INT(test_data):
pipeline.run()


@common.parametrize(
"test_data",
test_data_INT,
xfails={
"groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
},
)
def test_convolution_2d_tosa_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = TosaPipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
tosa_extensions=["int4"],
)
pipeline.quantizer.set_global(
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
)
pipeline.add_stage_after(
"to_edge_transform_and_lower",
pipeline.tester.check_dtype_count,
_get_dtype_count(model),
)
pipeline.run()


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone300
def test_convolution_2d_u55_INT(test_data):
Expand All @@ -431,6 +474,21 @@ def test_convolution_2d_u55_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_INT)
def test_convolution_2d_u55_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU55PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
)
pipeline.quantizer.set_global(
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
)
pipeline.run()


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone320
def test_convolution_u85_INT(test_data):
Expand All @@ -445,6 +503,21 @@ def test_convolution_u85_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_INT)
def test_convolution_2d_u85_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU85PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
)
pipeline.quantizer.set_global(
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
)
pipeline.run()


@common.parametrize("test_data", test_data_FP)
@common.SkipIfNoModelConverter
def test_convolution_2d_vgf_no_quant(test_data):
Expand Down
Loading
Loading