| //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir; |
| using namespace mlir::sparse_tensor; |
| |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" |
| |
| //===----------------------------------------------------------------------===// |
| // TensorDialect Attribute Methods. |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_ATTRDEF_CLASSES |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
| |
| static bool acceptBitWidth(unsigned bitWidth) { |
| switch (bitWidth) { |
| case 0: |
| case 8: |
| case 16: |
| case 32: |
| case 64: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { |
| if (failed(parser.parseLess())) |
| return {}; |
| // Parse the data as a dictionary. |
| DictionaryAttr dict; |
| if (failed(parser.parseAttribute(dict))) |
| return {}; |
| if (failed(parser.parseGreater())) |
| return {}; |
| // Process the data from the parsed dictionary value into struct-like data. |
| SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; |
| AffineMap map = {}; |
| unsigned ptr = 0; |
| unsigned ind = 0; |
| for (const NamedAttribute &attr : dict) { |
| if (attr.getName() == "dimLevelType") { |
| auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); |
| if (!arrayAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an array for dimension level types"); |
| return {}; |
| } |
| for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { |
| auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); |
| if (!strAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected a string value in dimension level types"); |
| return {}; |
| } |
| auto strVal = strAttr.getValue(); |
| if (strVal == "dense") { |
| dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); |
| } else if (strVal == "compressed") { |
| dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); |
| } else if (strVal == "singleton") { |
| dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); |
| } else { |
| parser.emitError(parser.getNameLoc(), |
| "unexpected dimension level type: ") |
| << strVal; |
| return {}; |
| } |
| } |
| } else if (attr.getName() == "dimOrdering") { |
| auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); |
| if (!affineAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an affine map for dimension ordering"); |
| return {}; |
| } |
| map = affineAttr.getValue(); |
| } else if (attr.getName() == "pointerBitWidth") { |
| auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); |
| if (!intAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an integral pointer bitwidth"); |
| return {}; |
| } |
| ptr = intAttr.getInt(); |
| } else if (attr.getName() == "indexBitWidth") { |
| auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); |
| if (!intAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an integral index bitwidth"); |
| return {}; |
| } |
| ind = intAttr.getInt(); |
| } else { |
| parser.emitError(parser.getNameLoc(), "unexpected key: ") |
| << attr.getName().strref(); |
| return {}; |
| } |
| } |
| // Construct struct-like storage for attribute. |
| return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, |
| map, ptr, ind); |
| } |
| |
| void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { |
| // Print the struct-like storage in dictionary fashion. |
| printer << "<{ dimLevelType = [ "; |
| for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { |
| switch (getDimLevelType()[i]) { |
| case DimLevelType::Dense: |
| printer << "\"dense\""; |
| break; |
| case DimLevelType::Compressed: |
| printer << "\"compressed\""; |
| break; |
| case DimLevelType::Singleton: |
| printer << "\"singleton\""; |
| break; |
| } |
| if (i != e - 1) |
| printer << ", "; |
| } |
| printer << " ]"; |
| if (getDimOrdering()) |
| printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; |
| printer << ", pointerBitWidth = " << getPointerBitWidth() |
| << ", indexBitWidth = " << getIndexBitWidth() << " }>"; |
| } |
| |
| LogicalResult SparseTensorEncodingAttr::verify( |
| function_ref<InFlightDiagnostic()> emitError, |
| ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, |
| unsigned pointerBitWidth, unsigned indexBitWidth) { |
| if (!acceptBitWidth(pointerBitWidth)) |
| return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; |
| if (!acceptBitWidth(indexBitWidth)) |
| return emitError() << "unexpected index bitwidth: " << indexBitWidth; |
| if (dimOrdering) { |
| if (!dimOrdering.isPermutation()) |
| return emitError() |
| << "expected a permutation affine map for dimension ordering"; |
| if (dimOrdering.getNumResults() != dimLevelType.size()) |
| return emitError() << "unexpected mismatch in ordering and dimension " |
| "level types size"; |
| } |
| return success(); |
| } |
| |
| LogicalResult SparseTensorEncodingAttr::verifyEncoding( |
| ArrayRef<int64_t> shape, Type elementType, |
| function_ref<InFlightDiagnostic()> emitError) const { |
| // Check structural integrity. |
| if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), |
| getPointerBitWidth(), getIndexBitWidth()))) |
| return failure(); |
| // Check integrity with tensor type specifics. Dimension ordering is optional, |
| // but we always should have dimension level types for the full rank. |
| unsigned size = shape.size(); |
| if (size == 0) |
| return emitError() << "expected non-scalar sparse tensor"; |
| if (getDimOrdering() && getDimOrdering().getNumResults() != size) |
| return emitError() << "expected an affine map of size " << size |
| << " for dimension ordering"; |
| if (getDimLevelType().size() != size) |
| return emitError() << "expected an array of size " << size |
| << " for dimension level types"; |
| return success(); |
| } |
| |
| SparseTensorEncodingAttr |
| mlir::sparse_tensor::getSparseTensorEncoding(Type type) { |
| if (auto ttp = type.dyn_cast<RankedTensorType>()) |
| return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); |
| return nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorDialect Operations. |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult isInBounds(Value dim, Value tensor) { |
| if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) { |
| unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt(); |
| if (d >= tensor.getType().cast<RankedTensorType>().getRank()) |
| return failure(); |
| } |
| return success(); // in bounds, or symbolic |
| } |
| |
| static LogicalResult isMatchingWidth(Value result, unsigned width) { |
| Type etp = result.getType().cast<MemRefType>().getElementType(); |
| if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) |
| return success(); |
| return failure(); |
| } |
| |
| static LogicalResult verify(NewOp op) { |
| if (!getSparseTensorEncoding(op.result().getType())) |
| return op.emitError("expected a sparse tensor result"); |
| return success(); |
| } |
| |
| static LogicalResult verify(InitOp op) { |
| if (!getSparseTensorEncoding(op.result().getType())) |
| return op.emitError("expected a sparse tensor result"); |
| RankedTensorType ttp = op.getType().cast<RankedTensorType>(); |
| unsigned rank = ttp.getRank(); |
| if (rank != op.sizes().size()) |
| return op.emitError("unexpected mismatch between tensor rank and sizes: ") |
| << rank << " vs. " << op.sizes().size(); |
| auto shape = ttp.getShape(); |
| for (unsigned i = 0; i < rank; i++) { |
| if (shape[i] == ShapedType::kDynamicSize) |
| continue; |
| auto constantOp = op.sizes()[i].getDefiningOp<arith::ConstantOp>(); |
| if (!constantOp || |
| constantOp.getValue().cast<IntegerAttr>().getInt() != shape[i]) |
| return op.emitError("unexpected mismatch with static dimension size ") |
| << shape[i]; |
| } |
| return success(); |
| } |
| |
| static LogicalResult verify(ConvertOp op) { |
| if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { |
| if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { |
| if (tp1.getRank() != tp2.getRank()) |
| return op.emitError("unexpected conversion mismatch in rank"); |
| auto shape1 = tp1.getShape(); |
| auto shape2 = tp2.getShape(); |
| // Accept size matches between the source and the destination type |
| // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or |
| // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). |
| for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { |
| if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) |
| return op.emitError("unexpected conversion mismatch in dimension ") |
| << d; |
| } |
| return success(); |
| } |
| } |
| return op.emitError("unexpected type in convert"); |
| } |
| |
| OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { |
| if (getType() == source().getType()) |
| return source(); |
| return {}; |
| } |
| |
| static LogicalResult verify(ToPointersOp op) { |
| if (auto e = getSparseTensorEncoding(op.tensor().getType())) { |
| if (failed(isInBounds(op.dim(), op.tensor()))) |
| return op.emitError("requested pointers dimension out of bounds"); |
| if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) |
| return op.emitError("unexpected type for pointers"); |
| return success(); |
| } |
| return op.emitError("expected a sparse tensor to get pointers"); |
| } |
| |
| static LogicalResult verify(ToIndicesOp op) { |
| if (auto e = getSparseTensorEncoding(op.tensor().getType())) { |
| if (failed(isInBounds(op.dim(), op.tensor()))) |
| return op.emitError("requested indices dimension out of bounds"); |
| if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) |
| return op.emitError("unexpected type for indices"); |
| return success(); |
| } |
| return op.emitError("expected a sparse tensor to get indices"); |
| } |
| |
| static LogicalResult verify(ToValuesOp op) { |
| if (!getSparseTensorEncoding(op.tensor().getType())) |
| return op.emitError("expected a sparse tensor to get values"); |
| RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); |
| MemRefType mtp = op.result().getType().cast<MemRefType>(); |
| if (ttp.getElementType() != mtp.getElementType()) |
| return op.emitError("unexpected mismatch in element types"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorDialect Management Operations. |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(LexInsertOp op) { |
| if (!getSparseTensorEncoding(op.tensor().getType())) |
| return op.emitError("expected a sparse tensor for insertion"); |
| return success(); |
| } |
| |
| static LogicalResult verify(LoadOp op) { |
| if (!getSparseTensorEncoding(op.tensor().getType())) |
| return op.emitError("expected a sparse tensor to materialize"); |
| return success(); |
| } |
| |
| static LogicalResult verify(ReleaseOp op) { |
| if (!getSparseTensorEncoding(op.tensor().getType())) |
| return op.emitError("expected a sparse tensor to release"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorDialect Methods. |
| //===----------------------------------------------------------------------===// |
| |
| void SparseTensorDialect::initialize() { |
| addAttributes< |
| #define GET_ATTRDEF_LIST |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
| >(); |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
| >(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
| |
| Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, |
| Type type) const { |
| StringRef attrTag; |
| if (failed(parser.parseKeyword(&attrTag))) |
| return Attribute(); |
| Attribute attr; |
| auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); |
| if (parseResult.hasValue()) |
| return attr; |
| parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); |
| return Attribute(); |
| } |
| |
| void SparseTensorDialect::printAttribute(Attribute attr, |
| DialectAsmPrinter &printer) const { |
| if (succeeded(generatedAttributePrinter(attr, printer))) |
| return; |
| } |