Skip to content

Commit 8a63430

Browse files
mcr229facebook-github-bot
authored andcommitted
add per-channel tests for linear (#3551)
Summary: Pull Request resolved: #3551 Adding a test for qc8 linear Reviewed By: digantdesai Differential Revision: D55941565 fbshipit-source-id: ecc870dbd879e00790a1052aaf3b4be748b02c94
1 parent 6c56122 commit 8a63430

1 file changed

Lines changed: 30 additions & 3 deletions

File tree

backends/xnnpack/test/ops/linear.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def test_fp32_linear(self):
4848
num_batch_dims=num_batch_dims,
4949
)
5050

51+
def test_qc8_linear(self):
52+
for use_bias in (True, False):
53+
for num_batch_dims in range(1, 3):
54+
self._test_linear(
55+
lambda in_size, out_size: torch.nn.Linear(
56+
in_size, out_size, bias=use_bias # noqa
57+
),
58+
uses_bias=use_bias,
59+
quant_type="per_channel",
60+
num_batch_dims=num_batch_dims,
61+
)
62+
5163
def test_fp32_addmm(self):
5264
"""
5365
Note that the ConvertToLinear pass requires the weight matrix to be transposed.
@@ -107,7 +119,7 @@ def forward(self, x):
107119
),
108120
num_batch_dims=num_batch_dims,
109121
uses_bias=use_bias,
110-
quant=True,
122+
quant_type="per_tensor",
111123
)
112124

113125
def test_qs8_linear(self):
@@ -119,6 +131,7 @@ def test_qs8_linear(self):
119131
),
120132
uses_bias=use_bias,
121133
num_batch_dims=num_batch_dims,
134+
quant_type="per_tensor",
122135
)
123136

124137
@unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.")
@@ -726,7 +739,7 @@ def _test_linear(
726739
make_module,
727740
uses_bias,
728741
num_batch_dims=1,
729-
quant=False,
742+
quant_type=None,
730743
dtype: torch.dtype = torch.float,
731744
atol=1e-03,
732745
):
@@ -746,6 +759,8 @@ def _test_linear(
746759
input_sizes = [4, 37, 17]
747760
output_sizes = [4, 17, 37]
748761

762+
quant = quant_type is not None
763+
749764
"""
750765
Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias),
751766
which ares then transformed into aten.linear.default by the ConvertToLinear pass.
@@ -769,7 +784,19 @@ def _test_linear(
769784
tester = Tester(module, inputs, dynamic_shapes=dynamic_shape)
770785

771786
if quant:
772-
tester.quantize()
787+
if quant_type == "per_channel":
788+
quant_config = get_symmetric_quantization_config(
789+
is_per_channel=True,
790+
is_dynamic=False,
791+
)
792+
elif quant_type == "per_tensor":
793+
quant_config = get_symmetric_quantization_config(
794+
is_per_channel=False,
795+
is_dynamic=False,
796+
)
797+
else:
798+
raise ValueError(f"Unsupported quant type {quant_type}")
799+
tester.quantize(Quantize(quantization_config=quant_config))
773800

774801
tester.export()
775802
tester.check_count({aten_op: 1})

0 commit comments

Comments
 (0)