diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index c0b6663f729..99aa2a0a60e 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -3,7 +3,7 @@ # Please refer to the license found in the LICENSE file in the root directory of the source tree. import logging -from typing import List, Optional +from typing import Callable, List, Optional, Tuple import coremltools as ct @@ -104,3 +104,17 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags ) + + def ops_to_not_decompose( + self, ep: ExportedProgram + ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + do_not_decompose = [] + op_support = OperatorsSupportedForCoreMLBackend() + for node in ep.graph.nodes: + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and op_support.is_node_supported(None, node) + ): + do_not_decompose.append(node.target) + return do_not_decompose, None diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 2b84558a0f0..03aac6a8611 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -13,6 +13,7 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.exir.backend.utils import format_delegated_graph class TestCoreMLPartitioner(unittest.TestCase): @@ -79,6 +80,50 @@ def test_vit_skip_conv(self): "getitem", ] + def test_ops_to_not_decompose(self): + class Model(torch.nn.Module): + def forward(self, q, k, v, mask): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, attn_mask=mask + ) + + model = Model() + model.eval() + + batch_size = 1 + n_heads = 12 + seq_len = 1 + max_seq_length = 32 + embedding_dim = 16 + q = torch.randn(batch_size, n_heads, seq_len, embedding_dim) + k = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim) + v = torch.randn(batch_size, n_heads, max_seq_length, embedding_dim) + mask = torch.randn(seq_len, max_seq_length) + example_inputs = (q, k, v, mask) + ep = torch.export.export(model, example_inputs) + coreml_partitioner = CoreMLPartitioner() + + # Using to_edge_transform_and_lower, we expect SDPA will be preserved and show up in delegated graph + edge_program_manager = executorch.exir.to_edge_transform_and_lower( + ep, partitioner=[coreml_partitioner] + ) + self.assertTrue( + "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default" + in format_delegated_graph( + edge_program_manager.exported_program().graph_module + ) + ) + + # Using to_edge flow, we expect SDPA will be decomposed and not show up in delegated graph + edge_program_manager2 = executorch.exir.to_edge(ep) + edge_program_manager2.to_backend(coreml_partitioner) + self.assertTrue( + "executorch.exir.dialects.edge._ops.aten.scaled_dot_product_attention.default" + not in format_delegated_graph( + edge_program_manager2.exported_program().graph_module + ) + ) + def test_buffer(self): embedding_dim = 3 max_seq_len = 2 @@ -129,4 +174,5 @@ def forward(self, q, k_val, input_pos): test_runner = TestCoreMLPartitioner() test_runner.test_add_sub_skip_mm() test_runner.test_vit_skip_conv() + test_runner.test_ops_to_not_decompose() test_runner.test_buffer()