diff --git a/backends/arm/_passes/decompose_masked_fill_pass.py b/backends/arm/_passes/decompose_masked_fill_pass.py index 09a3492a0c6..49a4bbb9b4b 100644 --- a/backends/arm/_passes/decompose_masked_fill_pass.py +++ b/backends/arm/_passes/decompose_masked_fill_pass.py @@ -17,7 +17,7 @@ edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) -aten_ops = (torch.ops.aten.masked_fill.Scalar,) +aten_ops = (torch.ops.aten.masked_fill.Scalar, torch.ops.aten.masked_fill_.Scalar) def _get_decomposition(op) -> tuple: @@ -26,7 +26,7 @@ def _get_decomposition(op) -> tuple: exir_ops.edge.aten.where.self, exir_ops.edge.aten.full_like.default, ) - if op in aten_ops: + elif op in aten_ops: return ( torch.ops.aten.where.self, torch.ops.aten.full_like.default, @@ -44,7 +44,7 @@ class DecomposeMaskedFillPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass} def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in (edge_ops + aten_ops): + if op not in (*aten_ops, *edge_ops): return super().call_operator(op, args, kwargs, meta, updated) x, mask, scalar = args diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index d41007e1e76..938732fa91a 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -45,9 +45,7 @@ class TestCLIPTextModelWithProjection: "executorch_exir_dialects_edge__ops_aten_index_select_default": 1, "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_where_self": 1, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2, - "torch.ops.aten.scalar_tensor.default": 1, "torch.ops.higher_order.executorch_call_delegate": 2, } diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 85ac2733e70..e04d8bd44a5 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -72,15 +72,8 @@ def test_conformer_tosa_INT(): aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, - ) - pipeline.pop_stage("check_count.exir") - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=TestConformer.rtol, atol=TestConformer.atol, + rtol=TestConformer.rtol, ) pipeline.run() @@ -93,38 +86,26 @@ def test_conformer_u55_INT(): pipeline = EthosU55PipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_ops=TestConformer.aten_ops, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, + atol=TestConformer.atol, + rtol=TestConformer.rtol, ) - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=1.0, - atol=5.0, - ) + pipeline.pop_stage("check_count.exir") pipeline.run() @common.XfailIfNoCorstone320 -@pytest.mark.xfail(reason="All IO needs to have the same data type (MLETORCH-635)") def test_conformer_u85_INT(): pipeline = EthosU85PipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_ops=TestConformer.aten_ops, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, - ) - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=1.0, - atol=5.0, + atol=TestConformer.atol, + rtol=TestConformer.rtol, ) pipeline.run() @@ -137,16 +118,9 @@ def test_conformer_vgf_quant(): aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, - quantize=True, - ) - pipeline.pop_stage("check_count.exir") - pipeline.change_args( - "run_method_and_compare_outputs", - get_test_inputs( - TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - ), - rtol=TestConformer.rtol, atol=TestConformer.atol, + rtol=TestConformer.rtol, + quantize=True, ) pipeline.run()