diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py index e5497389d14..84472022f6a 100644 --- a/backends/apple/mps/partition/mps_partitioner.py +++ b/backends/apple/mps/partition/mps_partitioner.py @@ -9,7 +9,6 @@ import torch from executorch.backends.apple.mps.mps_preprocess import MPSBackend from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors -from executorch.backends.apple.mps.utils.mps_utils import is_parameter from executorch.backends.transforms import get_shape from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -43,12 +42,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): self.edge_program = edge_program def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - # Parameters are supported if any of their users are supported - if is_parameter(self.edge_program, node): - return any( - self.is_node_supported(submodules, user) for user in node.users.keys() - ) - if node.op != "call_function": return False @@ -132,6 +125,7 @@ def partition(self, edge_program: ExportedProgram) -> PartitionResult: partitions = self.generate_partitions(edge_program=edge_program) if self.check_partitions(partitions): self.tag_nodes(partitions) + # Tag constant data that are used by the supported ops in MPS backend. tag_constant_data(edge_program) x = PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags