diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index 9d6dea8a97b..db82172a9e2 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -6,12 +6,12 @@ from typing import Sequence import torch -from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, QuantizationConfig, ) +from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from executorch.exir.dialects._ops import ops as exir_ops from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -110,7 +110,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): # Annotate 16a8w for matmul op to get better performance quantization_config_16a8w = get_16a8w_qnn_ptq_config() # Annotate 8a8w for second input of matmul until past_kv_cache - quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) for node in gm.graph.nodes: if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: if "nn_module_stack" in node.meta: diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py deleted file mode 100644 index d556dfa4ba3..00000000000 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase - - -# TODO move to torch/ao/quantization/observer.py. -class PerChannelParamObserver(UniformQuantizationObserverBase): - def __init__( - self, - ch_axis=0, - use_mse=True, - steps=100, - dtype=torch.int8, - qscheme=torch.per_channel_symmetric, - reduce_range=False, - quant_min=None, - quant_max=None, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps, # noqa: B008 - is_dynamic=False, - **kwargs, - ) -> None: - super().__init__( - dtype=dtype, - qscheme=qscheme, - reduce_range=reduce_range, - quant_min=quant_min, - quant_max=quant_max, - factory_kwargs=factory_kwargs, - eps=eps, - is_dynamic=is_dynamic, - **kwargs, - ) - - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) - self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) - self.ch_axis = ch_axis - self.use_mse = use_mse - self.steps = steps - self.calibrated = False - - def to_ch_axis(self, x): - axis_order = list(range(len(x.size()))) - axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis - return torch.flatten(x.permute(axis_order), start_dim=1) - - def mse(self, pred, expect): - loss = (pred - expect).abs().pow(2) - return self.to_ch_axis(loss).mean(1) - - def cosine(self, pred, expect): - target = torch.ones(pred.shape[self.ch_axis]) - pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) - expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) - return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) - - def loss_fn(self, x, new_min, new_max): - scale, offset = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_channel_affine( - x, - scale.data, - offset.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) - - def line_search(self, x): - x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) - x_range = torch.max(x_min.abs(), x_max) - optimal_loss = torch.zeros_like(x_min) + 1e9 - - # check which clip range could produce smallest loss - for i in range(1, self.steps + 1): - thres = x_range / self.steps * i - current_loss = self.loss_fn(x, -thres, thres) - x_min = torch.where(current_loss < optimal_loss, -thres, x_min) - x_max = torch.where(current_loss < optimal_loss, thres, x_max) - optimal_loss = torch.min(current_loss, optimal_loss) - - return x_min, x_max - - def forward(self, x_orig): - # since params are static, one calibration is enough - if not self.calibrated: - x = x_orig.detach().to(self.min_val.dtype) - self.min_val, self.max_val = self.line_search(x) - self.calibrated = True - - # return fake-quant result for saturating outliers - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - return torch.fake_quantize_per_channel_affine( - x_orig, - scale.data, - zero_point.data.int(), - self.ch_axis, - self.quant_min, - self.quant_max, - ) - - @torch.jit.export - def calculate_qparams(self): - return self._calculate_qparams(self.min_val, self.max_val) diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py deleted file mode 100644 index e07ca24d90f..00000000000 --- a/backends/qualcomm/quantizer/qconfig.py +++ /dev/null @@ -1,464 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch -from torch import Tensor -from torch.ao.quantization.fake_quantize import ( - FakeQuantize, - FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( - MinMaxObserver, - MovingAverageMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, - PerChannelMinMaxObserver, -) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node - - -@dataclass(eq=True, frozen=True) -class QuantizationConfig: - input_activation: Optional[QuantizationSpec] - output_activation: Optional[QuantizationSpec] - weight: Optional[QuantizationSpec] - bias: Optional[QuantizationSpec | Callable] - - -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_8a8w_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a8w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a16w_qnn_ptq_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int8, - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# TODO merge qat and ptq to a fucntion, and use a bool flag to control it -def get_8a8w_qnn_qat_config( - act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver -) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - qscheme=( - torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine - ), - ch_axis=0, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=bias_fake_quant_ctr, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_16a4w_qnn_qat_config( - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - reduce_range=True, - observer=MovingAverageMinMaxObserver, - ) - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=bias_fake_quant_ctr, - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_qat_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int8, - act_observer=MovingAverageMinMaxObserver, -) -> QuantizationConfig: - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_fake_quant_ctr = FakeQuantize.with_args( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - reduce_range=True, - observer=act_observer, - ) - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=act_fake_quant_ctr, - ) - - weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer=MovingAveragePerChannelMinMaxObserver, - ) - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=weight_fake_quant_ctr, - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 50ed07788fd..9e5aaf782a7 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Callable, Optional, Sequence, Set +from typing import Callable, Dict, Optional, Sequence, Set import torch from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum @@ -22,17 +22,14 @@ from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule -from .annotators import OP_ANNOTATOR - -from .qconfig import ( - get_16a16w_qnn_ptq_config, +from .utils import ( get_16a4w_qnn_ptq_config, - get_16a4w_qnn_qat_config, get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, - get_8a8w_qnn_qat_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qat_proto, + get_default_8bit_qnn_ptq_config, get_ptq_per_channel_quant_config, - get_qat_per_channel_quant_config, + OP_ANNOTATOR, QuantizationConfig, ) @@ -41,10 +38,9 @@ "QuantDtype", "get_16a4w_qnn_ptq_config", "get_16a8w_qnn_ptq_config", - "get_16a16w_qnn_ptq_config", - "get_8a8w_qnn_ptq_config", - "get_8a8w_qnn_qat_config", - "get_16a4w_qnn_qat_config", + "get_default_16bit_qnn_ptq_config", + "get_default_8bit_qnn_ptq_config", + "get_default_8bit_qat_proto", ] @@ -55,39 +51,8 @@ class QuantDtype(IntEnum): """ use_16a16w = 0 - use_16a8w = 1 - use_16a4w = 2 - use_8a8w = 3 - - -quant_config_dict = { - # PTQ - (QuantDtype.use_16a16w, False): ( - get_16a16w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int16), - ), - (QuantDtype.use_16a8w, False): ( - get_16a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, torch.int8), - ), - (QuantDtype.use_16a4w, False): ( - get_16a4w_qnn_ptq_config, - get_ptq_per_channel_quant_config(torch.uint16, "int4"), - ), - (QuantDtype.use_8a8w, False): ( - get_8a8w_qnn_ptq_config, - get_ptq_per_channel_quant_config(), - ), - # QAT, - (QuantDtype.use_16a4w, True): ( - get_16a4w_qnn_qat_config, - get_qat_per_channel_quant_config(torch.uint16, "int4"), - ), - (QuantDtype.use_8a8w, True): ( - get_8a8w_qnn_qat_config, - get_qat_per_channel_quant_config(), - ), -} + use_16a4w = 1 + use_8a8w = 2 class QnnQuantizer(Quantizer): @@ -95,17 +60,23 @@ class QnnQuantizer(Quantizer): def __init__(self): super().__init__() - self.quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() + self.bit8_quant_config: QuantizationConfig = get_default_8bit_qnn_ptq_config() + self.bit16_quant_config: QuantizationConfig = get_default_16bit_qnn_ptq_config() - self.is_qat = False - self.quant_dtype = QuantDtype.use_8a8w - self.quant_config: QuantizationConfig = get_8a8w_qnn_ptq_config() - self.per_channel_quant_config = get_ptq_per_channel_quant_config() - self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() + self.bit8_quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() + self.bit16_quant_ops: Set[OpOverload] = set() self.custom_quant_annotations: Sequence[Callable] = [] self.discard_nodes: Set[str] = set() + self.use_per_channel_weight_quant_ops: Set[OpOverload] = set() + # the weight quantized for activation 8 bits and 16 bits + self.per_channel_weight_dtype: Dict = { + "8bit_act": torch.int8, + "16bit_act": torch.int16, + } + self.per_channel_quant_config = None + def _annotate(self, gm: GraphModule) -> None: for node in gm.graph.nodes: if node.name in self.discard_nodes: @@ -123,16 +94,29 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig """ Priority: 1. is one of use_per_channel_weight_quant_ops - 2. quant config + 2. int8 / int16 config """ if isinstance(op, str): return if op in self.use_per_channel_weight_quant_ops: + if self.per_channel_quant_config is None: + if op in self.bit16_quant_ops: + return get_ptq_per_channel_quant_config( + act_dtype=torch.uint16, + weight_dtype=self.per_channel_weight_dtype["16bit_act"], + ) + return get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, + weight_dtype=self.per_channel_weight_dtype["8bit_act"], + ) return self.per_channel_quant_config - if op in self.quant_ops: - return self.quant_config + if op in self.bit8_quant_ops: + return self.bit8_quant_config + + if op in self.bit16_quant_ops: + return self.bit16_quant_config print(f"No quant config is implemented for op, {op}") @@ -142,6 +126,15 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo else: self.use_per_channel_weight_quant_ops.difference_update(ops) + def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: + for op in ops: + assert ( + op in self.SUPPORTED_OPS + ), f"The annotation of op {op} is not implemented" + + self.bit8_quant_ops.remove(op) + self.bit16_quant_ops.add(op) + def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: @@ -152,7 +145,10 @@ def add_discard_nodes(self, nodes: Sequence[str]) -> None: def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: for op in ops: - self.quant_ops.remove(op) + if op in self.bit8_quant_ops: + self.bit8_quant_ops.remove(op) + if op in self.bit16_quant_ops: + self.bit16_quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) @@ -163,22 +159,24 @@ def annotate(self, model: GraphModule) -> GraphModule: def get_supported_ops(self) -> Set[OpOverload]: return self.SUPPORTED_OPS - def set_quant_config( - self, quant_dtype: QuantDtype, is_qat=False, act_observer=None + def set_bit16_op_quant_config( + self, quantization_config: QuantizationConfig + ) -> None: + self.bit16_quant_config = quantization_config + + def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: + self.bit8_quant_config = quantization_config + + def set_per_channel_weight_dtype( + self, + weight_dtype_for_8bit_act: Optional[str | torch.dtype] = None, + weight_dtype_for_16bit_act: Optional[str | torch.dtype] = None, ) -> None: - self.quant_dtype = quant_dtype - self.is_qat = is_qat - if (quant_dtype, is_qat) not in quant_config_dict: - raise RuntimeError( - f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support" - ) - - quant_config_fuc, self.per_channel_quant_config = quant_config_dict[ - (quant_dtype, is_qat) - ] - self.quant_config = ( - quant_config_fuc(act_observer) if act_observer else quant_config_fuc() - ) + # TODO accept temporally str type. Remove it when torch support torch.int4 dtype + if weight_dtype_for_8bit_act: + self.per_channel_weight_dtype["8bit_act"] = weight_dtype_for_8bit_act + if weight_dtype_for_16bit_act: + self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act def set_per_channel_conv_quant(self, enable: bool) -> None: conv_ops = {torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default} diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/utils.py similarity index 68% rename from backends/qualcomm/quantizer/annotators.py rename to backends/qualcomm/quantizer/utils.py index 275da567e8f..dc3d2a6841b 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/utils.py @@ -5,16 +5,29 @@ # LICENSE file in the root directory of this source tree. import numbers import operator +from dataclasses import dataclass from functools import partial -from typing import Callable, Dict, List, Sequence, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch -from torch._ops import OpOverload +from torch import Tensor +from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.observer import FixedQParamsObserver +from torch.ao.quantization.fake_quantize import ( + default_fake_quant, + FusedMovingAvgObsFakeQuantize, +) + +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, + UniformQuantizationObserverBase, +) + from torch.ao.quantization.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, @@ -27,12 +40,397 @@ ) from torch.fx import Node -from .qconfig import ( - get_16a16w_qnn_ptq_config, - get_16a4w_qnn_qat_config, - get_8a8w_qnn_qat_config, - QuantizationConfig, -) + +class ParamObserver(UniformQuantizationObserverBase): + def __init__( + self, + ch_axis=0, + use_mse=True, + steps=100, + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + self.ch_axis = ch_axis + self.use_mse = use_mse + self.steps = steps + self.calibrated = False + + def to_ch_axis(self, x): + axis_order = list(range(len(x.size()))) + axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis + return torch.flatten(x.permute(axis_order), start_dim=1) + + def mse(self, pred, expect): + loss = (pred - expect).abs().pow(2) + return self.to_ch_axis(loss).mean(1) + + def cosine(self, pred, expect): + target = torch.ones(pred.shape[self.ch_axis]) + pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1) + expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1) + return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target) + + def loss_fn(self, x, new_min, new_max): + scale, offset = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, + scale.data, + offset.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x) + + def line_search(self, x): + x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1) + x_range = torch.max(x_min.abs(), x_max) + optimal_loss = torch.zeros_like(x_min) + 1e9 + + # check which clip range could produce smallest loss + for i in range(1, self.steps + 1): + thres = x_range / self.steps * i + current_loss = self.loss_fn(x, -thres, thres) + x_min = torch.where(current_loss < optimal_loss, -thres, x_min) + x_max = torch.where(current_loss < optimal_loss, thres, x_max) + optimal_loss = torch.min(current_loss, optimal_loss) + + return x_min, x_max + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + # return fake-quant result for saturating outliers + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + return torch.fake_quantize_per_channel_affine( + x_orig, + scale.data, + zero_point.data.int(), + self.ch_axis, + self.quant_min, + self.quant_max, + ) + + @torch.jit.export + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + weight: Optional[QuantizationSpec] + bias: Optional[QuantizationSpec | Callable] + + +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_default_8bit_qat_proto(act_symmetric: bool = False) -> QuantizationConfig: + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=default_fake_quant, + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver + ), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=default_fake_quant, + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_8bit_qnn_ptq_config( + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_quant_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config QUANT_ANNOTATION_KEY = "quantization_annotation" @@ -503,34 +901,19 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non scale = 1 / (q_max - q_min + 1) - bias_obs_ctr = observer = FixedQParamsObserver.with_args( - scale=scale, - zero_point=0, + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( dtype=quantization_config.output_activation.dtype, - qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ) - if quantization_config in ( - get_8a8w_qnn_qat_config(), - get_16a4w_qnn_qat_config(), - ): - bias_obs_ctr = FixedQParamsFakeQuantize.with_args( - observer=observer, + observer_or_fake_quant_ctr=FixedQParamsObserver.with_args( scale=scale, zero_point=0, dtype=quantization_config.output_activation.dtype, qscheme=torch.torch.per_tensor_affine, quant_max=q_max, quant_min=q_min, - ) - - # make sigmoid map to the range between 0~1 - out_act_quantization_spec = QuantizationSpec( - dtype=quantization_config.output_activation.dtype, - quant_max=q_max, - quant_min=q_min, - observer_or_fake_quant_ctr=bias_obs_ctr, + ), qscheme=torch.torch.per_tensor_affine, ) @@ -703,7 +1086,7 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -732,7 +1115,7 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: # we should use int16 for mm / bmm instead of int4 - input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -875,7 +1258,7 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> _annotate_input_qspec_map( node, weight_node, - get_16a16w_qnn_ptq_config().weight, + get_default_16bit_qnn_ptq_config().weight, ) else: _annotate_input_qspec_map( diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 64b0490d461..4bfdedcd4b4 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -698,17 +698,6 @@ def test_qnn_backend_16a4w_conv2d(self): ) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_16a4w_conv2d_qat(self): - modules = [Conv2dSingle(), Conv2dSingle(bias=False)] # noqa: F405 - sample_input = (torch.randn([1, 1, 3, 3]),) - for i, module in enumerate(modules): - with self.subTest(i=i): - prepared = self.get_prepared_qat_module(module, sample_input) - converted = self.get_converted_sgd_trained_module( - module, prepared, sample_input - ) - self.lower_module_and_test_output(converted, sample_input) - def test_qnn_backend_16a4w_layer_norm(self): module = LayerNorm() # noqa: F405 sample_input = (torch.randn(196, 768),) @@ -1074,8 +1063,18 @@ def test_qnn_backend_linear_qat(self): """ module = Linear() # noqa: F405 sample_input = (torch.randn([3, 4]),) - prepared = self.get_prepared_qat_module(module, sample_input) - module = self.get_converted_sgd_trained_module(module, prepared, sample_input) + + module = self.get_prepared_qat_module(module, sample_input) + + optimizer = torch.optim.SGD(module.parameters(), lr=0.1) + criterion = torch.nn.CrossEntropyLoss() + output = module(*sample_input) + loss = criterion(output, module(*sample_input)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + module = torch.ao.quantization.quantize_pt2e.convert_pt2e(module) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log_softmax(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index d2a3e7c2417..114493c7d2f 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -17,7 +17,13 @@ from executorch import exir from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qat_proto, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -399,7 +405,18 @@ def get_qdq_module( quantizer.add_custom_quant_annotations(custom_quant_annotations) quantizer.set_per_channel_conv_quant(is_conv_per_channel) quantizer.set_per_channel_linear_quant(is_linear_per_channel) - quantizer.set_quant_config(quant_dtype) + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a16w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") prepared = prepare_pt2e(m, quantizer) prepared(*inputs) @@ -431,28 +448,13 @@ def get_prepared_qat_module( quantizer.set_per_channel_linear_quant(is_linear_per_channel) if quant_dtype == QuantDtype.use_8a8w: - quantizer.set_quant_config(quant_dtype, is_qat=True) + quantizer.set_bit8_op_quant_config(get_default_8bit_qat_proto()) else: raise RuntimeError("Shuld not be here") prepared = prepare_qat_pt2e(m, quantizer) return torch.ao.quantization.move_exported_model_to_train(prepared) - def get_converted_sgd_trained_module( - self, - ori_module: torch.nn.Module, - prepared: torch.nn.Module, - inputs: Tuple[torch.Tensor], - ) -> torch.fx.GraphModule: - optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) - criterion = torch.nn.CrossEntropyLoss() - output = prepared(*inputs) - loss = criterion(output, ori_module(*inputs)) - optimizer.zero_grad() - loss.backward() - optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) - def split_graph(self, graph_module: torch.fx.GraphModule, division: int): class SplitGraph(ExportPass): """ diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index cb54412add0..0ea4512abce 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -331,7 +331,7 @@ def _transform( def capture_program( module: torch.nn.Module, inputs: Tuple[torch.Tensor], - custom_pass_config: Set[str] = frozenset(), + custom_pass_config: Set[str] = None, ) -> exir.ExirExportedProgram: ep = torch.export.export(module, inputs) decomposed_ep = ep.run_decompositions(get_decomp_table()) diff --git a/examples/qualcomm/oss_scripts/fastvit.py b/examples/qualcomm/oss_scripts/fastvit.py index 0e2c695ab34..30fe74f35b5 100644 --- a/examples/qualcomm/oss_scripts/fastvit.py +++ b/examples/qualcomm/oss_scripts/fastvit.py @@ -10,19 +10,15 @@ import numpy as np import torch -from executorch.backends.qualcomm.quantizer.annotators import ( - QuantizationConfig, - QuantizationSpec, -) -from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( - PerChannelParamObserver, -) -from executorch.backends.qualcomm.quantizer.qconfig import ( + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.quantizer.utils import ( _derived_bias_quant_spec, MovingAverageMinMaxObserver, + ParamObserver, + QuantizationConfig, + QuantizationSpec, ) - -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_EXPAND_BROADCAST_SHAPE, ) @@ -91,7 +87,7 @@ def main(args): quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_channel_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( + observer_or_fake_quant_ctr=ParamObserver.with_args( **{"steps": 200, "use_mse": True} ), ) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 9f7198a3447..04569df5c92 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -56,12 +56,12 @@ def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: This function is specific for matmul op 16a8w. """ - from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a8w_qnn_ptq_config, - get_8a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, QuantizationConfig, ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import ( QuantizationAnnotation, SharedQuantizationSpec, @@ -119,7 +119,7 @@ def annotate_single_in_single_out( ) def annotate_matmul_input1(node: Node): - quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) while isinstance(node, Node) and node.op == "call_function": if node.target in [ torch.ops.aten.permute.default, @@ -142,11 +142,11 @@ def annotate_matmul_input1(node: Node): def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: - from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY from executorch.backends.qualcomm.quantizer.quantizer import ( get_ptq_per_channel_quant_config, QuantizationConfig, ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import Node diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index 56169e39a2e..2e49a2344b8 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -4,7 +4,10 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_default_8bit_qnn_ptq_config, + QnnQuantizer, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -61,6 +64,8 @@ def main() -> None: # Get quantizer quantizer = QnnQuantizer() + quant_config = get_default_8bit_qnn_ptq_config() + quantizer.set_bit8_op_quant_config(quant_config) # Typical pytorch 2.0 quantization flow m = torch.export.export(model.eval(), example_inputs).module() diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 100008e91ca..06225be2d1c 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -16,7 +16,13 @@ import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QnnQuantizer, + QuantDtype, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -31,11 +37,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e class SimpleADB: @@ -185,58 +187,36 @@ def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): callback() -def ptq_calibrate(captured_model, quantizer, dataset): - annotated_model = prepare_pt2e(captured_model, quantizer) - print("Quantizing(PTQ) the model...") - # calibration - if callable(dataset): - dataset(annotated_model) - else: - for data in dataset: - annotated_model(*data) - return annotated_model - - -def qat_train(ori_model, captured_model, quantizer, dataset): - data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( - prepare_qat_pt2e(captured_model, quantizer) - ) - optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) - criterion = torch.nn.CrossEntropyLoss() - for i, d in enumerate(data): - print(f"Epoch {i}") - if i > 3: - # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) - if i > 2: - # Freeze batch norm mean and variance estimates - annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) - - output = annotated_model(*d) - loss = criterion(output, targets[i]) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) - ) - - def make_quantizer( - quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, + quant_dtype: Optional[QuantDtype], custom_annotations=(), per_channel_conv=True, per_channel_linear=False, act_observer=MovingAverageMinMaxObserver, - is_qat=False, ): quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) quantizer.set_per_channel_conv_quant(per_channel_conv) quantizer.set_per_channel_linear_quant(per_channel_linear) - quantizer.set_quant_config(quant_dtype, is_qat, act_observer) + + if quant_dtype == QuantDtype.use_8a8w: + quantizer.set_bit8_op_quant_config( + get_default_8bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a16w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + return quantizer @@ -255,22 +235,18 @@ def build_executorch_binary( metadata=None, dump_intermediate_outputs=False, custom_pass_config=frozenset(), - qat_training_data=None, ): if quant_dtype is not None: + quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) captured_model = torch.export.export(model, inputs).module() - if qat_training_data: - quantizer = custom_quantizer or make_quantizer( - quant_dtype=quant_dtype, is_qat=True - ) - # qat training - annotated_model = qat_train( - model, captured_model, quantizer, qat_training_data - ) + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing the model...") + # calibration + if callable(dataset): + dataset(annotated_model) else: - quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) - # ptq calibration - annotated_model = ptq_calibrate(captured_model, quantizer, dataset) + for data in dataset: + annotated_model(*data) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs, custom_pass_config) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index ba281864a9f..fd368d73f1f 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -144,7 +144,6 @@ def check_embedding_byte_registered(): def get_qnn_quantizer( pt2e_quantize: str, quantization_mode: Optional[str] = None, - is_qat: bool = False, ): try: from executorch.backends.qualcomm.quantizer.custom_annotation import ( # pyre-fixme[21] @@ -153,6 +152,8 @@ def get_qnn_quantizer( # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.quantizer.quantizer` from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, QnnQuantizer, QuantDtype, ) @@ -174,7 +175,6 @@ def get_qnn_quantizer( custom_annotations = () if quant_config == "8a8w": quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16] - qnn_quantizer.set_quant_config(quant_dtype, is_qat=is_qat) elif quant_config == "16a16w": quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16] # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w @@ -184,17 +184,20 @@ def get_qnn_quantizer( ) qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + qnn_quantizer.set_bit16_op_quant_config( + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + get_default_16bit_qnn_ptq_config(act_observer=MinMaxObserver) ) elif quant_config == "16a4w": # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. quant_dtype = QuantDtype.use_16a4w - # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. - qnn_quantizer.set_quant_config( - quant_dtype, is_qat=is_qat, act_observer=MinMaxObserver + qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) + qnn_quantizer.set_bit16_op_quant_config( + # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) ) + qnn_quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`. custom_annotations = (custom_annotate_llama_matmul_16a8w,) else: