[flang][fir] Update the kind mapping class.

The kind mapper provides a portable mechanism to map Fortran type KIND values
independent of the front-end to their corresponding MLIR and LLVM types.

Differential Revision: https://reviews.llvm.org/D96362

GitOrigin-RevId: 4dc87d1010351170d73ebd23869751fe1bd6ac26
diff --git a/include/flang/Optimizer/Support/KindMapping.h b/include/flang/Optimizer/Support/KindMapping.h
index b65f828..faef765 100644
--- a/include/flang/Optimizer/Support/KindMapping.h
+++ b/include/flang/Optimizer/Support/KindMapping.h
@@ -5,6 +5,10 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
 
 #ifndef OPTIMIZER_SUPPORT_KINDMAPPING_H
 #define OPTIMIZER_SUPPORT_KINDMAPPING_H
@@ -36,7 +40,8 @@
 ///   'c' : COMPLEX (encoding value)
 ///
 /// kind-value is either an unsigned integer (for 'i', 'l', and 'a') or one of
-/// 'Half', 'Float', 'Double', 'X86_FP80', or 'FP128' (for 'r' and 'c').
+/// 'Half', 'BFloat', 'Float', 'Double', 'X86_FP80', or 'FP128' (for 'r' and
+/// 'c').
 ///
 /// If LLVM adds support for new floating-point types, the final list should be
 /// extended.
@@ -47,8 +52,15 @@
   using LLVMTypeID = llvm::Type::TypeID;
   using MatchResult = mlir::ParseResult;
 
-  explicit KindMapping(mlir::MLIRContext *context);
-  explicit KindMapping(mlir::MLIRContext *context, llvm::StringRef map);
+  /// KindMapping constructors take an optional `defs` argument to specify the
+  /// default kinds for intrinsic types. To set the default kinds, an ArrayRef
+  /// of 6 KindTy must be passed. The kinds must be the given in the following
+  /// order: CHARACTER, COMPLEX, DOUBLE PRECISION, INTEGER, LOGICAL, and REAL.
+  /// If `defs` is not specified, default default kinds will be used.
+  explicit KindMapping(mlir::MLIRContext *context,
+                       llvm::ArrayRef<KindTy> defs = llvm::None);
+  explicit KindMapping(mlir::MLIRContext *context, llvm::StringRef map,
+                       llvm::ArrayRef<KindTy> defs = llvm::None);
 
   /// Get the size in bits of !fir.char<kind>
   Bitsize getCharacterBitsize(KindTy kind) const;
@@ -73,13 +85,26 @@
   /// Get the float semantics of !fir.real<kind>
   const llvm::fltSemantics &getFloatSemantics(KindTy kind) const;
 
+  //===--------------------------------------------------------------------===//
+  // Default kinds of intrinsic types
+  //===--------------------------------------------------------------------===//
+
+  KindTy defaultCharacterKind() const;
+  KindTy defaultComplexKind() const;
+  KindTy defaultDoubleKind() const;
+  KindTy defaultIntegerKind() const;
+  KindTy defaultLogicalKind() const;
+  KindTy defaultRealKind() const;
+
 private:
   MatchResult badMapString(const llvm::Twine &ptr);
   MatchResult parse(llvm::StringRef kindMap);
+  mlir::LogicalResult setDefaultKinds(llvm::ArrayRef<KindTy> defs);
 
   mlir::MLIRContext *context;
   llvm::DenseMap<std::pair<char, KindTy>, Bitsize> intMap;
   llvm::DenseMap<std::pair<char, KindTy>, LLVMTypeID> floatMap;
+  llvm::DenseMap<char, KindTy> defaultMap;
 };
 
 } // namespace fir
diff --git a/lib/Optimizer/Support/KindMapping.cpp b/lib/Optimizer/Support/KindMapping.cpp
index e1debae..6cb8d5e 100644
--- a/lib/Optimizer/Support/KindMapping.cpp
+++ b/lib/Optimizer/Support/KindMapping.cpp
@@ -5,6 +5,10 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Support/KindMapping.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -32,12 +36,14 @@
 }
 
 /// Floating-point types default to the kind value being the size of the value
-/// in bytes. The default is to translate kinds of 2, 4, 8, 10, and 16 to a
+/// in bytes. The default is to translate kinds of 2, 3, 4, 8, 10, and 16 to a
 /// valid llvm::Type::TypeID value. Otherwise, the default is FloatTyID.
 static LLVMTypeID defaultRealKind(KindTy kind) {
   switch (kind) {
   case 2:
     return LLVMTypeID::HalfTyID;
+  case 3:
+    return LLVMTypeID::BFloatTyID;
   case 4:
     return LLVMTypeID::FloatTyID;
   case 8:
@@ -81,6 +87,8 @@
   switch (doLookup<LLVMTypeID, KEY>(defaultRealKind, map, kind)) {
   case LLVMTypeID::HalfTyID:
     return llvm::APFloat::IEEEhalf();
+  case LLVMTypeID::BFloatTyID:
+    return llvm::APFloat::BFloat();
   case LLVMTypeID::FloatTyID:
     return llvm::APFloat::IEEEsingle();
   case LLVMTypeID::DoubleTyID:
@@ -148,6 +156,10 @@
     result = LLVMTypeID::HalfTyID;
     return mlir::success();
   }
+  if (mlir::succeeded(matchString(ptr, "BFloat"))) {
+    result = LLVMTypeID::BFloatTyID;
+    return mlir::success();
+  }
   if (mlir::succeeded(matchString(ptr, "Float"))) {
     result = LLVMTypeID::FloatTyID;
     return mlir::success();
@@ -171,16 +183,18 @@
   return mlir::failure();
 }
 
-fir::KindMapping::KindMapping(mlir::MLIRContext *context, llvm::StringRef map)
+fir::KindMapping::KindMapping(mlir::MLIRContext *context, llvm::StringRef map,
+                              llvm::ArrayRef<KindTy> defs)
     : context{context} {
-  if (mlir::failed(parse(map))) {
-    intMap.clear();
-    floatMap.clear();
-  }
+  if (mlir::failed(setDefaultKinds(defs)))
+    llvm::report_fatal_error("bad default kinds");
+  if (mlir::failed(parse(map)))
+    llvm::report_fatal_error("could not parse kind map");
 }
 
-fir::KindMapping::KindMapping(mlir::MLIRContext *context)
-    : KindMapping{context, clKindMapping} {}
+fir::KindMapping::KindMapping(mlir::MLIRContext *context,
+                              llvm::ArrayRef<KindTy> defs)
+    : KindMapping{context, clKindMapping, defs} {}
 
 MatchResult fir::KindMapping::badMapString(const llvm::Twine &ptr) {
   auto unknown = mlir::UnknownLoc::get(context);
@@ -248,3 +262,65 @@
 fir::KindMapping::getFloatSemantics(KindTy kind) const {
   return getFloatSemanticsOfKind<'r'>(kind, floatMap);
 }
+
+mlir::LogicalResult
+fir::KindMapping::setDefaultKinds(llvm::ArrayRef<KindTy> defs) {
+  if (defs.empty()) {
+    // generic front-end defaults
+    const KindTy genericKind = 4;
+    defaultMap.insert({'a', 1});
+    defaultMap.insert({'c', genericKind});
+    defaultMap.insert({'d', 2 * genericKind});
+    defaultMap.insert({'i', genericKind});
+    defaultMap.insert({'l', genericKind});
+    defaultMap.insert({'r', genericKind});
+    return mlir::success();
+  }
+  if (defs.size() != 6)
+    return mlir::failure();
+
+  // defaults determined after command-line processing
+  defaultMap.insert({'a', defs[0]});
+  defaultMap.insert({'c', defs[1]});
+  defaultMap.insert({'d', defs[2]});
+  defaultMap.insert({'i', defs[3]});
+  defaultMap.insert({'l', defs[4]});
+  defaultMap.insert({'r', defs[5]});
+  return mlir::success();
+}
+
+KindTy fir::KindMapping::defaultCharacterKind() const {
+  auto iter = defaultMap.find('a');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
+
+KindTy fir::KindMapping::defaultComplexKind() const {
+  auto iter = defaultMap.find('c');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
+
+KindTy fir::KindMapping::defaultDoubleKind() const {
+  auto iter = defaultMap.find('d');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
+
+KindTy fir::KindMapping::defaultIntegerKind() const {
+  auto iter = defaultMap.find('i');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
+
+KindTy fir::KindMapping::defaultLogicalKind() const {
+  auto iter = defaultMap.find('l');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
+
+KindTy fir::KindMapping::defaultRealKind() const {
+  auto iter = defaultMap.find('r');
+  assert(iter != defaultMap.end());
+  return iter->second;
+}
diff --git a/unittests/Optimizer/CMakeLists.txt b/unittests/Optimizer/CMakeLists.txt
index 9e88a90..a021ce0 100644
--- a/unittests/Optimizer/CMakeLists.txt
+++ b/unittests/Optimizer/CMakeLists.txt
@@ -7,6 +7,7 @@
 
 add_flang_unittest(FlangOptimizerTests
   InternalNamesTest.cpp
+  KindMappingTest.cpp
 )
 target_link_libraries(FlangOptimizerTests
   PRIVATE
diff --git a/unittests/Optimizer/KindMappingTest.cpp b/unittests/Optimizer/KindMappingTest.cpp
new file mode 100644
index 0000000..d99b817
--- /dev/null
+++ b/unittests/Optimizer/KindMappingTest.cpp
@@ -0,0 +1,194 @@
+//===- KindMappingTest.cpp ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Support/KindMapping.h"
+#include "gtest/gtest.h"
+#include <string>
+
+using namespace fir;
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class MLIRContext;
+} // namespace mlir
+
+using Bitsize = fir::KindMapping::Bitsize;
+using LLVMTypeID = fir::KindMapping::LLVMTypeID;
+
+struct DefaultStringTests : public testing::Test {
+public:
+  void SetUp() { defaultString = new KindMapping(context); }
+  void TearDown() { delete defaultString; }
+
+  KindMapping *defaultString{};
+  mlir::MLIRContext *context{};
+};
+
+struct CommandLineStringTests : public testing::Test {
+public:
+  void SetUp() {
+    commandLineString = new KindMapping(context,
+        "i10:80,l3:24,a1:8,r54:Double,c20:X86_FP80,r11:PPC_FP128,"
+        "r12:FP128,r13:X86_FP80,r14:Double,r15:Float,r16:Half,r23:BFloat");
+    clStringConflict =
+        new KindMapping(context, "i10:80,i10:40,r54:Double,r54:X86_FP80");
+  }
+  void TearDown() {
+    delete commandLineString;
+    delete clStringConflict;
+  }
+
+  KindMapping *commandLineString{};
+  KindMapping *clStringConflict{};
+  mlir::MLIRContext *context{};
+};
+
+struct KindDefaultsTests : public testing::Test {
+public:
+  void SetUp() {
+    defaultDefaultKinds = new KindMapping(context);
+    overrideDefaultKinds =
+        new KindMapping(context, {20, 121, 32, 133, 44, 145});
+  }
+  void TearDown() {
+    delete defaultDefaultKinds;
+    delete overrideDefaultKinds;
+  }
+
+  mlir::MLIRContext *context{};
+  KindMapping *defaultDefaultKinds{};
+  KindMapping *overrideDefaultKinds{};
+};
+
+TEST_F(DefaultStringTests, getIntegerBitsizeTest) {
+  EXPECT_EQ(defaultString->getIntegerBitsize(10), 80u);
+  EXPECT_EQ(defaultString->getIntegerBitsize(0), 0u);
+}
+
+TEST_F(DefaultStringTests, getCharacterBitsizeTest) {
+  EXPECT_EQ(defaultString->getCharacterBitsize(10), 80u);
+  EXPECT_EQ(defaultString->getCharacterBitsize(0), 0u);
+}
+
+TEST_F(DefaultStringTests, getLogicalBitsizeTest) {
+  EXPECT_EQ(defaultString->getLogicalBitsize(10), 80u);
+  // Unsigned values are expected
+  std::string actual = std::to_string(defaultString->getLogicalBitsize(-10));
+  std::string expect = "-80";
+  EXPECT_NE(actual, expect);
+}
+
+TEST_F(DefaultStringTests, getRealTypeIDTest) {
+  EXPECT_EQ(defaultString->getRealTypeID(2), LLVMTypeID::HalfTyID);
+  EXPECT_EQ(defaultString->getRealTypeID(3), LLVMTypeID::BFloatTyID);
+  EXPECT_EQ(defaultString->getRealTypeID(4), LLVMTypeID::FloatTyID);
+  EXPECT_EQ(defaultString->getRealTypeID(8), LLVMTypeID::DoubleTyID);
+  EXPECT_EQ(defaultString->getRealTypeID(10), LLVMTypeID::X86_FP80TyID);
+  EXPECT_EQ(defaultString->getRealTypeID(16), LLVMTypeID::FP128TyID);
+  // Default cases
+  EXPECT_EQ(defaultString->getRealTypeID(-1), LLVMTypeID::FloatTyID);
+  EXPECT_EQ(defaultString->getRealTypeID(1), LLVMTypeID::FloatTyID);
+}
+
+TEST_F(DefaultStringTests, getComplexTypeIDTest) {
+  EXPECT_EQ(defaultString->getComplexTypeID(2), LLVMTypeID::HalfTyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(3), LLVMTypeID::BFloatTyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(4), LLVMTypeID::FloatTyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(8), LLVMTypeID::DoubleTyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(10), LLVMTypeID::X86_FP80TyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(16), LLVMTypeID::FP128TyID);
+  // Default cases
+  EXPECT_EQ(defaultString->getComplexTypeID(-1), LLVMTypeID::FloatTyID);
+  EXPECT_EQ(defaultString->getComplexTypeID(1), LLVMTypeID::FloatTyID);
+}
+
+TEST_F(DefaultStringTests, getFloatSemanticsTest) {
+  EXPECT_EQ(&defaultString->getFloatSemantics(2), &llvm::APFloat::IEEEhalf());
+  EXPECT_EQ(&defaultString->getFloatSemantics(3), &llvm::APFloat::BFloat());
+  EXPECT_EQ(&defaultString->getFloatSemantics(4), &llvm::APFloat::IEEEsingle());
+  EXPECT_EQ(&defaultString->getFloatSemantics(8), &llvm::APFloat::IEEEdouble());
+  EXPECT_EQ(&defaultString->getFloatSemantics(10),
+      &llvm::APFloat::x87DoubleExtended());
+  EXPECT_EQ(&defaultString->getFloatSemantics(16), &llvm::APFloat::IEEEquad());
+
+  // Default cases
+  EXPECT_EQ(
+      &defaultString->getFloatSemantics(-1), &llvm::APFloat::IEEEsingle());
+  EXPECT_EQ(&defaultString->getFloatSemantics(1), &llvm::APFloat::IEEEsingle());
+}
+
+TEST_F(CommandLineStringTests, getIntegerBitsizeTest) {
+  // KEY is present in map.
+  EXPECT_EQ(commandLineString->getIntegerBitsize(10), 80u);
+  EXPECT_EQ(commandLineString->getCharacterBitsize(1), 8u);
+  EXPECT_EQ(commandLineString->getLogicalBitsize(3), 24u);
+  EXPECT_EQ(commandLineString->getComplexTypeID(20), LLVMTypeID::X86_FP80TyID);
+  EXPECT_EQ(commandLineString->getRealTypeID(54), LLVMTypeID::DoubleTyID);
+  EXPECT_EQ(commandLineString->getRealTypeID(11), LLVMTypeID::PPC_FP128TyID);
+  EXPECT_EQ(&commandLineString->getFloatSemantics(11),
+      &llvm::APFloat::PPCDoubleDouble());
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(12), &llvm::APFloat::IEEEquad());
+  EXPECT_EQ(&commandLineString->getFloatSemantics(13),
+      &llvm::APFloat::x87DoubleExtended());
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(14), &llvm::APFloat::IEEEdouble());
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(15), &llvm::APFloat::IEEEsingle());
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(16), &llvm::APFloat::IEEEhalf());
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(23), &llvm::APFloat::BFloat());
+
+  // Converts to default case
+  EXPECT_EQ(
+      &commandLineString->getFloatSemantics(20), &llvm::APFloat::IEEEsingle());
+
+  // KEY is absent from map, Default values are expected.
+  EXPECT_EQ(commandLineString->getIntegerBitsize(9), 72u);
+  EXPECT_EQ(commandLineString->getCharacterBitsize(9), 72u);
+  EXPECT_EQ(commandLineString->getLogicalBitsize(9), 72u);
+  EXPECT_EQ(commandLineString->getComplexTypeID(9), LLVMTypeID::FloatTyID);
+  EXPECT_EQ(commandLineString->getRealTypeID(9), LLVMTypeID::FloatTyID);
+
+  // KEY repeats in map.
+  EXPECT_NE(clStringConflict->getIntegerBitsize(10), 80u);
+  EXPECT_NE(clStringConflict->getRealTypeID(10), LLVMTypeID::DoubleTyID);
+}
+
+TEST(KindMappingDeathTests, mapTest) {
+  mlir::MLIRContext *context{};
+  // Catch parsing errors
+  ASSERT_DEATH(new KindMapping(context, "r10:Double,r20:Doubl"), "");
+  ASSERT_DEATH(new KindMapping(context, "10:Double"), "");
+  ASSERT_DEATH(new KindMapping(context, "rr:Double"), "");
+  ASSERT_DEATH(new KindMapping(context, "rr:"), "");
+  ASSERT_DEATH(new KindMapping(context, "rr:Double MoreContent"), "");
+  // length of 'size' > 10
+  ASSERT_DEATH(new KindMapping(context, "i11111111111:10"), "");
+}
+
+TEST_F(KindDefaultsTests, getIntegerBitsizeTest) {
+   EXPECT_EQ(defaultDefaultKinds->defaultCharacterKind(), 1u);
+   EXPECT_EQ(defaultDefaultKinds->defaultComplexKind(), 4u);
+   EXPECT_EQ(defaultDefaultKinds->defaultDoubleKind(), 8u);
+   EXPECT_EQ(defaultDefaultKinds->defaultIntegerKind(), 4u);
+   EXPECT_EQ(defaultDefaultKinds->defaultLogicalKind(), 4u);
+   EXPECT_EQ(defaultDefaultKinds->defaultRealKind(), 4u);
+
+   EXPECT_EQ(overrideDefaultKinds->defaultCharacterKind(), 20u);
+   EXPECT_EQ(overrideDefaultKinds->defaultComplexKind(), 121u);
+   EXPECT_EQ(overrideDefaultKinds->defaultDoubleKind(), 32u);
+   EXPECT_EQ(overrideDefaultKinds->defaultIntegerKind(), 133u);
+   EXPECT_EQ(overrideDefaultKinds->defaultLogicalKind(), 44u);
+   EXPECT_EQ(overrideDefaultKinds->defaultRealKind(), 145u);
+}
+
+// main() from gtest_main