[flang][fir] Add OpaqueAttr.

Add the opaque attribute class used in flang.

https://github.com/flang-compiler/f18-llvm-project/pull/402

Differential Revision: https://reviews.llvm.org/D96293

GitOrigin-RevId: 2cd0a113df2c12405e7a81f970f2df5a0de46df2
diff --git a/include/flang/Optimizer/Dialect/FIRAttr.h b/include/flang/Optimizer/Dialect/FIRAttr.h
index e7075f3..5199508 100644
--- a/include/flang/Optimizer/Dialect/FIRAttr.h
+++ b/include/flang/Optimizer/Dialect/FIRAttr.h
@@ -25,10 +25,13 @@
 class FIROpsDialect;
 
 namespace detail {
+struct OpaqueAttributeStorage;
 struct RealAttributeStorage;
 struct TypeAttributeStorage;
 } // namespace detail
 
+using KindTy = unsigned;
+
 class ExactTypeAttr
     : public mlir::Attribute::AttrBase<ExactTypeAttr, mlir::Attribute,
                                        detail::TypeAttributeStorage> {
@@ -127,10 +130,31 @@
   static constexpr llvm::StringRef getAttrName() { return "real"; }
   static RealAttr get(mlir::MLIRContext *ctxt, const ValueType &key);
 
-  int getFKind() const;
+  KindTy getFKind() const;
   llvm::APFloat getValue() const;
 };
 
+/// An opaque attribute is used to provide dictionary lookups of pointers. The
+/// underlying type of the pointee object is left up to the client. Opaque
+/// attributes are always constructed as null pointers when parsing. Clearly,
+/// opaque attributes come with restrictions and must be used with care.
+/// 1. An opaque attribute should not refer to information of semantic
+/// significance, since the pointed-to object will not be a part of
+/// round-tripping the IR.
+/// 2. The lifetime of the pointed-to object must outlive any possible uses
+/// via the opaque attribute.
+class OpaqueAttr
+    : public mlir::Attribute::AttrBase<OpaqueAttr, mlir::Attribute,
+                                       detail::OpaqueAttributeStorage> {
+public:
+  using Base::Base;
+
+  static constexpr llvm::StringRef getAttrName() { return "opaque"; }
+  static OpaqueAttr get(mlir::MLIRContext *ctxt, void *pointer);
+
+  void *getPointer() const;
+};
+
 mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
                                   mlir::DialectAsmParser &parser,
                                   mlir::Type type);
diff --git a/lib/Optimizer/Dialect/FIRAttr.cpp b/lib/Optimizer/Dialect/FIRAttr.cpp
index 1e7a276..1996143 100644
--- a/lib/Optimizer/Dialect/FIRAttr.cpp
+++ b/lib/Optimizer/Dialect/FIRAttr.cpp
@@ -15,13 +15,12 @@
 #include "flang/Optimizer/Support/KindMapping.h"
 #include "mlir/IR/AttributeSupport.h"
 #include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/Types.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/SmallString.h"
 
 using namespace fir;
 
-namespace fir {
-namespace detail {
+namespace fir::detail {
 
 struct RealAttributeStorage : public mlir::AttributeStorage {
   using KeyTy = std::pair<int, llvm::APFloat>;
@@ -44,7 +43,7 @@
         RealAttributeStorage(key);
   }
 
-  int getFKind() const { return kind; }
+  KindTy getFKind() const { return kind; }
   llvm::APFloat getValue() const { return value; }
 
 private:
@@ -75,55 +74,98 @@
 private:
   mlir::Type value;
 };
-} // namespace detail
 
-ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
+/// An attribute representing a raw pointer.
+struct OpaqueAttributeStorage : public mlir::AttributeStorage {
+  using KeyTy = void *;
+
+  OpaqueAttributeStorage(void *value) : value(value) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const { return key == value; }
+
+  /// Construct a new storage instance.
+  static OpaqueAttributeStorage *
+  construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
+    return new (allocator.allocate<OpaqueAttributeStorage>())
+        OpaqueAttributeStorage(key);
+  }
+
+  void *getPointer() const { return value; }
+
+private:
+  void *value;
+};
+} // namespace fir::detail
+
+//===----------------------------------------------------------------------===//
+// Attributes for SELECT TYPE
+//===----------------------------------------------------------------------===//
+
+ExactTypeAttr fir::ExactTypeAttr::get(mlir::Type value) {
   return Base::get(value.getContext(), value);
 }
 
-mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
+mlir::Type fir::ExactTypeAttr::getType() const { return getImpl()->getType(); }
 
-SubclassAttr SubclassAttr::get(mlir::Type value) {
+SubclassAttr fir::SubclassAttr::get(mlir::Type value) {
   return Base::get(value.getContext(), value);
 }
 
-mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
+mlir::Type fir::SubclassAttr::getType() const { return getImpl()->getType(); }
+
+//===----------------------------------------------------------------------===//
+// Attributes for SELECT CASE
+//===----------------------------------------------------------------------===//
 
 using AttributeUniquer = mlir::detail::AttributeUniquer;
 
-ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
+ClosedIntervalAttr fir::ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
   return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
 }
 
-UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
+UpperBoundAttr fir::UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
   return AttributeUniquer::get<UpperBoundAttr>(ctxt);
 }
 
-LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
+LowerBoundAttr fir::LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
   return AttributeUniquer::get<LowerBoundAttr>(ctxt);
 }
 
-PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
+PointIntervalAttr fir::PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
   return AttributeUniquer::get<PointIntervalAttr>(ctxt);
 }
 
+//===----------------------------------------------------------------------===//
 // RealAttr
+//===----------------------------------------------------------------------===//
 
-RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
-                       const RealAttr::ValueType &key) {
+RealAttr fir::RealAttr::get(mlir::MLIRContext *ctxt,
+                            const RealAttr::ValueType &key) {
   return Base::get(ctxt, key);
 }
 
-int RealAttr::getFKind() const { return getImpl()->getFKind(); }
+KindTy fir::RealAttr::getFKind() const { return getImpl()->getFKind(); }
 
-llvm::APFloat RealAttr::getValue() const { return getImpl()->getValue(); }
+llvm::APFloat fir::RealAttr::getValue() const { return getImpl()->getValue(); }
 
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
+OpaqueAttr fir::OpaqueAttr::get(mlir::MLIRContext *ctxt, void *key) {
+  return Base::get(ctxt, key);
+}
+
+void *fir::OpaqueAttr::getPointer() const { return getImpl()->getPointer(); }
+
+//===----------------------------------------------------------------------===//
 // FIR attribute parsing
+//===----------------------------------------------------------------------===//
 
-namespace {
-mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
-                                 mlir::DialectAsmParser &parser,
-                                 mlir::Type type) {
+static mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
+                                        mlir::DialectAsmParser &parser,
+                                        mlir::Type type) {
   int kind = 0;
   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
     parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
@@ -158,11 +200,10 @@
   }
   return RealAttr::get(dialect->getContext(), {kind, value});
 }
-} // namespace
 
-mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
-                                  mlir::DialectAsmParser &parser,
-                                  mlir::Type type) {
+mlir::Attribute fir::parseFirAttribute(FIROpsDialect *dialect,
+                                       mlir::DialectAsmParser &parser,
+                                       mlir::Type type) {
   auto loc = parser.getNameLoc();
   llvm::StringRef attrName;
   if (parser.parseKeyword(&attrName)) {
@@ -186,6 +227,15 @@
     }
     return SubclassAttr::get(type);
   }
+  if (attrName == OpaqueAttr::getAttrName()) {
+    if (parser.parseLess() || parser.parseGreater()) {
+      parser.emitError(loc, "expected <>");
+      return {};
+    }
+    // NB: opaque pointers are always parsed in as nullptrs. The tool must
+    // rebuild the context.
+    return OpaqueAttr::get(dialect->getContext(), nullptr);
+  }
   if (attrName == PointIntervalAttr::getAttrName())
     return PointIntervalAttr::get(dialect->getContext());
   if (attrName == LowerBoundAttr::getAttrName())
@@ -201,10 +251,12 @@
   return {};
 }
 
+//===----------------------------------------------------------------------===//
 // FIR attribute pretty printer
+//===----------------------------------------------------------------------===//
 
-void printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
-                       mlir::DialectAsmPrinter &p) {
+void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
+                            mlir::DialectAsmPrinter &p) {
   auto &os = p.getStream();
   if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
     os << fir::ExactTypeAttr::getAttrName() << '<';
@@ -227,9 +279,10 @@
     llvm::SmallString<40> ss;
     a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
     os << ss << '>';
+  } else if (attr.isa<fir::OpaqueAttr>()) {
+    os << fir::OpaqueAttr::getAttrName() << "<>";
   } else {
-    llvm_unreachable("attribute pretty-printer is not implemented");
+    // don't know how to print the attribute, so use a default
+    os << "<(unknown attribute)>";
   }
 }
-
-} // namespace fir
diff --git a/lib/Optimizer/Dialect/FIRDialect.cpp b/lib/Optimizer/Dialect/FIRDialect.cpp
index f174c89..c424b98 100644
--- a/lib/Optimizer/Dialect/FIRDialect.cpp
+++ b/lib/Optimizer/Dialect/FIRDialect.cpp
@@ -19,7 +19,7 @@
            FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
            PointerType, RealType, RecordType, ReferenceType, SequenceType,
            TypeDescType, fir::VectorType>();
-  addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
+  addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr, OpaqueAttr,
                 PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
   addOperations<
 #define GET_OP_LIST
diff --git a/test/Fir/fir-ops.fir b/test/Fir/fir-ops.fir
index fe10ef7..4a1f21a 100644
--- a/test/Fir/fir-ops.fir
+++ b/test/Fir/fir-ops.fir
@@ -32,6 +32,9 @@
 func private @nop()
 func private @get_func() -> (() -> ())
 
+// CHECK-LABEL: func private @attr1() -> none attributes {a = #fir.opaque<>, b = #fir.opaque<>}
+func private @attr1() -> none attributes {a = #fir.opaque<>, b = #fir.opaque<>}
+
 // CHECK-LABEL:       func @instructions() {
 func @instructions() {
 // CHECK: [[VAL_0:%.*]] = fir.alloca !fir.array<10xi32>