[mlir][SMT] restore custom builder for forall/exists (#135470)

This reverts commit 54e70ac7650f1c22f687937d1a082e4152f97b22 which
itself fixed an [asan
leak](https://lab.llvm.org/buildbot/#/builders/55/builds/9761) from the
original upstreaming commit. The leak was due to op allocations not
being `free`ed.

~~The necessary change was to explicitly `->destroy()` the ops at the
end of the tests. I believe this is because the rewriter used in the
tests doesn't actually insert them into a module and so without an
explicit `->destroy()` no bookkeeping process is able to take care of
them.~~

The necessary change was to use `OwningOpRef` which calls `op->erase()`
in its [own
destructor](https://github.com/makslevental/llvm-project/blob/89cfae41ecc043f8c47be4dea4b7c740d4f950b3/mlir/include/mlir/IR/OwningOpRef.h#L39).
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index af73955..1872c00 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -448,6 +448,18 @@
                         VariadicRegion<SizedRegion<1>>:$patterns);
   let results = (outs BoolType:$result);
 
+  let builders = [
+    OpBuilder<(ins
+      "TypeRange":$boundVarTypes,
+      "function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
+      CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
+      CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
+           "{}">:$patternBuilder,
+      CArg<"uint32_t", "0">:$weight,
+      CArg<"bool", "false">:$noPattern)>
+  ];
+  let skipDefaultBuilders = true;
+
   let assemblyFormat = [{
     ($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
     attr-dict-with-keyword $body (`patterns` $patterns^)?
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
index 604dd26..8977a3a 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -432,6 +432,16 @@
   return verifyQuantifierRegions(*this);
 }
 
+void ForallOp::build(
+    OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+    function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+    std::optional<ArrayRef<StringRef>> boundVarNames,
+    function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+    uint32_t weight, bool noPattern) {
+  buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+                              boundVarNames, patternBuilder, weight, noPattern);
+}
+
 //===----------------------------------------------------------------------===//
 // ExistsOp
 //===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@
   return verifyQuantifierRegions(*this);
 }
 
+void ExistsOp::build(
+    OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+    function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+    std::optional<ArrayRef<StringRef>> boundVarNames,
+    function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+    uint32_t weight, bool noPattern) {
+  buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+                              boundVarNames, patternBuilder, weight, noPattern);
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
diff --git a/mlir/unittests/Dialect/SMT/CMakeLists.txt b/mlir/unittests/Dialect/SMT/CMakeLists.txt
index 86e16d6..a133146 100644
--- a/mlir/unittests/Dialect/SMT/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SMT/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRSMTTests
   AttributeTest.cpp
+  QuantifierTest.cpp
   TypeTest.cpp
 )
 
diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
new file mode 100644
index 0000000..d7c57f0
--- /dev/null
+++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
@@ -0,0 +1,187 @@
+//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace smt;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ExistsOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ExistsBuilderWithPattern) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      std::nullopt,
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return boundVars;
+      },
+      /*weight=*/2);
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  existsOp->print(stream);
+
+  ASSERT_STREQ(
+      stream.str().str().c_str(),
+      "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
+      "%arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  smt.yield %0 : "
+      "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n  "
+      "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ExistsBuilderNoPattern) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  existsOp->print(stream);
+
+  ASSERT_STREQ(stream.str().str().c_str(),
+               "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+               "!smt.bool, %arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  "
+               "smt.yield %0 : !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ExistsBuilderDefault) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      ArrayRef<StringRef>{"a", "b"});
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  existsOp->print(stream);
+
+  ASSERT_STREQ(stream.str().str().c_str(),
+               "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
+               "%arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  smt.yield "
+               "%0 : !smt.bool\n}\n");
+}
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ForallOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ForallBuilderWithPattern) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      ArrayRef<StringRef>{"a", "b"},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return boundVars;
+      },
+      /*weight=*/2);
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  forallOp->print(stream);
+
+  ASSERT_STREQ(
+      stream.str().str().c_str(),
+      "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
+      "%arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  smt.yield %0 : "
+      "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n  "
+      "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ForallBuilderNoPattern) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  forallOp->print(stream);
+
+  ASSERT_STREQ(stream.str().str().c_str(),
+               "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+               "!smt.bool, %arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  "
+               "smt.yield %0 : !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ForallBuilderDefault) {
+  MLIRContext context;
+  context.loadDialect<SMTDialect>();
+  Location loc(UnknownLoc::get(&context));
+
+  OpBuilder builder(&context);
+  auto boolTy = BoolType::get(&context);
+
+  OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+      loc, TypeRange{boolTy, boolTy},
+      [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+        return builder.create<AndOp>(loc, boundVars);
+      },
+      std::nullopt);
+
+  SmallVector<char, 1024> buffer;
+  llvm::raw_svector_ostream stream(buffer);
+  forallOp->print(stream);
+
+  ASSERT_STREQ(stream.str().str().c_str(),
+               "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
+               "%arg1: !smt.bool):\n  %0 = smt.and %arg0, %arg1\n  smt.yield "
+               "%0 : !smt.bool\n}\n");
+}
+
+} // namespace