diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f9e23f73cc5..a5d8e17f0cd 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -46,6 +46,9 @@ from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa +from .decompose_int16_activation_conv2d_pass import ( # noqa + DecomposeConv2dWithInt16ActivationPass, +) from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index a8a76c0a47b..fd5476f51b8 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops @@ -59,6 +60,10 @@ def call(self, graph_module): persistent_buffer=True, name=f"{node.name}_bias", ) + if node.args[0].meta["val"].dtype == torch.int16: + bias_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) node.update_arg(2, bias_node) if modified: diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c6530357f3b..70470890317 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,6 +42,7 @@ DecomposeAtanPass, DecomposeAvgPool2d, DecomposeBatchNormNoStatsPass, + DecomposeConv2dWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) + self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) @@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) + # If we have a conv2d with int16 activation split up into a convolution + # and an addition, to work-around the lack of support for int48 in torch + # needs to happen before AddBiasPass, but after the table ops are inserted + # to be able to validate that conv2d has right dtype arguments. + self.add_pass(DecomposeConv2dWithInt16ActivationPass()) self.add_pass(AddBiasPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py new file mode 100644 index 00000000000..d43c2a8c89c --- /dev/null +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -0,0 +1,145 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00 +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeConv2dWithInt16ActivationPass(ExportPass): + """ + This pass decomposes a convolution with input dtype int16 and bias + into a convolution without bias followed by an addition of the bias + since the TOSA op requires the bias to be int48 which is hard to represent + in torch. Instead rescale the int48 output to int16 and add the bias in int16. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.convolution.default: + return super().call_operator(op, args, kwargs, meta) + + tosa_spec = get_context_spec() + if not tosa_spec.support_integer(): + return super().call_operator(op, args, kwargs, meta) + + # return if no bias + if args[2] is None: + return super().call_operator(op, args, kwargs, meta) + + if args[0].data.dtype == torch.int8: + return super().call_operator(op, args, kwargs, meta) + elif args[0].data.dtype == torch.int16: + if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension( + "int16" + ): + raise ValueError( + "int16 activation for convolution requires TOSA int16 extension" + ) + else: + raise NotImplementedError( + "Decomposition to conv+add only implemented for activation of int16 type" + ) + + # convolution with bias and activation is int16 + # The bias is assumed to be quantized with the same quantization parameters as + # as the output of the convolution + bias = args[2] + assert ( + meta.data["output_qparams"][0].dtype == bias.data.dtype + ), "Bias needs to have same type as quantized output type" + no_bias_args = list(args) + no_bias_args[2] = None + # split up to convolution + bias + convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta) + + # create a copy of the meta without the qparams, to be used with the new nodes + new_meta = meta.copy() + new_meta.data.pop("output_qparams", None) + new_meta.data.pop("input_qparams", None) + + # reshape the tensor to the same rank as the convolution output to add the bias to the channels + channel_bias = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (bias, [1, len(bias.data), 1, 1]), + {}, + new_meta, + ) + + output_dtype = meta.data["output_qparams"][0].dtype + + if output_dtype == torch.int16: + # The conv will get the output int48 scaled to int32 in serialization step. + # To be able to add the bias we need to first scale (cast?) the output to int32. + # The resulting i32 sum will then need to be scaled back to the output dtype. + + # calculate common rescale factor from convolution output and bias quantization + output_qparams = cast(QuantArgs, meta.data["output_qparams"][0]) + conv_output_scale = output_qparams.scale + bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) + bias_scale = bias_qparams.scale + + common_scale = max(bias_scale, conv_output_scale) + + # calculate how we can rescale bias and conv to a common scale and maximize the output range + bias_rescale_factor = bias_scale / common_scale + conv_rescale_factor = conv_output_scale / common_scale + + # Either of conv output or bias now covers the full int16 range and the other one a smaller range. + # Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range. + # Worst case here is that both bias and conv output covers the full int16 range so we leave one bit + # and then one for the sign bit. + bits_left_to_shift = 14 + + # update rescale factors + bias_rescale_factor *= 1 << bits_left_to_shift + conv_rescale_factor *= 1 << bits_left_to_shift + + conv_output = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (convolution, torch.int32, conv_rescale_factor, 0, 0), + {}, + new_meta, + ) + + bias_rescaled = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (channel_bias, torch.int32, bias_rescale_factor, 0, 0), + {}, + new_meta, + ) + + add = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (conv_output, bias_rescaled), + {}, + new_meta, + ) + + res_rescale = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + ( + add, + output_dtype, + (common_scale / (conv_output_scale * (1 << bits_left_to_shift))), + 0, + 0, + ), + {}, + new_meta, + ) + + else: + raise NotImplementedError( + f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}" + ) + + return res_rescale diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index cf1177a0448..b8b8143e6c5 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -8,11 +8,13 @@ from typing import Set, Type import torch + from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_param_tensor, is_param_node, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -47,9 +49,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: continue # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes # Ensure tensor is on CPU and contiguous + + # ensure we don't merge any special case int48_t tensors with int32_t tensors + # since int48_t tensors needs to be instantiated separately. + is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( + is_int48, str(t_cpu.dtype), tuple(t_cpu.shape), hashlib.sha1(data_bytes).hexdigest(), diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 6bfe0ab21eb..41e422b5504 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -19,10 +19,11 @@ ) from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, + validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.quant_utils import build_rescale +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape @@ -73,6 +74,32 @@ def define_node( input, weight, bias, stride, pad, dilation, _, _, group = inputs validate_num_inputs(self.target, inputs, 9) + valid_input_dtypes = [] + if self.tosa_spec.support_float(): + valid_input_dtypes.append(ts.DType.FP32) + if self.tosa_spec.support_integer(): + valid_input_dtypes.append(ts.DType.INT8) + + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + valid_input_dtypes.append(ts.DType.INT16) + # Check constraints for int16 activations + if inputs[0].dtype == ts.DType.INT16: + validate_valid_dtype( + self.target, [inputs[1]], [ts.DType.INT8], self.tosa_spec + ) + validate_valid_dtype( + self.target, [inputs[2]], [ts.DType.INT48], self.tosa_spec + ) + + validate_valid_dtype( + self.target, + [inputs[0]], + valid_input_dtypes, + self.tosa_spec, + ) + # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -97,8 +124,8 @@ def define_node( ) input_zp = 0 - if inputs[0].dtype == ts.DType.INT8: - # int8 input requires quantization information + if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): + # int8 and int16 input requires quantization information input_qparams = get_input_qparams(node) input_zp = input_qparams[0].get_zp_per_tensor() @@ -109,15 +136,22 @@ def define_node( weight_zp = input_qparams[1].zp # type: ignore[assignment] # The output type is int32 when input type is int8. - conv2d_output_name = output.name - if output.dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) conv2d_output_name = conv2d_res.name - acc_type = ( - inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32 - ) + acc_type = ts.DType.INT32 + elif inputs[0].dtype == ts.DType.INT16: + conv2d_res = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT48 + ) + conv2d_output_name = conv2d_res.name + acc_type = ts.DType.INT48 + else: + conv2d_output_name = output.name + conv2d_res = output + acc_type = ts.DType.FP32 tosa_graph.addConst( [1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" @@ -207,7 +241,7 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if inputs[0].dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: # Get scale_factor from input, weight, and output. input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 5093ea32d4c..50257bc9180 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,7 +12,7 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape from torch._export.utils import ( @@ -112,10 +112,17 @@ def process_inputs_to_parameters( if tosa_arg.dtype == torch.float32: assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + # Handle special case for INT48 tensors + special_type = node.meta.get(TosaSpecialDtype.meta_key(), None) + if isinstance(special_type, TosaSpecialDtype): + tosa_dtype = special_type.get_tosa_dtype() + else: + tosa_dtype = tosa_arg.dtype + parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name + parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name ) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index d5c3aab1060..29af10dfd1d 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -89,29 +89,48 @@ def _derive_qparams_fn( torch.ops.aten.linear.default, torch.ops.aten.conv2d.padding, ]: - input_act = node.args[0] - weight = node.args[1] - # If the weights are quantized per_tensor, do the same with bias - qscheme = ( - torch.per_tensor_symmetric - if self.weight is None - else self.weight.qscheme - ) - ch_axis = None - if self.weight is not None: - if qscheme == torch.per_channel_symmetric: - ch_axis = self.weight.ch_axis + if self.input_activation is None or self.weight is None: + raise ValueError( + "Input activation and weight QuantizationConfig must be specified." + ) + if self.input_activation.dtype == self.weight.dtype == torch.int8: + # This is the default int8 quantization which uses the derived quantization + # calculated from the activation and weight scale + input_act = node.args[0] + weight = node.args[1] - quantization_spec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] - derive_qparams_fn=_derive_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=qscheme, - ch_axis=ch_axis, - ) - return quantization_spec # type: ignore[return-value] + # If the weights are quantized per_tensor, do the same with bias + qscheme = ( + torch.per_tensor_symmetric + if self.weight is None + else self.weight.qscheme + ) + ch_axis = None + if self.weight is not None: + if qscheme == torch.per_channel_symmetric: + ch_axis = self.weight.ch_axis + + quantization_spec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] + derive_qparams_fn=_derive_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=qscheme, + ch_axis=ch_axis, + ) + return quantization_spec # type: ignore[return-value] + elif ( + self.input_activation.dtype == torch.int16 + and self.weight.dtype == torch.int8 + ): + # In case the activation is quantized to int16, the bias needs to be + # added after the convolution, so use the output quantization for this case. + return self.output_activation + else: + raise NotImplementedError( + f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" + ) if self.bias is None: return None diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index c6eaafa597b..2629d8eb257 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -95,6 +95,9 @@ def parse_test_name( op = op.removesuffix("_1d") op = op.removesuffix("_2d") + # Remove suffix for 16 bit activation and 8 bit weight test cases + op = op.removesuffix("_16a8w") + assert target != "None", f"{test_name} does not contain one of {TARGETS}" assert ( op in op_name_map.keys() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index f9aa4f14048..ebc2ead8a83 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -277,10 +277,14 @@ def get_symmetric_a16w8_linear_quantizer( ) -@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) -@pytest.mark.xfail( - reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph" -) +test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT +# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377 +for k in list(test_data_all_16a8w.keys()): + if "large_rand" in k: + test_data_all_16a8w.pop(k) + + +@common.parametrize("test_data", test_data_all_16a8w) def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): """Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)""" test_data, out_features, has_bias, per_channel_quantization = test_data() diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 935d9f8da77..64e4ae96e08 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -11,6 +11,7 @@ """ +from enum import Enum from typing import Any, Optional, Sequence import serializer.tosa_serializer as ts # type: ignore @@ -31,6 +32,22 @@ ) +class TosaSpecialDtype(Enum): + """ + Special TOSA data types that are not natively supported in PyTorch, to be + used in specific scenarios as a value in the key from meta_key(). + """ + + INT48 = ts.DType.INT48 + + def get_tosa_dtype(self) -> ts.TosaDType.DType: + return self.value + + @staticmethod + def meta_key() -> str: + return "tosa_special_dtype" + + def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. @@ -130,10 +147,16 @@ def __process_node(self, argument: torch.fx.Node): """ self.name: str = argument.name - self.dtype, self.shape, self.dim_order = extract_tensor_meta( + output_dtype, self.shape, self.dim_order = extract_tensor_meta( argument.meta, self.tosa_spec ) + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() + + self.dtype = output_dtype + def __process_list(self, argument): """Capture a sequence argument as ``special``. diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 027c26fc20a..68ceec8d97c 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -268,6 +268,9 @@ def compute_multiplier_and_shift( if shift > 62: multiplier = multiplier >> min(31, shift - 62) shift = 62 + + assert multiplier >= 0, "Multiplier should be non-negative" + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" multipliers.append(multiplier) shifts.append(shift) return multipliers, shifts @@ -322,8 +325,8 @@ def build_rescale( import tosa.Op as TosaOp # type: ignore - scaleWidth = 32 - is_scale32 = True + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) rescale_inputs = create_const_ops_for_rescale( tosa_fb,