[mlir][SMT] restore custom builder for forall/exists (#135470)
This reverts commit 54e70ac7650f1c22f687937d1a082e4152f97b22 which
itself fixed an [asan
leak](https://lab.llvm.org/buildbot/#/builders/55/builds/9761) from the
original upstreaming commit. The leak was due to op allocations not
being `free`ed.
~~The necessary change was to explicitly `->destroy()` the ops at the
end of the tests. I believe this is because the rewriter used in the
tests doesn't actually insert them into a module and so without an
explicit `->destroy()` no bookkeeping process is able to take care of
them.~~
The necessary change was to use `OwningOpRef` which calls `op->erase()`
in its [own
destructor](https://github.com/makslevental/llvm-project/blob/89cfae41ecc043f8c47be4dea4b7c740d4f950b3/mlir/include/mlir/IR/OwningOpRef.h#L39).
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index af73955..1872c00 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -448,6 +448,18 @@
VariadicRegion<SizedRegion<1>>:$patterns);
let results = (outs BoolType:$result);
+ let builders = [
+ OpBuilder<(ins
+ "TypeRange":$boundVarTypes,
+ "function_ref<Value(OpBuilder &, Location, ValueRange)>":$bodyBuilder,
+ CArg<"std::optional<ArrayRef<StringRef>>", "std::nullopt">:$boundVarNames,
+ CArg<"function_ref<ValueRange(OpBuilder &, Location, ValueRange)>",
+ "{}">:$patternBuilder,
+ CArg<"uint32_t", "0">:$weight,
+ CArg<"bool", "false">:$noPattern)>
+ ];
+ let skipDefaultBuilders = true;
+
let assemblyFormat = [{
($boundVarNames^)? (`no_pattern` $noPattern^)? (`weight` $weight^)?
attr-dict-with-keyword $body (`patterns` $patterns^)?
diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
index 604dd26..8977a3a 100644
--- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
+++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp
@@ -432,6 +432,16 @@
return verifyQuantifierRegions(*this);
}
+void ForallOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
//===----------------------------------------------------------------------===//
// ExistsOp
//===----------------------------------------------------------------------===//
@@ -448,5 +458,15 @@
return verifyQuantifierRegions(*this);
}
+void ExistsOp::build(
+ OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
+ function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
+ std::optional<ArrayRef<StringRef>> boundVarNames,
+ function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder,
+ uint32_t weight, bool noPattern) {
+ buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
+ boundVarNames, patternBuilder, weight, noPattern);
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
diff --git a/mlir/unittests/Dialect/SMT/CMakeLists.txt b/mlir/unittests/Dialect/SMT/CMakeLists.txt
index 86e16d6..a133146 100644
--- a/mlir/unittests/Dialect/SMT/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SMT/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRSMTTests
AttributeTest.cpp
+ QuantifierTest.cpp
TypeTest.cpp
)
diff --git a/mlir/unittests/Dialect/SMT/QuantifierTest.cpp b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
new file mode 100644
index 0000000..d7c57f0
--- /dev/null
+++ b/mlir/unittests/Dialect/SMT/QuantifierTest.cpp
@@ -0,0 +1,187 @@
+//===- QuantifierTest.cpp - SMT quantifier operation unit tests -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace smt;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ExistsOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ExistsBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt,
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp->print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.exists weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ExistsBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp->print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ExistsBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ExistsOp> existsOp = builder.create<ExistsOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"});
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ existsOp->print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.exists [\"a\", \"b\"] {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+}
+
+//===----------------------------------------------------------------------===//
+// Test custom builders of ForallOp
+//===----------------------------------------------------------------------===//
+
+TEST(QuantifierTest, ForallBuilderWithPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return boundVars;
+ },
+ /*weight=*/2);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp->print(stream);
+
+ ASSERT_STREQ(
+ stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] weight 2 {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield %0 : "
+ "!smt.bool\n} patterns {\n^bb0(%arg0: !smt.bool, %arg1: !smt.bool):\n "
+ "smt.yield %arg0, %arg1 : !smt.bool, !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ForallBuilderNoPattern) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ ArrayRef<StringRef>{"a", "b"}, nullptr, /*weight=*/0, /*noPattern=*/true);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp->print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall [\"a\", \"b\"] no_pattern {\n^bb0(%arg0: "
+ "!smt.bool, %arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n "
+ "smt.yield %0 : !smt.bool\n}\n");
+}
+
+TEST(QuantifierTest, ForallBuilderDefault) {
+ MLIRContext context;
+ context.loadDialect<SMTDialect>();
+ Location loc(UnknownLoc::get(&context));
+
+ OpBuilder builder(&context);
+ auto boolTy = BoolType::get(&context);
+
+ OwningOpRef<ForallOp> forallOp = builder.create<ForallOp>(
+ loc, TypeRange{boolTy, boolTy},
+ [](OpBuilder &builder, Location loc, ValueRange boundVars) {
+ return builder.create<AndOp>(loc, boundVars);
+ },
+ std::nullopt);
+
+ SmallVector<char, 1024> buffer;
+ llvm::raw_svector_ostream stream(buffer);
+ forallOp->print(stream);
+
+ ASSERT_STREQ(stream.str().str().c_str(),
+ "%0 = smt.forall {\n^bb0(%arg0: !smt.bool, "
+ "%arg1: !smt.bool):\n %0 = smt.and %arg0, %arg1\n smt.yield "
+ "%0 : !smt.bool\n}\n");
+}
+
+} // namespace