[mlir] DialectConversion: support block creation in ConversionPatternRewriter

PatternRewriter and derived classes provide a set of virtual methods to
manipulate blocks, which ConversionPatternRewriter overrides to keep track of
the manipulations and undo them in case the conversion fails. However, one can
currently create a block only by splitting another block into two. This not
only makes the API inconsistent (`splitBlock` is allowed in conversion
patterns, but `createBlock` is not), but it also make it impossible for one to
create blocks with argument lists different from those of already existing
blocks since in-place block updates are not supported either. Such
functionality precludes dialect conversion infrastructure from being used more
extensively on region-containing ops, for example, for value-returning "if"
operations. At the same time, ConversionPatternRewriter already allows one to
undo block creation as block creation is one of the primitive operations in
already supported region inlining.

Support block creation in conversion patterns by hooking `createBlock` on the
block action undo mechanism. This requires to make `Builder::createBlock`
virtual, similarly to Op insertion. This is a minimal change to the Builder
infrastructure that will later help support additional use cases such as block
signature changes. `createBlock` now additionally takes the types of the block
arguments that are added immediately so as to avoid in-place argument list
manipulation that would be illegal in conversion patterns.
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c6b16f..75f49e8 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -298,13 +298,15 @@
   /// Insert the given operation at the current insertion point and return it.
   virtual Operation *insert(Operation *op);
 
-  /// Add new block and set the insertion point to the end of it. The block is
-  /// inserted at the provided insertion point of 'parent'.
-  Block *createBlock(Region *parent, Region::iterator insertPt = {});
+  /// Add new block with 'argTypes' arguments and set the insertion point to the
+  /// end of it. The block is inserted at the provided insertion point of
+  /// 'parent'.
+  virtual Block *createBlock(Region *parent, Region::iterator insertPt = {},
+                             TypeRange argTypes = llvm::None);
 
-  /// Add new block and set the insertion point to the end of it. The block is
-  /// placed before 'insertBefore'.
-  Block *createBlock(Block *insertBefore);
+  /// Add new block with 'argTypes' arguments and set the insertion point to the
+  /// end of it. The block is placed before 'insertBefore'.
+  Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
 
   /// Returns the current block of the builder.
   Block *getBlock() const { return block; }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 7760073..9ab3a71 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -344,6 +344,10 @@
   /// otherwise an assert will be issued.
   void eraseOp(Operation *op) override;
 
+  /// PatternRewriter hook for creating a new block with the given arguments.
+  Block *createBlock(Region *parent, Region::iterator insertPt = {},
+                     TypeRange argTypes = llvm::None) override;
+
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2353665..c8d5ea6 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -339,24 +339,28 @@
   return op;
 }
 
-/// Add new block and set the insertion point to the end of it. The block is
-/// inserted at the provided insertion point of 'parent'.
-Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is inserted at the provided insertion point of
+/// 'parent'.
+Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
+                              TypeRange argTypes) {
   assert(parent && "expected valid parent region");
   if (insertPt == Region::iterator())
     insertPt = parent->end();
 
   Block *b = new Block();
+  b->addArguments(argTypes);
   parent->getBlocks().insert(insertPt, b);
   setInsertionPointToEnd(b);
   return b;
 }
 
-/// Add new block and set the insertion point to the end of it.  The block is
-/// placed before 'insertBefore'.
-Block *OpBuilder::createBlock(Block *insertBefore) {
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it.  The block is placed before 'insertBefore'.
+Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
   assert(insertBefore && "expected valid insertion block");
-  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
+  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
+                     argTypes);
 }
 
 /// Create an operation given the fields represented as an OperationState.
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 19304b3..725f5f4 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -585,6 +585,9 @@
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ValueRange newValues);
 
+  /// Notifies that a block was created.
+  void notifyCreatedBlock(Block *block);
+
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
 
@@ -804,6 +807,10 @@
   markNestedOpsIgnored(op);
 }
 
+void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
+  blockActions.push_back(BlockAction::getCreate(block));
+}
+
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
                                                      Block *continuation) {
   blockActions.push_back(BlockAction::getSplit(continuation, block));
@@ -910,6 +917,15 @@
   return impl->mapping.lookupOrDefault(key);
 }
 
+/// PatternRewriter hook for creating a new block with the given arguments.
+Block *ConversionPatternRewriter::createBlock(Region *parent,
+                                              Region::iterator insertPtr,
+                                              TypeRange argTypes) {
+  Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes);
+  impl->notifyCreatedBlock(block);
+  return block;
+}
+
 /// PatternRewriter hook for splitting a block into two parts.
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index bd73cf3..3305e01 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -130,6 +130,19 @@
   return %0 : i32
 }
 
+// CHECK-LABEL: @create_block
+func @create_block() {
+  "test.container"() ({
+    // Check that we created a block with arguments.
+    // CHECK-NOT: test.create_block
+    // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+    // CHECK: test.finish
+    "test.create_block"() : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+  return
+}
+
 // -----
 
 func @fail_to_convert_illegal_op() -> i32 {
@@ -163,3 +176,17 @@
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @create_illegal_block
+func @create_illegal_block() {
+  "test.container"() ({
+    // Check that we can undo block creation, i.e. that the block was removed.
+    // CHECK: test.create_illegal_block
+    // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+    "test.create_illegal_block"() : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 0b73f09..23d650e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -183,6 +183,41 @@
     return success();
   }
 };
+/// A simple pattern that creates a block at the end of the parent region of the
+/// matched operation.
+struct TestCreateBlock : public RewritePattern {
+  TestCreateBlock(MLIRContext *ctx)
+      : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    Region &region = *op->getParentRegion();
+    Type i32Type = rewriter.getIntegerType(32);
+    rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
+    rewriter.create<TerminatorOp>(op->getLoc());
+    rewriter.replaceOp(op, {});
+    return success();
+  }
+};
+
+/// A simple pattern that creates a block containing an invalid operaiton in
+/// order to trigger the block creation undo mechanism.
+struct TestCreateIllegalBlock : public RewritePattern {
+  TestCreateIllegalBlock(MLIRContext *ctx)
+      : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    Region &region = *op->getParentRegion();
+    Type i32Type = rewriter.getIntegerType(32);
+    rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
+    // Create an illegal op to ensure the conversion fails.
+    rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
+    rewriter.create<TerminatorOp>(op->getLoc());
+    rewriter.replaceOp(op, {});
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
@@ -373,12 +408,12 @@
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    patterns
-        .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-                TestPassthroughInvalidOp, TestSplitReturnType,
-                TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-                TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-                TestNonRootReplacement>(&getContext());
+    patterns.insert<
+        TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
+        TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
+        TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+        TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+        TestNonRootReplacement>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
@@ -388,7 +423,8 @@
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
+    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
+                      TerminatorOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {