[flang][fir][NFC] Move ShapeType to TableGen type definition

This is the first patch of a serie to move FIR types to TableGen format as suggested in D96172.
This patch is setting up the files for FIR types and move the ShapeType to TableGen.

As discussed with @schweitz, I'm taking over this task to help the FIR upstreaming effort.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96422

GitOrigin-RevId: edd365c7098d212172b6e94a793910a8e1a0f521
diff --git a/include/flang/Optimizer/Dialect/CMakeLists.txt b/include/flang/Optimizer/Dialect/CMakeLists.txt
index f83cb8f..e55b3ed 100644
--- a/include/flang/Optimizer/Dialect/CMakeLists.txt
+++ b/include/flang/Optimizer/Dialect/CMakeLists.txt
@@ -4,6 +4,8 @@
 set(LLVM_TARGET_DEFINITIONS FIROps.td)
 mlir_tablegen(FIROps.h.inc -gen-op-decls)
 mlir_tablegen(FIROps.cpp.inc -gen-op-defs)
+mlir_tablegen(FIROpsTypes.h.inc --gen-typedef-decls)
+mlir_tablegen(FIROpsTypes.cpp.inc --gen-typedef-defs)
 add_public_tablegen_target(FIROpsIncGen)
 
 add_custom_target(flang-doc)
diff --git a/include/flang/Optimizer/Dialect/FIROps.td b/include/flang/Optimizer/Dialect/FIROps.td
index 5eee855..a156ca4 100644
--- a/include/flang/Optimizer/Dialect/FIROps.td
+++ b/include/flang/Optimizer/Dialect/FIROps.td
@@ -24,6 +24,8 @@
   let cppNamespace = "::fir";
 }
 
+include "flang/Optimizer/Dialect/FIRTypes.td"
+
 // Types and predicates
 
 def fir_Type : Type<CPred<"fir::isa_fir_or_std_type($_self)">,
@@ -99,10 +101,9 @@
     fir_HeapType.predicate, fir_PointerType.predicate, fir_BoxType.predicate]>,
     "any reference or box">;
 
-def fir_ShapeType : Type<CPred<"$_self.isa<fir::ShapeType>()">, "shape type">;
 def fir_ShapeShiftType : Type<CPred<"$_self.isa<fir::ShapeShiftType>()">,
     "shape shift type">;
-def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
+def AnyShapeLike : TypeConstraint<Or<[ShapeType.predicate,
     fir_ShapeShiftType.predicate]>, "any legal shape type">;
 def AnyShapeType : Type<AnyShapeLike.predicate, "any legal shape type">;
 def fir_SliceType : Type<CPred<"$_self.isa<fir::SliceType>()">, "slice type">;
diff --git a/include/flang/Optimizer/Dialect/FIRType.h b/include/flang/Optimizer/Dialect/FIRType.h
index a10aef5..487f8b0 100644
--- a/include/flang/Optimizer/Dialect/FIRType.h
+++ b/include/flang/Optimizer/Dialect/FIRType.h
@@ -17,6 +17,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/SmallVector.h"
 
+#define GET_TYPEDEF_CLASSES
+#include "flang/Optimizer/Dialect/FIROpsTypes.h.inc"
+
 namespace llvm {
 class raw_ostream;
 class StringRef;
@@ -54,7 +57,6 @@
 struct RecordTypeStorage;
 struct ReferenceTypeStorage;
 struct SequenceTypeStorage;
-struct ShapeTypeStorage;
 struct ShapeShiftTypeStorage;
 struct SliceTypeStorage;
 struct TypeDescTypeStorage;
@@ -219,16 +221,6 @@
                                                           mlir::Type eleTy);
 };
 
-/// Type of a vector of runtime values that define the shape of a
-/// multidimensional array object. The vector is the extents of each array
-/// dimension. The rank of a ShapeType must be at least 1.
-class ShapeType : public mlir::Type::TypeBase<ShapeType, mlir::Type,
-                                              detail::ShapeTypeStorage> {
-public:
-  using Base::Base;
-  static ShapeType get(mlir::MLIRContext *ctx, unsigned rank);
-  unsigned getRank() const;
-};
 
 /// Type of a vector of runtime values that define the shape and the origin of a
 /// multidimensional array object. The vector is of pairs, origin offset and
diff --git a/include/flang/Optimizer/Dialect/FIRTypes.td b/include/flang/Optimizer/Dialect/FIRTypes.td
new file mode 100644
index 0000000..1ed1608
--- /dev/null
+++ b/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -0,0 +1,48 @@
+//===- FIRTypes.td - FIR types -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the FIR dialect types.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FIR_DIALECT_FIR_TYPES
+#define FIR_DIALECT_FIR_TYPES
+
+//===----------------------------------------------------------------------===//
+// FIR Types
+//===----------------------------------------------------------------------===//
+
+class FIR_Type<string name, string typeMnemonic> : TypeDef<fir_Dialect, name> {
+  let mnemonic = typeMnemonic;
+}
+
+def ShapeType : FIR_Type<"Shape", "shape"> {
+  let summary = "shape of a multidimensional array object";
+
+  let description = [{
+    Type of a vector of runtime values that define the shape of a
+    multidimensional array object. The vector is the extents of each array
+    dimension. The rank of a ShapeType must be at least 1.
+  }];
+
+  let parameters = (ins "unsigned":$rank);
+
+  let printer = [{
+    $_printer << "shape<" << getImpl()->rank << ">";
+  }];
+
+  let parser = [{
+    int rank;
+    if ($_parser.parseLess() || $_parser.parseInteger(rank) ||
+        $_parser.parseGreater())
+      return Type();
+    return get(context, rank);
+  }];
+}
+
+#endif // FIR_DIALECT_FIR_TYPES
diff --git a/lib/Optimizer/Dialect/FIRType.cpp b/lib/Optimizer/Dialect/FIRType.cpp
index 863babe..d1e1a8c 100644
--- a/lib/Optimizer/Dialect/FIRType.cpp
+++ b/lib/Optimizer/Dialect/FIRType.cpp
@@ -20,6 +20,9 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/ErrorHandling.h"
 
+#define GET_TYPEDEF_CLASSES
+#include "flang/Optimizer/Dialect/FIROpsTypes.cpp.inc"
+
 using namespace fir;
 
 namespace {
@@ -112,11 +115,6 @@
   return parseKindSingleton<fir::ComplexType>(parser);
 }
 
-// `shape` `<` rank `>`
-ShapeType parseShape(mlir::DialectAsmParser &parser) {
-  return parseRankSingleton<ShapeType>(parser);
-}
-
 // `shapeshift` `<` rank `>`
 ShapeShiftType parseShapeShift(mlir::DialectAsmParser &parser) {
   return parseRankSingleton<ShapeShiftType>(parser);
@@ -352,7 +350,8 @@
 
 // Implementation of the thin interface from dialect to type parser
 
-mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
+mlir::Type fir::parseFirType(FIROpsDialect *dialect,
+                             mlir::DialectAsmParser &parser) {
   llvm::StringRef typeNameLit;
   if (mlir::failed(parser.parseKeyword(&typeNameLit)))
     return {};
@@ -387,7 +386,8 @@
   if (typeNameLit == "ref")
     return parseReference(parser, loc);
   if (typeNameLit == "shape")
-    return parseShape(parser);
+    // TODO move to generatedTypeParser when all types have been moved
+    return ShapeType::parse(dialect->getContext(), parser);
   if (typeNameLit == "shapeshift")
     return parseShapeShift(parser);
   if (typeNameLit == "slice")
@@ -443,29 +443,6 @@
       : kind{kind}, len{len} {}
 };
 
-struct ShapeTypeStorage : public mlir::TypeStorage {
-  using KeyTy = unsigned;
-
-  static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); }
-
-  bool operator==(const KeyTy &key) const { return key == getRank(); }
-
-  static ShapeTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
-                                     unsigned rank) {
-    auto *storage = allocator.allocate<ShapeTypeStorage>();
-    return new (storage) ShapeTypeStorage{rank};
-  }
-
-  unsigned getRank() const { return rank; }
-
-protected:
-  unsigned rank;
-
-private:
-  ShapeTypeStorage() = delete;
-  explicit ShapeTypeStorage(unsigned rank) : rank{rank} {}
-};
-
 struct ShapeShiftTypeStorage : public mlir::TypeStorage {
   using KeyTy = unsigned;
 
@@ -1272,14 +1249,6 @@
   return llvm::hash_combine(0);
 }
 
-// Shape
-
-ShapeType fir::ShapeType::get(mlir::MLIRContext *ctxt, unsigned rank) {
-  return Base::get(ctxt, rank);
-}
-
-unsigned fir::ShapeType::getRank() const { return getImpl()->getRank(); }
-
 // Shapeshift
 
 ShapeShiftType fir::ShapeShiftType::get(mlir::MLIRContext *ctxt,
@@ -1478,7 +1447,9 @@
     return;
   }
   if (auto type = ty.dyn_cast<ShapeType>()) {
-    os << "shape<" << type.getRank() << '>';
+    // TODO when all type are moved to TableGen can be replaced by
+    // generatedTypePrinter
+    type.print(p);
     return;
   }
   if (auto type = ty.dyn_cast<ShapeShiftType>()) {