[mlir][IR][NFC] Simplify "splat" handling in `DenseIntOrFPElementsAttr` (#180965)

Since #180397, all elements of a `DenseIntOrFPElementsAttr` are padded
to full bytes. This enables additional simplifications: whether a
`DenseIntOrFPElementsAttr` is a splat or not can now be inferred from
the size of the buffer. This was not possible before because a single
byte sometimes contained multiple `i1` elements.

Discussion:
https://discourse.llvm.org/t/denseelementsattr-i1-element-type/62525
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ee4707a..ee6a8f4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -204,16 +204,7 @@
                                             ArrayRef<char> rawBuffer);
 
   /// Returns true if the given buffer is a valid raw buffer for the given type.
-  /// `detectedSplat` is set if the buffer is valid and represents a splat
-  /// buffer. The definition may be expanded over time, but currently, a
-  /// splat buffer is detected if:
-  ///   - For >1bit: The buffer consists of a single element.
-  ///   - For 1bit: The buffer consists of a single byte with value 0 or 255.
-  ///
-  /// User code should be prepared for additional, conformant patterns to be
-  /// identified as splats in the future.
-  static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
-                               bool &detectedSplat);
+  static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer);
 
   //===--------------------------------------------------------------------===//
   // Iterators
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index b67b8f9..798d3c8 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -408,8 +408,7 @@
   let builders = [
     AttrBuilderWithInferredContext<(ins "ShapedType":$type,
                                         "ArrayRef<StringRef>":$values), [{
-      return $_get(type.getContext(), type, values,
-                   /* isSplat */(values.size() == 1));
+      return $_get(type.getContext(), type, values);
     }]>,
   ];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 519609a3..5978a11 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -718,8 +718,7 @@
     return nullptr;
 
   ArrayRef<char> rawData(data);
-  bool detectedSplat = false;
-  if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
+  if (!DenseElementsAttr::isValidRawBuffer(type, rawData)) {
     p.emitError(loc) << "elements hex data size is invalid for provided type: "
                      << type;
     return nullptr;
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index f1d95af..44a3dea 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -582,9 +582,7 @@
   auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
   ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
                               rawBufferSize);
-  bool isSplat = false;
-  if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
-                                           isSplat))
+  if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp))
     return mlirAttributeGetNull();
   return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
 }
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788d..299c8af 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -262,8 +262,7 @@
 
       // Check that the buffer meets the requirements to get converted to a
       // DenseElementsAttr
-      bool detectedSplat = false;
-      if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
+      if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr))
         return constOp->emitError("resource is not a valid buffer");
 
       dstElementsAttr =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 5786f53..f01a87a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -185,9 +185,8 @@
       return std::nullopt;
 
     // Check that the data are in a valid form
-    bool isSplat = false;
     if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
-                                             blob->getData(), isSplat)) {
+                                             blob->getData())) {
       return std::nullopt;
     }
 
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index c0d8f7d..1f26860 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -43,23 +43,19 @@
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseElementsAttributeStorage : public AttributeStorage {
 public:
-  DenseElementsAttributeStorage(ShapedType type, bool isSplat)
-      : type(type), isSplat(isSplat) {}
+  DenseElementsAttributeStorage(ShapedType type) : type(type) {}
 
   ShapedType type;
-  bool isSplat;
 };
 
 /// An attribute representing a reference to a dense vector or tensor object.
 struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
-  DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data,
-                                  bool isSplat = false)
-      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+  DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data)
+      : DenseElementsAttributeStorage(ty), data(data) {}
 
   struct KeyTy {
-    KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
-          bool isSplat = false)
-        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+    KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode)
+        : type(type), data(data), hashCode(hashCode) {}
 
     /// The type of the dense elements.
     ShapedType type;
@@ -69,9 +65,6 @@
 
     /// The computed hash code for the storage data.
     llvm::hash_code hashCode;
-
-    /// A boolean that indicates if this data is a splat or not.
-    bool isSplat;
   };
 
   /// Compare this storage instance with the provided key.
@@ -79,31 +72,22 @@
     return key.type == type && key.data == data;
   }
 
-  /// Construct a key from a shaped type, raw data buffer, and a flag that
-  /// signals if the data is already known to be a splat. Callers to this
-  /// function are expected to tag preknown splat values when possible, e.g. one
-  /// element shapes.
-  static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
+  /// Construct a key from a shaped type and raw data buffer.
+  static KeyTy getKey(ShapedType ty, ArrayRef<char> data) {
     // Handle an empty storage instance.
     if (data.empty())
       return KeyTy(ty, data, 0);
 
-    // If the data is already known to be a splat, the key hash value is
-    // directly the data buffer.
-    if (isKnownSplat) {
-      return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
-    }
-
-    // Otherwise, we need to check if the data corresponds to a splat or not.
-
-    // Handle the simple case of only one element.
-    size_t numElements = ty.getNumElements();
-    assert(numElements != 1 && "splat of 1 element should already be detected");
-
     size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
     // Dense elements are padded to 8-bits.
     size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
-    assert(((data.size() / storageSize) == numElements) &&
+
+    // If the data buffer holds a single element, it is a known splat.
+    if (data.size() == storageSize)
+      return KeyTy(ty, data, llvm::hash_value(data));
+
+    assert(((data.size() / storageSize) ==
+            static_cast<size_t>(ty.getNumElements())) &&
            "data does not hold expected number of elements");
 
     // Create the initial hash value with just the first element.
@@ -117,7 +101,7 @@
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
     // Otherwise, this is a splat so just return the hash of the first element.
-    return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
+    return KeyTy(ty, firstElt, hashVal);
   }
 
   /// Hash the key for the storage.
@@ -139,7 +123,7 @@
     }
 
     return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
-        DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat);
+        DenseIntOrFPElementsAttrStorage(key.type, copy);
   }
 
   ArrayRef<char> data;
@@ -148,14 +132,12 @@
 /// An attribute representing a reference to a dense vector or tensor object
 /// containing strings.
 struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
-  DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data,
-                                 bool isSplat = false)
-      : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+  DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data)
+      : DenseElementsAttributeStorage(ty), data(data) {}
 
   struct KeyTy {
-    KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
-          bool isSplat = false)
-        : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+    KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode)
+        : type(type), data(data), hashCode(hashCode) {}
 
     /// The type of the dense elements.
     ShapedType type;
@@ -165,9 +147,6 @@
 
     /// The computed hash code for the storage data.
     llvm::hash_code hashCode;
-
-    /// A boolean that indicates if this data is a splat or not.
-    bool isSplat;
   };
 
   /// Compare this storage instance with the provided key.
@@ -180,24 +159,15 @@
     return key.data == data;
   }
 
-  /// Construct a key from a shaped type, StringRef data buffer, and a flag that
-  /// signals if the data is already known to be a splat. Callers to this
-  /// function are expected to tag preknown splat values when possible, e.g. one
-  /// element shapes.
-  static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
-                      bool isKnownSplat) {
+  /// Construct a key from a shaped type and StringRef data buffer.
+  static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data) {
     // Handle an empty storage instance.
     if (data.empty())
       return KeyTy(ty, data, 0);
 
-    // If the data is already known to be a splat, the key hash value is
-    // directly the data buffer.
-    if (isKnownSplat)
-      return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
-
-    // Handle the simple case of only one element.
-    assert(ty.getNumElements() != 1 &&
-           "splat of 1 element should already be detected");
+    // If the data buffer holds a single element, it is a known splat.
+    if (data.size() == 1)
+      return KeyTy(ty, data, llvm::hash_value(data.front()));
 
     // Create the initial hash value with just the first element.
     const auto &firstElt = data.front();
@@ -205,12 +175,12 @@
 
     // Check to see if this storage represents a splat. If it doesn't then
     // combine the hash for the data starting with the first non splat element.
-    for (size_t i = 1, e = data.size(); i != e; i++)
+    for (size_t i = 1, e = data.size(); i != e; ++i)
       if (firstElt != data[i])
         return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
 
     // Otherwise, this is a splat so just return the hash of the first element.
-    return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
+    return KeyTy(ty, data.take_front(), hashVal);
   }
 
   /// Hash the key for the storage.
@@ -226,15 +196,15 @@
     ArrayRef<StringRef> copy, data = key.data;
     if (data.empty()) {
       return new (allocator.allocate<DenseStringElementsAttrStorage>())
-          DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
+          DenseStringElementsAttrStorage(key.type, copy);
     }
 
-    int numEntries = key.isSplat ? 1 : data.size();
+    size_t numEntries = data.size();
 
     // Compute the amount data needed to store the ArrayRef and StringRef
     // contents.
     size_t dataSize = sizeof(StringRef) * numEntries;
-    for (int i = 0; i < numEntries; i++)
+    for (size_t i = 0; i < numEntries; ++i)
       dataSize += data[i].size();
 
     char *rawData = reinterpret_cast<char *>(
@@ -246,7 +216,7 @@
         reinterpret_cast<StringRef *>(rawData), numEntries);
     auto *stringData = rawData + numEntries * sizeof(StringRef);
 
-    for (int i = 0; i < numEntries; i++) {
+    for (size_t i = 0; i < numEntries; ++i) {
       memcpy(stringData, data[i].data(), data[i].size());
       mutableCopy[i] = StringRef(stringData, data[i].size());
       stringData += data[i].size();
@@ -256,7 +226,7 @@
         ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
 
     return new (allocator.allocate<DenseStringElementsAttrStorage>())
-        DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
+        DenseStringElementsAttrStorage(key.type, copy);
   }
 
   ArrayRef<StringRef> data;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 241165a..fbbd9d2 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1030,24 +1030,15 @@
 
 /// Returns true if the given buffer is a valid raw buffer for the given type.
 bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
-                                         ArrayRef<char> rawBuffer,
-                                         bool &detectedSplat) {
+                                         ArrayRef<char> rawBuffer) {
   size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
   size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
   int64_t numElements = type.getNumElements();
 
-  // The initializer is always a splat if the result type has a single element.
-  detectedSplat = numElements == 1;
-
-  // All types are 8-bit aligned, so we can just check the buffer width
-  // to know if only a single initializer element was passed in.
-  if (rawBufferWidth == storageWidth) {
-    detectedSplat = true;
-    return true;
-  }
-
-  // The raw buffer is valid if it has the right size.
-  return rawBufferWidth == storageWidth * numElements;
+  // The raw buffer is valid if it has a single element (splat) or the right
+  // number of elements.
+  return rawBufferWidth == storageWidth ||
+         rawBufferWidth == storageWidth * numElements;
 }
 
 /// Check the information for a C++ data type, check if this type is valid for
@@ -1121,7 +1112,13 @@
 /// Returns true if this attribute corresponds to a splat, i.e. if all element
 /// values are the same.
 bool DenseElementsAttr::isSplat() const {
-  return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
+  // Splat iff the data array has exactly one element.
+  if (isa<DenseStringElementsAttr>(*this))
+    return getRawStringData().size() == 1;
+  // FP/Int case.
+  size_t storageSize = llvm::divideCeil(
+      getDenseElementBitWidth(getType().getElementType()), CHAR_BIT);
+  return getRawData().size() == storageSize;
 }
 
 /// Return if the given complex type has an integer element type.
@@ -1286,11 +1283,8 @@
 DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
                                                    ArrayRef<char> data) {
   assert(type.hasStaticShape() && "type must have static shape");
-  bool isSplat = false;
-  bool isValid = isValidRawBuffer(type, data, isSplat);
-  assert(isValid);
-  (void)isValid;
-  return Base::get(type.getContext(), type, data, isSplat);
+  assert(isValidRawBuffer(type, data));
+  return Base::get(type.getContext(), type, data);
 }
 
 /// Overload of the raw 'get' method that asserts that the given type is of