Skip to content

Commit 7b5c60d

Browse files
authored
Cortex-M backend: Add quantized int8 batch matmul (CMSIS-NN) (#17799)
### Summary Add cortex_m::quantized_batch_matmul wrapping arm_batch_matmul_s8. The RHS is always pre-transposed: constant RHS (parameters) are transposed at AOT time in the pass, dynamic RHS get a cortex_m::transpose node inserted in the graph. It would be preferable if we could pre-compute or cache the constant RHS kernel sums, but I could not find any public CMSIS-NN APIs that would allow us to do so. Fixes #16109 Authored with Claude. ### Test plan ``` pytest backends/cortex_m/test/ops/test_batch_matmul.py ```
1 parent fa4c6a0 commit 7b5c60d

8 files changed

Lines changed: 409 additions & 5 deletions

File tree

backends/cortex_m/CMakeLists.txt

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,22 @@ endif()
5353

5454
# Cortex-M ops kernel sources
5555
set(_cortex_m_kernels__srcs
56-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
5756
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
57+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
58+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
59+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_pad.cpp
60+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
5861
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5962
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_avg_pool2d.cpp
63+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_batch_matmul.cpp
6064
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_conv2d.cpp
6165
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_depthwise_conv2d.cpp
62-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_max_pool2d.cpp
6366
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
67+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_max_pool2d.cpp
6468
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
6569
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_transpose_conv2d.cpp
66-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
67-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
6870
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp
6971
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
70-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_pad.cpp
7172
)
7273

7374
# Generate C++ bindings to register kernels into Executorch
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "cortex_m_ops_common.h"
10+
11+
extern "C" {
12+
#include "arm_nnfunctions.h"
13+
}
14+
15+
namespace cortex_m {
16+
namespace native {
17+
18+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
19+
20+
namespace {
21+
22+
bool validate_batch_matmul_arguments(
23+
KernelRuntimeContext& context,
24+
const Tensor& lhs,
25+
const Tensor& rhs_transposed,
26+
const Tensor& out) {
27+
if (lhs.scalar_type() != ScalarType::Char ||
28+
rhs_transposed.scalar_type() != ScalarType::Char ||
29+
out.scalar_type() != ScalarType::Char) {
30+
ET_LOG(Error, "quantized_batch_matmul: all tensors must be int8");
31+
context.fail(Error::InvalidArgument);
32+
return false;
33+
}
34+
35+
if (lhs.dim() != 3 || rhs_transposed.dim() != 3 || out.dim() != 3) {
36+
ET_LOG(Error, "quantized_batch_matmul: all tensors must be 3-D");
37+
context.fail(Error::InvalidArgument);
38+
return false;
39+
}
40+
41+
if (lhs.size(0) != rhs_transposed.size(0)) {
42+
ET_LOG(Error, "quantized_batch_matmul: batch dims must match");
43+
context.fail(Error::InvalidArgument);
44+
return false;
45+
}
46+
47+
if (lhs.size(2) != rhs_transposed.size(2)) {
48+
ET_LOG(Error, "quantized_batch_matmul: inner dims must match");
49+
context.fail(Error::InvalidArgument);
50+
return false;
51+
}
52+
53+
if (out.size(0) != lhs.size(0) || out.size(1) != lhs.size(1) ||
54+
out.size(2) != rhs_transposed.size(1)) {
55+
ET_LOG(Error, "quantized_batch_matmul: output shape mismatch");
56+
context.fail(Error::InvalidArgument);
57+
return false;
58+
}
59+
60+
return true;
61+
}
62+
63+
} // namespace
64+
65+
Tensor& quantized_batch_matmul_out(
66+
KernelRuntimeContext& context,
67+
const Tensor& lhs,
68+
int64_t lhs_offset,
69+
const Tensor& rhs_transposed,
70+
int64_t rhs_offset,
71+
int64_t output_offset,
72+
int64_t output_multiplier,
73+
int64_t output_shift,
74+
Tensor& out) {
75+
if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) {
76+
return out;
77+
}
78+
79+
const int32_t batch = static_cast<int32_t>(lhs.size(0));
80+
const int32_t lhs_rows = static_cast<int32_t>(lhs.size(1));
81+
const int32_t inner = static_cast<int32_t>(lhs.size(2));
82+
const int32_t rhs_cols = static_cast<int32_t>(rhs_transposed.size(1));
83+
84+
const cmsis_nn_dims lhs_dims = {1, batch, lhs_rows, inner};
85+
const cmsis_nn_dims rhs_dims = {1, batch, rhs_cols, inner};
86+
const cmsis_nn_dims out_dims = {1, batch, lhs_rows, rhs_cols};
87+
88+
const cmsis_nn_bmm_params bmm_params = {
89+
/* adj_x */ false,
90+
/* adj_y */ false,
91+
/* fc_params */
92+
{static_cast<int32_t>(lhs_offset),
93+
static_cast<int32_t>(rhs_offset),
94+
static_cast<int32_t>(output_offset),
95+
/* activation */
96+
{std::numeric_limits<int8_t>::min(),
97+
std::numeric_limits<int8_t>::max()}}};
98+
99+
cmsis_nn_per_tensor_quant_params quant_params;
100+
quant_params.multiplier = static_cast<int32_t>(output_multiplier);
101+
quant_params.shift = static_cast<int32_t>(output_shift);
102+
103+
const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims);
104+
105+
cmsis_nn_context ctx;
106+
ctx.buf = nullptr;
107+
ctx.size = 0;
108+
109+
if (buf_size > 0) {
110+
auto buffer_or_error = context.allocate_temp(buf_size);
111+
if (!buffer_or_error.ok()) {
112+
ET_LOG(
113+
Error,
114+
"quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)",
115+
buf_size);
116+
context.fail(buffer_or_error.error());
117+
return out;
118+
}
119+
ctx.buf = buffer_or_error.get();
120+
ctx.size = buf_size;
121+
}
122+
123+
const arm_cmsis_nn_status status = arm_batch_matmul_s8(
124+
&ctx,
125+
&bmm_params,
126+
&quant_params,
127+
&lhs_dims,
128+
lhs.const_data_ptr<int8_t>(),
129+
&rhs_dims,
130+
rhs_transposed.const_data_ptr<int8_t>(),
131+
&out_dims,
132+
out.mutable_data_ptr<int8_t>());
133+
134+
if (status != ARM_CMSIS_NN_SUCCESS) {
135+
ET_LOG(
136+
Error,
137+
"quantized_batch_matmul: arm_batch_matmul_s8 failed with status [%d]",
138+
status);
139+
context.fail(Error::Internal);
140+
}
141+
142+
return out;
143+
}
144+
145+
} // namespace native
146+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,59 @@ def quantized_mul_impl(
255255
return result
256256

257257

258+
# ===================================================================
259+
# QUANTIZED BATCH MATMUL OPERATION DEFINITION
260+
# ===================================================================
261+
lib.define(
262+
"quantized_batch_matmul("
263+
"Tensor lhs, int lhs_zero_point, "
264+
"Tensor rhs_transposed, int rhs_zero_point, "
265+
"int output_zero_point, int output_multiplier, int output_shift) -> Tensor"
266+
)
267+
lib.define(
268+
"quantized_batch_matmul.out("
269+
"Tensor lhs, int lhs_zero_point, "
270+
"Tensor rhs_transposed, int rhs_zero_point, "
271+
"int output_zero_point, int output_multiplier, int output_shift, "
272+
"*, Tensor(a!) out) -> Tensor(a!)"
273+
)
274+
275+
276+
@register_fake("cortex_m::quantized_batch_matmul")
277+
def quantized_batch_matmul_meta(
278+
lhs: torch.Tensor,
279+
lhs_zero_point: int,
280+
rhs_transposed: torch.Tensor,
281+
rhs_zero_point: int,
282+
output_zero_point: int,
283+
output_multiplier: int,
284+
output_shift: int,
285+
) -> torch.Tensor:
286+
batch, lhs_rows, inner = lhs.shape
287+
batch_rhs, rhs_cols, inner_rhs = rhs_transposed.shape
288+
assert batch == batch_rhs and inner == inner_rhs
289+
return torch.empty((batch, lhs_rows, rhs_cols), dtype=torch.int8, device=lhs.device)
290+
291+
292+
@impl(lib, "quantized_batch_matmul", "CompositeExplicitAutograd")
293+
def quantized_batch_matmul_impl(
294+
lhs: torch.Tensor,
295+
lhs_zero_point: int,
296+
rhs_transposed: torch.Tensor,
297+
rhs_zero_point: int,
298+
output_zero_point: int,
299+
output_multiplier: int,
300+
output_shift: int,
301+
) -> torch.Tensor:
302+
# Offsets are negated zero points (CMSIS-NN convention)
303+
lhs_fp = lhs.to(torch.float32) + float(lhs_zero_point)
304+
rhs_t_fp = rhs_transposed.to(torch.float32) + float(rhs_zero_point)
305+
rhs_fp = rhs_t_fp.permute(0, 2, 1)
306+
acc = torch.bmm(lhs_fp, rhs_fp).to(torch.int32)
307+
result = requantize_cmsis(acc, output_multiplier, output_shift)
308+
return torch.clamp(result + output_zero_point, -128, 127).to(torch.int8)
309+
310+
258311
# ===================================================================
259312
# MINIMUM/MAXIMUM OPERATION DEFINITIONS
260313
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,9 @@
9393
kernels:
9494
- arg_meta: null
9595
kernel_name: cortex_m::quantized_max_pool2d_out
96+
97+
- func: cortex_m::quantized_batch_matmul.out(Tensor lhs, int lhs_zero_point, Tensor rhs_transposed, int rhs_zero_point, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!)
98+
variants: function
99+
kernels:
100+
- arg_meta: null
101+
kernel_name: cortex_m::quantized_batch_matmul_out

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.transforms.utils import (
1717
create_constant_placeholder,
1818
get_param_tensor,
19+
is_param_node,
1920
)
2021

2122
from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
@@ -372,6 +373,52 @@ def _get_transpose_conv2d_replacement(self, node) -> tuple:
372373
)
373374
return exir_ops.edge.cortex_m.quantized_transpose_conv2d.default, new_args
374375

376+
def _get_bmm_replacement(self, node):
377+
lhs_scale = node.meta["input_qparams"][0].scale
378+
lhs_zp = node.meta["input_qparams"][0].zp
379+
rhs_scale = node.meta["input_qparams"][1].scale
380+
rhs_zp = node.meta["input_qparams"][1].zp
381+
output_scale = node.meta["output_qparams"][0].scale
382+
output_zp = node.meta["output_qparams"][0].zp
383+
384+
output_mult, output_shift = quantize_multiplier_aot(
385+
(lhs_scale * rhs_scale) / output_scale
386+
)
387+
388+
lhs_node = node.args[0]
389+
rhs_node = node.args[1]
390+
391+
is_constant_rhs = is_param_node(self.exported_program, rhs_node)
392+
if is_constant_rhs:
393+
rhs_tensor = get_param_tensor(self.exported_program, rhs_node)
394+
rhs_transposed_tensor = rhs_tensor.permute(0, 2, 1).contiguous()
395+
with node.graph.inserting_after(rhs_node):
396+
rhs_transposed = create_constant_placeholder(
397+
self.exported_program,
398+
node.graph,
399+
node.name + "_rhs_transposed",
400+
InputKind.PARAMETER,
401+
rhs_transposed_tensor,
402+
)
403+
else:
404+
with node.graph.inserting_before(node):
405+
rhs_transposed = node.graph.create_node(
406+
"call_function",
407+
target=exir_ops.edge.cortex_m.transpose.default,
408+
args=(rhs_node, [0, 2, 1]),
409+
)
410+
411+
args = (
412+
lhs_node,
413+
-lhs_zp,
414+
rhs_transposed,
415+
-rhs_zp,
416+
output_zp,
417+
output_mult,
418+
output_shift,
419+
)
420+
return exir_ops.edge.cortex_m.quantized_batch_matmul.default, args
421+
375422
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
376423
modified = False
377424
for node in graph_module.graph.nodes:
@@ -393,6 +440,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
393440
op, args = self._get_transpose_conv2d_replacement(node)
394441
else:
395442
op, args = self._get_convolution_replacement(node)
443+
case exir_ops.edge.aten.bmm.default:
444+
op, args = self._get_bmm_replacement(node)
396445
case _:
397446
continue
398447

backends/cortex_m/quantizer/pattern_checkers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,33 @@ def check_quantization_config(
312312
return is_int8 and is_per_tensor
313313

314314

315+
class CortexMBmmCheck(PatternCheck):
316+
317+
@classmethod
318+
def check_pattern(cls, pattern):
319+
for node in pattern:
320+
if len(node.all_input_nodes) == 2:
321+
t1 = get_first_fake_tensor(node.all_input_nodes[0])
322+
t2 = get_first_fake_tensor(node.all_input_nodes[1])
323+
if t1.dim() != 3 or t2.dim() != 3:
324+
return False
325+
if t1.shape[0] != t2.shape[0]:
326+
return False
327+
if t1.shape[2] != t2.shape[1]:
328+
return False
329+
return True
330+
331+
@classmethod
332+
def check_quantization_config(
333+
cls, pattern: list[Node], quantization_config: CortexMQuantizationConfig
334+
):
335+
is_per_tensor = PatternCheck.is_per_tensor(
336+
quantization_config.get_input_act_qspec()
337+
) and PatternCheck.is_per_tensor(quantization_config.get_output_act_qspec())
338+
is_int8 = cls.is_int8_activations(quantization_config)
339+
return is_per_tensor and is_int8
340+
341+
315342
class CortexMMaxPool2DCheck(PatternCheck):
316343
@classmethod
317344
def _pool_arg_as_bool(cls, node: Node, index: int, default: bool) -> bool:

backends/cortex_m/quantizer/quantizer_support.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from executorch.backends.cortex_m.quantizer.pattern_checkers import (
88
CortexMAddMulCheck,
99
CortexMAvgPool2DCheck,
10+
CortexMBmmCheck,
1011
CortexMConv2DCheck,
1112
CortexMConvTranspose2DCheck,
1213
CortexMLinearCheck,
@@ -118,11 +119,16 @@
118119
(torch.ops.aten.max_pool2d_with_indices.default,): CortexMMaxPool2DCheck,
119120
}
120121

122+
BMM_OP_PATTERNS = {
123+
(torch.ops.aten.bmm.default,): CortexMBmmCheck,
124+
}
125+
121126
CORTEX_M_QUANTIZER_SUPPORT_DICT = (
122127
BINARY_OP_PATTERNS
123128
| LINEAR_OP_PATTERNS
124129
| CONV_OP_PATTERNS
125130
| SOFTMAX_OP_PATTERNS
126131
| CONV_TRANSPOSE_OP_PATTERNS
127132
| POOL_OP_PATTERNS
133+
| BMM_OP_PATTERNS
128134
)

0 commit comments

Comments
 (0)