[mlir] Added the dialect inliner to the SCF dialect
Currently no restrictions are added to the destination regions.
Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, stephenneuendorffer, Joonsoo, grosul1, Kayjukh, jurahul, msifontes
Tags: #mlir
Differential Revision: https://reviews.llvm.org/D82336
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index f980cdb..559a5f1 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -19,11 +19,46 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
+#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::scf;
//===----------------------------------------------------------------------===//
+// SCFDialect Dialect Interfaces
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct SCFInlinerInterface : public DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ // We don't have any special restrictions on what can be inlined into
+ // destination regions (e.g. while/conditional bodies). Always allow it.
+ bool isLegalToInline(Region *dest, Region *src,
+ BlockAndValueMapping &valueMapping) const final {
+ return true;
+ }
+ // Operations in scf dialect are always legal to inline since they are
+ // pure.
+ bool isLegalToInline(Operation *, Region *,
+ BlockAndValueMapping &) const final {
+ return true;
+ }
+ // Handle the given inlined terminator by replacing it with a new operation
+ // as necessary. Required when the region has only one block.
+ void handleTerminator(Operation *op,
+ ArrayRef<Value> valuesToRepl) const final {
+ auto retValOp = dyn_cast<YieldOp>(op);
+ if (!retValOp)
+ return;
+
+ for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
+ std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
+ }
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
// SCFDialect
//===----------------------------------------------------------------------===//
@@ -33,6 +68,7 @@
#define GET_OP_LIST
#include "mlir/Dialect/SCF/SCFOps.cpp.inc"
>();
+ addInterfaces<SCFInlinerInterface>();
}
/// Default callback for IfOp builders. Inserts a yield without arguments.