[mlir][IR][NFC] Simplify "splat" handling in `DenseIntOrFPElementsAttr` (#180965)
Since #180397, all elements of a `DenseIntOrFPElementsAttr` are padded
to full bytes. This enables additional simplifications: whether a
`DenseIntOrFPElementsAttr` is a splat or not can now be inferred from
the size of the buffer. This was not possible before because a single
byte sometimes contained multiple `i1` elements.
Discussion:
https://discourse.llvm.org/t/denseelementsattr-i1-element-type/62525
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index ee4707a..ee6a8f4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -204,16 +204,7 @@
ArrayRef<char> rawBuffer);
/// Returns true if the given buffer is a valid raw buffer for the given type.
- /// `detectedSplat` is set if the buffer is valid and represents a splat
- /// buffer. The definition may be expanded over time, but currently, a
- /// splat buffer is detected if:
- /// - For >1bit: The buffer consists of a single element.
- /// - For 1bit: The buffer consists of a single byte with value 0 or 255.
- ///
- /// User code should be prepared for additional, conformant patterns to be
- /// identified as splats in the future.
- static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer,
- bool &detectedSplat);
+ static bool isValidRawBuffer(ShapedType type, ArrayRef<char> rawBuffer);
//===--------------------------------------------------------------------===//
// Iterators
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index b67b8f9..798d3c8 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -408,8 +408,7 @@
let builders = [
AttrBuilderWithInferredContext<(ins "ShapedType":$type,
"ArrayRef<StringRef>":$values), [{
- return $_get(type.getContext(), type, values,
- /* isSplat */(values.size() == 1));
+ return $_get(type.getContext(), type, values);
}]>,
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 519609a3..5978a11 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -718,8 +718,7 @@
return nullptr;
ArrayRef<char> rawData(data);
- bool detectedSplat = false;
- if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
+ if (!DenseElementsAttr::isValidRawBuffer(type, rawData)) {
p.emitError(loc) << "elements hex data size is invalid for provided type: "
<< type;
return nullptr;
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index f1d95af..44a3dea 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -582,9 +582,7 @@
auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
rawBufferSize);
- bool isSplat = false;
- if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
- isSplat))
+ if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp))
return mlirAttributeGetNull();
return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
}
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788d..299c8af 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -262,8 +262,7 @@
// Check that the buffer meets the requirements to get converted to a
// DenseElementsAttr
- bool detectedSplat = false;
- if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
+ if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr))
return constOp->emitError("resource is not a valid buffer");
dstElementsAttr =
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 5786f53..f01a87a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -185,9 +185,8 @@
return std::nullopt;
// Check that the data are in a valid form
- bool isSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(attr.getShapedType(),
- blob->getData(), isSplat)) {
+ blob->getData())) {
return std::nullopt;
}
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index c0d8f7d..1f26860 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -43,23 +43,19 @@
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public AttributeStorage {
public:
- DenseElementsAttributeStorage(ShapedType type, bool isSplat)
- : type(type), isSplat(isSplat) {}
+ DenseElementsAttributeStorage(ShapedType type) : type(type) {}
ShapedType type;
- bool isSplat;
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage {
- DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data,
- bool isSplat = false)
- : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+ DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef<char> data)
+ : DenseElementsAttributeStorage(ty), data(data) {}
struct KeyTy {
- KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode,
- bool isSplat = false)
- : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+ KeyTy(ShapedType type, ArrayRef<char> data, llvm::hash_code hashCode)
+ : type(type), data(data), hashCode(hashCode) {}
/// The type of the dense elements.
ShapedType type;
@@ -69,9 +65,6 @@
/// The computed hash code for the storage data.
llvm::hash_code hashCode;
-
- /// A boolean that indicates if this data is a splat or not.
- bool isSplat;
};
/// Compare this storage instance with the provided key.
@@ -79,31 +72,22 @@
return key.type == type && key.data == data;
}
- /// Construct a key from a shaped type, raw data buffer, and a flag that
- /// signals if the data is already known to be a splat. Callers to this
- /// function are expected to tag preknown splat values when possible, e.g. one
- /// element shapes.
- static KeyTy getKey(ShapedType ty, ArrayRef<char> data, bool isKnownSplat) {
+ /// Construct a key from a shaped type and raw data buffer.
+ static KeyTy getKey(ShapedType ty, ArrayRef<char> data) {
// Handle an empty storage instance.
if (data.empty())
return KeyTy(ty, data, 0);
- // If the data is already known to be a splat, the key hash value is
- // directly the data buffer.
- if (isKnownSplat) {
- return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
- }
-
- // Otherwise, we need to check if the data corresponds to a splat or not.
-
- // Handle the simple case of only one element.
- size_t numElements = ty.getNumElements();
- assert(numElements != 1 && "splat of 1 element should already be detected");
-
size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
// Dense elements are padded to 8-bits.
size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
- assert(((data.size() / storageSize) == numElements) &&
+
+ // If the data buffer holds a single element, it is a known splat.
+ if (data.size() == storageSize)
+ return KeyTy(ty, data, llvm::hash_value(data));
+
+ assert(((data.size() / storageSize) ==
+ static_cast<size_t>(ty.getNumElements())) &&
"data does not hold expected number of elements");
// Create the initial hash value with just the first element.
@@ -117,7 +101,7 @@
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
// Otherwise, this is a splat so just return the hash of the first element.
- return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
+ return KeyTy(ty, firstElt, hashVal);
}
/// Hash the key for the storage.
@@ -139,7 +123,7 @@
}
return new (allocator.allocate<DenseIntOrFPElementsAttrStorage>())
- DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat);
+ DenseIntOrFPElementsAttrStorage(key.type, copy);
}
ArrayRef<char> data;
@@ -148,14 +132,12 @@
/// An attribute representing a reference to a dense vector or tensor object
/// containing strings.
struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
- DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data,
- bool isSplat = false)
- : DenseElementsAttributeStorage(ty, isSplat), data(data) {}
+ DenseStringElementsAttrStorage(ShapedType ty, ArrayRef<StringRef> data)
+ : DenseElementsAttributeStorage(ty), data(data) {}
struct KeyTy {
- KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode,
- bool isSplat = false)
- : type(type), data(data), hashCode(hashCode), isSplat(isSplat) {}
+ KeyTy(ShapedType type, ArrayRef<StringRef> data, llvm::hash_code hashCode)
+ : type(type), data(data), hashCode(hashCode) {}
/// The type of the dense elements.
ShapedType type;
@@ -165,9 +147,6 @@
/// The computed hash code for the storage data.
llvm::hash_code hashCode;
-
- /// A boolean that indicates if this data is a splat or not.
- bool isSplat;
};
/// Compare this storage instance with the provided key.
@@ -180,24 +159,15 @@
return key.data == data;
}
- /// Construct a key from a shaped type, StringRef data buffer, and a flag that
- /// signals if the data is already known to be a splat. Callers to this
- /// function are expected to tag preknown splat values when possible, e.g. one
- /// element shapes.
- static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data,
- bool isKnownSplat) {
+ /// Construct a key from a shaped type and StringRef data buffer.
+ static KeyTy getKey(ShapedType ty, ArrayRef<StringRef> data) {
// Handle an empty storage instance.
if (data.empty())
return KeyTy(ty, data, 0);
- // If the data is already known to be a splat, the key hash value is
- // directly the data buffer.
- if (isKnownSplat)
- return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
-
- // Handle the simple case of only one element.
- assert(ty.getNumElements() != 1 &&
- "splat of 1 element should already be detected");
+ // If the data buffer holds a single element, it is a known splat.
+ if (data.size() == 1)
+ return KeyTy(ty, data, llvm::hash_value(data.front()));
// Create the initial hash value with just the first element.
const auto &firstElt = data.front();
@@ -205,12 +175,12 @@
// Check to see if this storage represents a splat. If it doesn't then
// combine the hash for the data starting with the first non splat element.
- for (size_t i = 1, e = data.size(); i != e; i++)
+ for (size_t i = 1, e = data.size(); i != e; ++i)
if (firstElt != data[i])
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
// Otherwise, this is a splat so just return the hash of the first element.
- return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
+ return KeyTy(ty, data.take_front(), hashVal);
}
/// Hash the key for the storage.
@@ -226,15 +196,15 @@
ArrayRef<StringRef> copy, data = key.data;
if (data.empty()) {
return new (allocator.allocate<DenseStringElementsAttrStorage>())
- DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
+ DenseStringElementsAttrStorage(key.type, copy);
}
- int numEntries = key.isSplat ? 1 : data.size();
+ size_t numEntries = data.size();
// Compute the amount data needed to store the ArrayRef and StringRef
// contents.
size_t dataSize = sizeof(StringRef) * numEntries;
- for (int i = 0; i < numEntries; i++)
+ for (size_t i = 0; i < numEntries; ++i)
dataSize += data[i].size();
char *rawData = reinterpret_cast<char *>(
@@ -246,7 +216,7 @@
reinterpret_cast<StringRef *>(rawData), numEntries);
auto *stringData = rawData + numEntries * sizeof(StringRef);
- for (int i = 0; i < numEntries; i++) {
+ for (size_t i = 0; i < numEntries; ++i) {
memcpy(stringData, data[i].data(), data[i].size());
mutableCopy[i] = StringRef(stringData, data[i].size());
stringData += data[i].size();
@@ -256,7 +226,7 @@
ArrayRef<StringRef>(reinterpret_cast<StringRef *>(rawData), numEntries);
return new (allocator.allocate<DenseStringElementsAttrStorage>())
- DenseStringElementsAttrStorage(key.type, copy, key.isSplat);
+ DenseStringElementsAttrStorage(key.type, copy);
}
ArrayRef<StringRef> data;
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 241165a..fbbd9d2 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -1030,24 +1030,15 @@
/// Returns true if the given buffer is a valid raw buffer for the given type.
bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
- ArrayRef<char> rawBuffer,
- bool &detectedSplat) {
+ ArrayRef<char> rawBuffer) {
size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
int64_t numElements = type.getNumElements();
- // The initializer is always a splat if the result type has a single element.
- detectedSplat = numElements == 1;
-
- // All types are 8-bit aligned, so we can just check the buffer width
- // to know if only a single initializer element was passed in.
- if (rawBufferWidth == storageWidth) {
- detectedSplat = true;
- return true;
- }
-
- // The raw buffer is valid if it has the right size.
- return rawBufferWidth == storageWidth * numElements;
+ // The raw buffer is valid if it has a single element (splat) or the right
+ // number of elements.
+ return rawBufferWidth == storageWidth ||
+ rawBufferWidth == storageWidth * numElements;
}
/// Check the information for a C++ data type, check if this type is valid for
@@ -1121,7 +1112,13 @@
/// Returns true if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool DenseElementsAttr::isSplat() const {
- return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
+ // Splat iff the data array has exactly one element.
+ if (isa<DenseStringElementsAttr>(*this))
+ return getRawStringData().size() == 1;
+ // FP/Int case.
+ size_t storageSize = llvm::divideCeil(
+ getDenseElementBitWidth(getType().getElementType()), CHAR_BIT);
+ return getRawData().size() == storageSize;
}
/// Return if the given complex type has an integer element type.
@@ -1286,11 +1283,8 @@
DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
ArrayRef<char> data) {
assert(type.hasStaticShape() && "type must have static shape");
- bool isSplat = false;
- bool isValid = isValidRawBuffer(type, data, isSplat);
- assert(isValid);
- (void)isValid;
- return Base::get(type.getContext(), type, data, isSplat);
+ assert(isValidRawBuffer(type, data));
+ return Base::get(type.getContext(), type, data);
}
/// Overload of the raw 'get' method that asserts that the given type is of