[mlir][TableGen] Emit interface traits after all interfaces (#147699)
Interface traits may provide default implementation of methods. When
this happens, the implementation may rely on another interface that is
not yet defined meaning that one gets "incomplete type" error during C++
compilation. In pseudo-code, the problem is the following:
```
InterfaceA has methodB() { return InterfaceB(); }
InterfaceB defined later
// What's generated is:
class InterfaceA { ... }
class InterfaceATrait {
// error: InterfaceB is an incomplete type
InterfaceB methodB() { return InterfaceB(); }
}
class InterfaceB { ... } // defined here
```
The two more "advanced" cases are:
* Cyclic dependency (A requires B and B requires A)
* Type-traited usage of an incomplete type (e.g.
`FailureOr<InterfaceB>`)
It seems reasonable to emit interface traits *after* all of the
interfaces have been defined to avoid the problem altogether.
As a drive by, make forward declarations of the interfaces early so that
user code does not need to forward declare.
GitOrigin-RevId: f5d3cf4a643fc13194e09cb39905f7f3b083f85e
diff --git a/test/lib/Dialect/Test/TestInterfaces.td b/test/lib/Dialect/Test/TestInterfaces.td
index dea26b8..d3d96ea 100644
--- a/test/lib/Dialect/Test/TestInterfaces.td
+++ b/test/lib/Dialect/Test/TestInterfaces.td
@@ -174,4 +174,32 @@
}];
}
+// Dummy type interface "A" that requires type interface "B" to be complete.
+def TestCyclicTypeInterfaceA : TypeInterface<"TestCyclicTypeInterfaceA"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceB>",
+ /*methodName=*/"returnB",
+ (ins),
+ /*methodBody=*/"",
+ /*defaultImpl=*/"return mlir::failure();"
+ >,
+ ];
+}
+
+// Dummy type interface "B" that requires type interface "A" to be complete.
+def TestCyclicTypeInterfaceB : TypeInterface<"TestCyclicTypeInterfaceB"> {
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<"",
+ "::mlir::FailureOr<::mlir::TestCyclicTypeInterfaceA>",
+ /*methodName=*/"returnA",
+ (ins),
+ /*methodBody=*/"",
+ /*defaultImpl=*/"return mlir::failure();"
+ >,
+ ];
+}
+
#endif // MLIR_TEST_DIALECT_TEST_INTERFACES
diff --git a/test/mlir-tblgen/op-interface.td b/test/mlir-tblgen/op-interface.td
index 17bd631..aa71bad 100644
--- a/test/mlir-tblgen/op-interface.td
+++ b/test/mlir-tblgen/op-interface.td
@@ -31,11 +31,6 @@
// DECL-NEXT: return (*this).someOtherMethod();
// DECL-NEXT: }
-// DECL: struct ExtraShardDeclsInterfaceTrait
-// DECL: bool sharedMethodDeclaration() {
-// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
-// DECL-NEXT: }
-
def TestInheritanceMultiBaseInterface : OpInterface<"TestInheritanceMultiBaseInterface"> {
let methods = [
InterfaceMethod<
@@ -71,7 +66,7 @@
def TestInheritanceZDerivedInterface
: OpInterface<"TestInheritanceZDerivedInterface", [TestInheritanceMiddleBaseInterface]>;
-// DECL: class TestInheritanceZDerivedInterface
+// DECL: struct TestInheritanceZDerivedInterfaceInterfaceTraits
// DECL: struct Concept {
// DECL: const TestInheritanceMultiBaseInterface::Concept *implTestInheritanceMultiBaseInterface = nullptr;
// DECL-NOT: const TestInheritanceMultiBaseInterface::Concept
@@ -173,10 +168,16 @@
// DECL: /// some function comment
// DECL: int foo(int input);
-// DECL-LABEL: struct TestOpInterfaceVerifyTrait
+// Trait declarations / definitions come after interface definitions.
+// DECL: struct ExtraShardDeclsInterfaceTrait : public
+// DECL: bool sharedMethodDeclaration() {
+// DECL-NEXT: return (*static_cast<ConcreteOp *>(this)).someOtherMethod();
+// DECL-NEXT: }
+
+// DECL-LABEL: struct TestOpInterfaceVerifyTrait : public
// DECL: verifyTrait
-// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
+// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait : public
// DECL: verifyRegionTrait
// Method implementations come last, after all class definitions.
diff --git a/tools/mlir-tblgen/OpInterfacesGen.cpp b/tools/mlir-tblgen/OpInterfacesGen.cpp
index 4dfa190..3cc1636 100644
--- a/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -96,9 +96,9 @@
void emitConceptDecl(const Interface &interface);
void emitModelDecl(const Interface &interface);
void emitModelMethodsDef(const Interface &interface);
- void emitTraitDecl(const Interface &interface, StringRef interfaceName,
- StringRef interfaceTraitsName);
+ void forwardDeclareInterface(const Interface &interface);
void emitInterfaceDecl(const Interface &interface);
+ void emitInterfaceTraitDecl(const Interface &interface);
/// The set of interface records to emit.
std::vector<const Record *> defs;
@@ -445,9 +445,16 @@
os << "} // namespace " << ns << "\n";
}
-void InterfaceGenerator::emitTraitDecl(const Interface &interface,
- StringRef interfaceName,
- StringRef interfaceTraitsName) {
+void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
+ for (StringRef ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ os << "namespace detail {\n";
+
+ StringRef interfaceName = interface.getName();
+ auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
os << llvm::formatv(" template <typename {3}>\n"
" struct {0}Trait : public ::mlir::{2}<{0},"
" detail::{1}>::Trait<{3}> {{\n",
@@ -494,6 +501,10 @@
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
os << " };\n";
+ os << "}// namespace detail\n";
+
+ for (StringRef ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
}
static void emitInterfaceDeclMethods(const Interface &interface,
@@ -517,6 +528,27 @@
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
}
+void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) {
+ llvm::SmallVector<StringRef, 2> namespaces;
+ llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
+ for (StringRef ns : namespaces)
+ os << "namespace " << ns << " {\n";
+
+ // Emit a forward declaration of the interface class so that it becomes usable
+ // in the signature of its methods.
+ std::string comments = tblgen::emitSummaryAndDescComments(
+ "", interface.getDescription().value_or(""));
+ if (!comments.empty()) {
+ os << comments << "\n";
+ }
+
+ StringRef interfaceName = interface.getName();
+ os << "class " << interfaceName << ";\n";
+
+ for (StringRef ns : llvm::reverse(namespaces))
+ os << "} // namespace " << ns << "\n";
+}
+
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
@@ -533,7 +565,6 @@
if (!comments.empty()) {
os << comments << "\n";
}
- os << "class " << interfaceName << ";\n";
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
@@ -603,10 +634,6 @@
os << "};\n";
- os << "namespace detail {\n";
- emitTraitDecl(interface, interfaceName, interfaceTraitsName);
- os << "}// namespace detail\n";
-
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
@@ -620,9 +647,14 @@
return lhs->getID() < rhs->getID();
});
for (const Record *def : sortedDefs)
+ forwardDeclareInterface(Interface(def));
+ for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));
for (const Record *def : sortedDefs)
+ emitInterfaceTraitDecl(Interface(def));
+ for (const Record *def : sortedDefs)
emitModelMethodsDef(Interface(def));
+
return false;
}