From 4c59916654c432fe80a44522d537ab17fc6f5e70 Mon Sep 17 00:00:00 2001 From: Milos Poletanovic Date: Fri, 12 Dec 2025 13:45:16 +0100 Subject: [PATCH] [MLIR][Transform] Safely erase transform ops by collecting first Avoids runtime crashes caused by deleting operations inside a walk. Operations are gathered during the walk and then erased in the correct dependency order after the walk finishes. --- .../transform-dialect-erase-schedule.mlir | 15 +++++++++++++++ .../Transform/TestTransformDialectInterpreter.cpp | 7 ++++++- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Transform/transform-dialect-erase-schedule.mlir diff --git a/mlir/test/Dialect/Transform/transform-dialect-erase-schedule.mlir b/mlir/test/Dialect/Transform/transform-dialect-erase-schedule.mlir new file mode 100644 index 0000000000000..a258568d8679b --- /dev/null +++ b/mlir/test/Dialect/Transform/transform-dialect-erase-schedule.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-transform-dialect-erase-schedule | FileCheck %s + +module attributes {transform.with_named_sequence} { + func.func @transform_example(%arg0: !transform.any_op) { + %transform_copy = transform.structured.match ops{["linalg.copy"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.nvgpu.rewrite_copy_as_tma %transform_copy : (!transform.any_op) -> () + transform.yield + } +} + +// CHECK-LABEL: module attributes {transform.with_named_sequence} { +// CHECK-NEXT: func.func @transform_example(%arg0: !transform.any_op) { +// CHECK-NEXT: transform.yield +// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index 1273414cd4dfc..bd2de68fb276d 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -36,13 +36,18 @@ struct TestTransformDialectEraseSchedulePass } void runOnOperation() override { + SmallVector opsToDelete; getOperation()->walk([&](Operation *nestedOp) { if (isa(nestedOp)) { - nestedOp->erase(); + opsToDelete.push_back(nestedOp); return WalkResult::skip(); } return WalkResult::advance(); }); + for (Operation *op : llvm::reverse(opsToDelete)) { + // erase the operation + op->erase(); + } } }; } // namespace