diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 0f99d6cbbdf..0ecb7ff2070 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -24,6 +24,7 @@ from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -32,6 +33,13 @@ from torch.fx import GraphModule, Node +def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None: + if qspec.dtype == torch.int8: + if qspec.qmax == 7 and qspec.qmin == -7: + return TosaSpecialDtype.INT4 + return None + + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. @@ -157,6 +165,11 @@ def fold_and_annotate_arg( node.replace_input_with(n, cast(Node, n.args[0])) if len(n.users) == 0: graph_module.graph.erase_node(n) + special_dtype = _get_special_dtype(input_qparams) + if special_dtype: + node.all_input_nodes[i].meta[ + TosaSpecialDtype.meta_key() + ] = special_dtype def _handle_control_flow_node(self, node: Node, graph_module: GraphModule): """Fold outmost quant nodes inside submodule. diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 79ce4ec8848..c29603d0b4c 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -18,6 +18,7 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -52,6 +53,23 @@ def __init__(self, exported_program: ExportedProgram, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.exported_program = exported_program + def _propagate_special_dtype(self, from_nodes, to_node, data): + """Propagate special dtype meta if it exists.""" + special_dtypes = set() + for input_node in from_nodes: + special_type = input_node.meta.get(TosaSpecialDtype.meta_key(), None) + if special_type: + special_dtypes.add(special_type) + if len(special_dtypes) > 1: + logger.warning( + "Propagating mixed special dtypes is not implemented, skipping." + ) + elif len(special_dtypes) == 1: + special_dtype = list(special_dtypes)[0] + # Make sure data is still within special dtype range. + if data.abs().max() <= special_dtype.max(): + to_node.meta[TosaSpecialDtype.meta_key()] = special_dtype + def _fuse_nodes(self, node) -> bool: """ Takes a node with only parameter inputs and replaces it with one constant tensor node with @@ -105,6 +123,8 @@ def resolve_arg(arg): persistent_buffer=persistent_buffer, ) + self._propagate_special_dtype(input_nodes, const_node, data) + node.replace_all_uses_with(const_node) return True diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 6c3f9dde99e..37cac8a8c56 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,11 +53,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # ensure we don't merge any special case int48_t tensors with int32_t tensors # since int48_t tensors needs to be instantiated separately. - is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None) + is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( - is_int48, + is_special_dtype, str(t_cpu.dtype), tuple(t_cpu.shape), hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(), diff --git a/backends/arm/ethosu/compile_spec.py b/backends/arm/ethosu/compile_spec.py index 8f6d6284f74..1d311cbf74c 100644 --- a/backends/arm/ethosu/compile_spec.py +++ b/backends/arm/ethosu/compile_spec.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -73,7 +73,7 @@ def __init__( compiler_flags.append(f"--memory-mode={memory_mode}") # Set TOSA version. - base_tosa_version = "TOSA-1.0+INT+int16" + base_tosa_version = "TOSA-1.0+INT+int16+int4" if "u55" in target_lower: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index a7ffd4eacca..0930d7e7997 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -53,6 +53,8 @@ def define_node( supported_dtypes += [ts.DType.FP32] if self.tosa_spec.support_extension("int16"): supported_dtypes += [ts.DType.INT48] + if self.tosa_spec.support_extension("int4"): + supported_dtypes += [ts.DType.INT4] validate_valid_dtype( self.target, [inputs[0], output], diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 5a1d563ee0b..b85b1b43013 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -12,7 +12,7 @@ import torch.fx import tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype +from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape from torch._export.utils import ( @@ -116,21 +116,10 @@ def process_inputs_to_parameters( ) parameter_values = parameter_data.detach().numpy() - if tosa_arg.dtype == torch.float32: - if not tosa_spec.support_float(): - raise ValueError(f"{tosa_spec} doesn't support float operations") - - # Handle special case for INT48 tensors - special_type = node.meta.get(TosaSpecialDtype.meta_key(), None) - if isinstance(special_type, TosaSpecialDtype): - tosa_dtype = special_type.get_tosa_dtype() - else: - tosa_dtype = tosa_arg.dtype - parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name + parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 425fea0987b..28cef0d95ca 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -180,6 +180,14 @@ def get_symmetric_quantization_config( return quantization_config +def get_symmetric_a8w4_quantization_config( + is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False +): + return get_symmetric_quantization_config( + is_per_channel, is_qat, is_dynamic, weight_qmin=-7, weight_qmax=7 + ) + + @functools.lru_cache def get_symmetric_a16w8_quantization_config( is_per_channel: bool = True, diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 2b86ea6a5c4..55eee293f95 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,6 +7,9 @@ from typing import List, Tuple, Union import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a8w4_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, @@ -17,6 +20,7 @@ VgfPipeline, ) + aten_op = "torch.ops.aten.conv2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" @@ -162,8 +166,8 @@ def forward(self, x): batches=1, ) -conv2d_2x2_1x1x14x13_st2 = Conv2d( - in_channels=1, +conv2d_2x2_2x1x14x13_st2 = Conv2d( + in_channels=2, out_channels=1, kernel_size=(2, 2), stride=2, @@ -363,7 +367,7 @@ def forward(self, x): "3x3_1x3x24x24_st1": lambda: conv2d_3x3_1x3x24x24_st1, "3x3_1x3x12x12_st2_pd1": lambda: conv2d_3x3_1x3x12x12_st2_pd1, "1x1_1x2x16x16_st1": lambda: conv2d_1x1_1x2x16x16_st1, - "2x2_1x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_1x1x14x13_st2, + "2x2_2x1x14x13_st2_needs_adjust_pass": lambda: conv2d_2x2_2x1x14x13_st2, "5x5_1x3x14x15_st3_pd1_needs_adjust_pass": lambda: conv2d_5x5_1x3x14x15_st3_pd1, "7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass": lambda: conv2d_7x7_1x3x16x16_st2_pd1_dl2, "7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass": lambda: conv2d_7x7_1x3x15x15_st1_pd0_dl1, @@ -391,6 +395,15 @@ def forward(self, x): input_t = Tuple[torch.Tensor] +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp. + "CONV2D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + @common.parametrize("test_data", test_data_FP) def test_convolution_2d_tosa_FP(test_data): model = test_data() @@ -417,6 +430,36 @@ def test_convolution_2d_tosa_INT(test_data): pipeline.run() +@common.parametrize( + "test_data", + test_data_INT, + xfails={ + "groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", + "groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", + }, +) +def test_convolution_2d_tosa_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @common.XfailIfNoCorstone300 def test_convolution_2d_u55_INT(test_data): @@ -431,6 +474,21 @@ def test_convolution_2d_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_2d_u55_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @common.XfailIfNoCorstone320 def test_convolution_u85_INT(test_data): @@ -445,6 +503,21 @@ def test_convolution_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_2d_u85_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter def test_convolution_2d_vgf_no_quant(test_data): diff --git a/backends/arm/test/ops/test_conv3d.py b/backends/arm/test/ops/test_conv3d.py index 9c831c9ba49..f28315dcdae 100644 --- a/backends/arm/test/ops/test_conv3d.py +++ b/backends/arm/test/ops/test_conv3d.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -10,6 +10,7 @@ import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, + get_symmetric_a8w4_quantization_config, TOSAQuantizer, ) from executorch.backends.arm.test import common, conftest @@ -430,6 +431,15 @@ def forward(self, x): } +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, + "CONV3D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + def get_symmetric_a16w8_conv3d_quantizer(per_channel_quantization: bool = False): tosa_version = conftest.get_option("tosa_version") tosa_profiles = { @@ -474,6 +484,28 @@ def test_convolution_3d_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +def test_convolution_3d_tosa_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + tosa_extensions=["int4"], + qtol=1, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT16) def test_convolution_3d_tosa_INT_a16w8(test_data): model, per_channel_quantization = test_data() @@ -543,6 +575,22 @@ def test_convolution_3d_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +@pytest.mark.skip(reason="Ethos-U55 does not support CONV3D yet.") +def test_convolution_3d_u55_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_INT) @pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") def test_convolution_3d_u85_INT(test_data): @@ -557,6 +605,22 @@ def test_convolution_3d_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_INT) +@pytest.mark.skip(reason="Ethos-U85 does not support CONV3D yet.") +def test_convolution_3d_u85_INT_a8w4(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_op, + exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + @common.parametrize("test_data", test_data_FP) @common.SkipIfNoModelConverter def test_convolution_3d_vgf_no_quant(test_data): diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 017993e737b..b4289f922ce 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,6 +8,9 @@ import pytest import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a8w4_quantization_config, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -198,6 +201,15 @@ } +def _get_dtype_count(model: torch.nn.Module): + nbr_convs: int = model.nbr_convs # noqa + return { + "CONST": {"INT4": nbr_convs * 2}, + "DEPTHWISE_CONV2D": {"INT32": nbr_convs}, + "RESCALE": {"INT8": nbr_convs}, + } + + @common.parametrize("test_data", test_data_conv1d_FP | test_data_conv2d_FP) def test_convolution_2d_tosa_FP_depthwise(test_data: torch.nn.Module): pipeline = TosaPipelineFP[input_t]( @@ -223,6 +235,27 @@ def test_convolution_2d_tosa_INT_depthwise(test_data): pipeline.run() +@common.parametrize("test_data", test_data_conv1d_INT | test_data_conv2d_INT) +def test_convolution_2d_tosa_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = TosaPipelineINT[input_t]( + model, + model.get_inputs(), + aten_op=[], + exir_op=exir_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + _get_dtype_count(model), + ) + pipeline.run() + + @common.parametrize("test_data", test_data_conv1d_FP | test_data_conv2d_FP) @common.SkipIfNoModelConverter def test_convolution_2d_vgf_no_quant_depthwise(test_data: torch.nn.Module): @@ -251,7 +284,7 @@ def test_convolution_2d_vgf_quant_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone300 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv2d_INT) def test_convolution_2d_u55_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -265,7 +298,23 @@ def test_convolution_2d_u55_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone300 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_conv2d_INT) +def test_convolution_2d_u55_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone300 @common.parametrize("test_data", test_data_conv1d_INT) def test_convolution_1d_u55_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -279,7 +328,23 @@ def test_convolution_1d_u55_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone320 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_conv1d_INT) +def test_convolution_1d_u55_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU55PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv2d_INT) def test_convolution_2d_u85_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -293,7 +358,23 @@ def test_convolution_2d_u85_INT_depthwise(test_data): pipeline.run() -@common.XfailIfNoCorstone320 # TODO: MLETORCH-516 +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_conv2d_INT) +def test_convolution_2d_u85_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 @common.parametrize("test_data", test_data_conv1d_INT) def test_convolution_1d_u85_INT_depthwise(test_data): model, per_channel_quantization = test_data() @@ -305,3 +386,19 @@ def test_convolution_1d_u85_INT_depthwise(test_data): per_channel_quantization=per_channel_quantization, ) pipeline.run() + + +@common.XfailIfNoCorstone320 +@common.parametrize("test_data", test_data_conv1d_INT) +def test_convolution_1d_u85_INT_a8w4_depthwise(test_data): + model, per_channel_quantization = test_data() + pipeline = EthosU85PipelineINT[input_t]( + model, + model.get_inputs(), + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 77e512cdf2f..7e22ad304e4 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,6 +11,7 @@ import torch from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, + get_symmetric_a8w4_quantization_config, TOSAQuantizer, ) from executorch.backends.arm.test import common, conftest @@ -166,6 +167,35 @@ def test_linear_tosa_INT(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) +def test_linear_tosa_INT_a8w4(test_data: torch.Tensor): + test_data, out_features, has_bias, per_channel_quantization = test_data() + in_features = test_data.shape[-1] + pipeline = TosaPipelineINT[input_t1]( + Linear( + in_features=in_features, + out_features=out_features, + bias=has_bias, + ), + (test_data,), + aten_op, + tosa_extensions=["int4"], + ) + pipeline.quantizer.set_global( + get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization) + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_dtype_count, + { + "CONST": {"INT4": 2}, + "CONV2D": {"INT32": 1}, + "RESCALE": {"INT8": 1}, + }, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_rank1_INT) @common.XfailIfNoCorstone300 def test_linear_u55_INT(test_data: torch.Tensor): diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index ca83c6c09ea..c11a046cd66 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,6 +36,7 @@ class TosaSpecialDtype(Enum): """Special TOSA dtypes not natively expressed in PyTorch.""" INT48 = ts.DType.INT48 + INT4 = ts.DType.INT4 def get_tosa_dtype(self) -> ts.DType: """Return the underlying ``ts.DType`` enumerant. @@ -56,6 +57,24 @@ def meta_key() -> str: """ return "tosa_special_dtype" + def max(self): + match self: + case self.INT4: + return 7 + case self.INT48: + return 2**47 - 1 + case _: + raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") + + def min(self): + match self: + case self.INT4: + return -7 + case self.INT48: + return -(2**47) + case _: + raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") + def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. @@ -180,6 +199,11 @@ def __process_node(self, argument: torch.fx.Node): else: self.multiple_output_names = [] + if not self.__validate(): + raise ValueError( + f"{self.tosa_spec} doesn't support tensor {self.__repr__()}" + ) + def __process_list(self, argument): """Capture a sequence argument as ``special``. @@ -198,6 +222,17 @@ def __process_number(self, argument: float | int): """ self.number: float | int = argument + def __validate(self) -> bool: + match getattr(self, "dtype", None): + case ts.DType.FP32: + if not self.tosa_spec.support_float(): + return False + case ts.DType.INT4: + if not self.tosa_spec.support_extension("int4"): + return False + + return True + def __init__( self, argument: Any, tosa_spec: Optional[TosaSpecification] = None ) -> None: