@@ -48,6 +48,18 @@ def test_fp32_linear(self):
4848 num_batch_dims = num_batch_dims ,
4949 )
5050
51+ def test_qc8_linear (self ):
52+ for use_bias in (True , False ):
53+ for num_batch_dims in range (1 , 3 ):
54+ self ._test_linear (
55+ lambda in_size , out_size : torch .nn .Linear (
56+ in_size , out_size , bias = use_bias # noqa
57+ ),
58+ uses_bias = use_bias ,
59+ quant_type = "per_channel" ,
60+ num_batch_dims = num_batch_dims ,
61+ )
62+
5163 def test_fp32_addmm (self ):
5264 """
5365 Note that the ConvertToLinear pass requires the weight matrix to be transposed.
@@ -107,7 +119,7 @@ def forward(self, x):
107119 ),
108120 num_batch_dims = num_batch_dims ,
109121 uses_bias = use_bias ,
110- quant = True ,
122+ quant_type = "per_tensor" ,
111123 )
112124
113125 def test_qs8_linear (self ):
@@ -119,6 +131,7 @@ def test_qs8_linear(self):
119131 ),
120132 uses_bias = use_bias ,
121133 num_batch_dims = num_batch_dims ,
134+ quant_type = "per_tensor" ,
122135 )
123136
124137 @unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
@@ -726,7 +739,7 @@ def _test_linear(
726739 make_module ,
727740 uses_bias ,
728741 num_batch_dims = 1 ,
729- quant = False ,
742+ quant_type = None ,
730743 dtype : torch .dtype = torch .float ,
731744 atol = 1e-03 ,
732745 ):
@@ -746,6 +759,8 @@ def _test_linear(
746759 input_sizes = [4 , 37 , 17 ]
747760 output_sizes = [4 , 17 , 37 ]
748761
762+ quant = quant_type is not None
763+
749764 """
750765 Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias),
751766 which ares then transformed into aten.linear.default by the ConvertToLinear pass.
@@ -769,7 +784,19 @@ def _test_linear(
769784 tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
770785
771786 if quant :
772- tester .quantize ()
787+ if quant_type == "per_channel" :
788+ quant_config = get_symmetric_quantization_config (
789+ is_per_channel = True ,
790+ is_dynamic = False ,
791+ )
792+ elif quant_type == "per_tensor" :
793+ quant_config = get_symmetric_quantization_config (
794+ is_per_channel = False ,
795+ is_dynamic = False ,
796+ )
797+ else :
798+ raise ValueError (f"Unsupported quant type { quant_type } " )
799+ tester .quantize (Quantize (quantization_config = quant_config ))
773800
774801 tester .export ()
775802 tester .check_count ({aten_op : 1 })
0 commit comments