Skip to content

Commit 87d828a

Browse files
mcr229facebook-github-bot
authored andcommitted
don't partition max pool with ceil mode (#3578)
Summary: Pull Request resolved: #3578 XNNPACK doesn't support max pooling with ceil mode, so we should not be partitioning these nodes where ceil mode is True Resolving this issue: #3567 Reviewed By: mergennachin, digantdesai Differential Revision: D57228128 fbshipit-source-id: ee57a783d314d69ebef57f0e1707c0d038582a31
1 parent e288039 commit 87d828a

2 files changed

Lines changed: 41 additions & 0 deletions

File tree

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
166166
return True
167167

168168
def check_node_has_valid_dtype(self, node):
169+
# max_pool2d_with_indicies returns indicies which is int64
170+
# this is supportable within XNNPACK
169171
if node.target in {exir_ops.edge.aten.max_pool2d_with_indices.default}:
170172
return True
171173

@@ -268,13 +270,16 @@ def maxpool2d_with_indices(
268270
) -> bool:
269271
"""
270272
Only if the first output value is consumed in the graph
273+
and it is not in ceil mode
271274
"""
272275
users = list(node.users.keys())
276+
is_ceil_mode = len(node.args) >= 6 and node.args[5]
273277
return (
274278
True
275279
if len(users) == 1
276280
and users[0].target == operator.getitem
277281
and users[0].args[1] == 0
282+
and not is_ceil_mode
278283
else False
279284
)
280285

backends/xnnpack/test/ops/maxpool2d.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
3838
def forward(self, x):
3939
return self.max_pool2d_module(x)[1]
4040

41+
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
42+
def __init__(self):
43+
super().__init__()
44+
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
45+
46+
def forward(self, x):
47+
return self.max_pool2d_module(x)
48+
4149
def _test_maxpool2d(self, inputs):
4250
"""
4351
Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op
@@ -99,6 +107,34 @@ def test_fp32_maxpool2d_unsupported(self):
99107
)
100108
)
101109

110+
def test_fp32_maxpool2d_unsupported_ceilmode(self):
111+
"""
112+
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
113+
"""
114+
inputs = (torch.randn(1, 32, 23, 23),)
115+
(
116+
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
117+
.export()
118+
.check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1})
119+
.to_edge()
120+
.check_count(
121+
{
122+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
123+
}
124+
)
125+
.partition()
126+
# We expect it not be be delegated.
127+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
128+
.check_count(
129+
{
130+
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
131+
}
132+
)
133+
.to_executorch()
134+
.serialize()
135+
.run_method_and_compare_outputs()
136+
)
137+
102138
def test_qs8_maxpool2d(self):
103139
class MaxPool(torch.nn.Module):
104140
def __init__(self, maxpool_params):

0 commit comments

Comments
 (0)