From 25eae446d14f5d54693d4f590ec88dcad8d3672a Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 6 May 2024 21:01:33 -0700 Subject: [PATCH] fix constant tagging in mps backend (#3503) Summary: Test with https://github.com/pytorch/executorch/pull/3399 and this command passes ``` python -m examples.models.llama2.export_llama -kv --mps ``` Without this diff, it will error out ``` in _verify_exported_program_signature raise SpecViolationError( torch._export.verifier.SpecViolationError: Buffer output getitem_1 does not point to a buffer that exists. Dict of buffers that are mutated, in order: {'getitem_1': 'layers_0_attention_SDPA_kv_cache_k_cache', 'getitem': 'layers_0_attention_SDPA_kv_cache_v_cache', 'getitem_3': 'layers_1_attention_SDPA_kv_cache_k_cache', 'getitem_2': 'layers_1_attention_SDPA_kv_cache_v_cache', 'getitem_5': 'layers_2_attention_SDPA_kv_cache_k_cache', 'getitem_4': 'layers_2_attention_SDPA_kv_cache_v_cache', 'getitem_7': 'layers_3_attention_SDPA_kv_cache_k_cache', 'getitem_6': 'layers_3_attention_SDPA_kv_cache_v_cache', 'getitem_9': 'layers_4_attention_SDPA_kv_cache_k_cache', 'getitem_8': 'layers_4_attention_SDPA_kv_cache_v_cache'} Buffer nodes available: [] ``` The root cause is that by `is_parameter`, it tags all data including mutable buffers. Reviewed By: larryliu0820 Differential Revision: D56941763 --- backends/apple/mps/partition/mps_partitioner.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py index e5497389d14..84472022f6a 100644 --- a/backends/apple/mps/partition/mps_partitioner.py +++ b/backends/apple/mps/partition/mps_partitioner.py @@ -9,7 +9,6 @@ import torch from executorch.backends.apple.mps.mps_preprocess import MPSBackend from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors -from executorch.backends.apple.mps.utils.mps_utils import is_parameter from executorch.backends.transforms import get_shape from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -43,12 +42,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): self.edge_program = edge_program def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - # Parameters are supported if any of their users are supported - if is_parameter(self.edge_program, node): - return any( - self.is_node_supported(submodules, user) for user in node.users.keys() - ) - if node.op != "call_function": return False @@ -132,6 +125,7 @@ def partition(self, edge_program: ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program=edge_program) if self.check_partitions(partitions): self.tag_nodes(partitions) + # Tag constant data that are used by the supported ops in MPS backend. tag_constant_data(edge_program) x = PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags