| //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file contains the definitions of the infer op interfaces defined in |
| // `InferTypeOpInterface.td`. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Interfaces/InferTypeOpInterface.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| using namespace mlir; |
| |
| namespace mlir { |
| #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" |
| } // namespace mlir |
| |
| LogicalResult |
| mlir::reifyResultShapes(OpBuilder &b, Operation *op, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op); |
| if (!reifiableOp) |
| return failure(); |
| LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes); |
| #ifndef NDEBUG |
| if (failed(status)) |
| return failure(); |
| // Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced |
| // a correct result. |
| int64_t resultIdx = 0; |
| for (OpResult result : op->getResults()) { |
| auto shapedType = dyn_cast<ShapedType>(result.getType()); |
| if (!shapedType) |
| continue; |
| if (!shapedType.hasRank()) { |
| // Nothing to check for unranked shaped values. |
| ++resultIdx; |
| continue; |
| } |
| // Assert one OpFoldResult per dimension. |
| assert(shapedType.getRank() == |
| static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) && |
| "incorrect implementation of ReifyRankedShapedTypeOpInterface"); |
| for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) { |
| // reifyResultShapes must return: |
| // * Attribute for static dimensions |
| // * Value for dynamic dimensions |
| assert(shapedType.isDynamicDim(dim) == |
| reifiedReturnShapes[resultIdx][dim].is<Value>() && |
| "incorrect implementation of ReifyRankedShapedTypeOpInterface"); |
| } |
| ++resultIdx; |
| } |
| // Assert that every shaped value result was reified. |
| assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) && |
| "incorrect implementation of ReifyRankedShapedTypeOpInterface"); |
| #endif // NDEBUG |
| return status; |
| } |
| |
| bool ShapeAdaptor::hasRank() const { |
| if (val.isNull()) |
| return false; |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).hasRank(); |
| if (val.is<Attribute>()) |
| return true; |
| return val.get<ShapedTypeComponents *>()->hasRank(); |
| } |
| |
| Type ShapeAdaptor::getElementType() const { |
| if (val.isNull()) |
| return nullptr; |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).getElementType(); |
| if (val.is<Attribute>()) |
| return nullptr; |
| return val.get<ShapedTypeComponents *>()->getElementType(); |
| } |
| |
| void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { |
| assert(hasRank()); |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) { |
| ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape(); |
| res.assign(vals.begin(), vals.end()); |
| } else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { |
| auto dattr = cast<DenseIntElementsAttr>(attr); |
| res.clear(); |
| res.reserve(dattr.size()); |
| for (auto it : dattr.getValues<APInt>()) |
| res.push_back(it.getSExtValue()); |
| } else { |
| auto vals = val.get<ShapedTypeComponents *>()->getDims(); |
| res.assign(vals.begin(), vals.end()); |
| } |
| } |
| |
| void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { |
| assert(hasRank()); |
| res.ranked = true; |
| getDims(res.dims); |
| } |
| |
| int64_t ShapeAdaptor::getDimSize(int index) const { |
| assert(hasRank()); |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).getDimSize(index); |
| if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) |
| return cast<DenseIntElementsAttr>(attr) |
| .getValues<APInt>()[index] |
| .getSExtValue(); |
| auto *stc = val.get<ShapedTypeComponents *>(); |
| return stc->getDims()[index]; |
| } |
| |
| int64_t ShapeAdaptor::getRank() const { |
| assert(hasRank()); |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).getRank(); |
| if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) |
| return cast<DenseIntElementsAttr>(attr).size(); |
| return val.get<ShapedTypeComponents *>()->getDims().size(); |
| } |
| |
| bool ShapeAdaptor::hasStaticShape() const { |
| if (!hasRank()) |
| return false; |
| |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).hasStaticShape(); |
| if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { |
| auto dattr = cast<DenseIntElementsAttr>(attr); |
| for (auto index : dattr.getValues<APInt>()) |
| if (ShapedType::isDynamic(index.getSExtValue())) |
| return false; |
| return true; |
| } |
| auto *stc = val.get<ShapedTypeComponents *>(); |
| return llvm::none_of(stc->getDims(), ShapedType::isDynamic); |
| } |
| |
| int64_t ShapeAdaptor::getNumElements() const { |
| assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); |
| |
| if (auto t = llvm::dyn_cast_if_present<Type>(val)) |
| return cast<ShapedType>(t).getNumElements(); |
| |
| if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) { |
| auto dattr = cast<DenseIntElementsAttr>(attr); |
| int64_t num = 1; |
| for (auto index : dattr.getValues<APInt>()) { |
| num *= index.getZExtValue(); |
| assert(num >= 0 && "integer overflow in element count computation"); |
| } |
| return num; |
| } |
| |
| auto *stc = val.get<ShapedTypeComponents *>(); |
| int64_t num = 1; |
| for (int64_t dim : stc->getDims()) { |
| num *= dim; |
| assert(num >= 0 && "integer overflow in element count computation"); |
| } |
| return num; |
| } |
| |
| void ShapeAdaptor::dump() const { |
| if (!hasRank()) { |
| llvm::errs() << "<<unranked>>\n"; |
| return; |
| } |
| |
| SmallVector<int64_t> dims; |
| getDims(dims); |
| auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { |
| if (ShapedType::isDynamic(dim)) |
| return "?"; |
| return llvm::formatv("{0}", dim).str(); |
| }); |
| llvm::errs() << "rank = " << getRank() << " dims = ["; |
| llvm::interleave(mapped, llvm::errs(), "x"); |
| llvm::errs() << "]\n"; |
| } |
| |
| ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { |
| Value val = operator[](index); |
| if (valueToShape) |
| if (ShapeAdaptor ret = valueToShape(val)) |
| return ret; |
| |
| DenseIntElementsAttr attr; |
| if (!matchPattern(val, m_Constant(&attr))) |
| return nullptr; |
| if (attr.getType().getRank() != 1) |
| return nullptr; |
| return attr; |
| } |
| |
| ShapeAdaptor ValueShapeRange::getShape(Value val) const { |
| if (operandShape) |
| if (ShapeAdaptor ret = operandShape(val)) |
| return ret; |
| return val.getType(); |
| } |
| |
| ShapeAdaptor ValueShapeRange::getShape(int index) const { |
| if (index < 0 || static_cast<size_t>(index) >= size()) |
| return nullptr; |
| return getShape(operator[](index)); |
| } |
| |
| LogicalResult mlir::detail::inferReturnTensorTypes( |
| ArrayRef<ShapedTypeComponents> retComponents, |
| SmallVectorImpl<Type> &inferredReturnTypes) { |
| for (const auto &shapeAndType : retComponents) { |
| Type elementTy = shapeAndType.getElementType(); |
| assert(elementTy && "element type required to construct tensor"); |
| |
| Attribute attr = shapeAndType.getAttribute(); |
| if (shapeAndType.hasRank()) { |
| inferredReturnTypes.push_back( |
| RankedTensorType::get(shapeAndType.getDims(), elementTy, attr)); |
| } else { |
| assert(attr == nullptr && "attribute not supported"); |
| inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy)); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { |
| SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); |
| auto retTypeFn = cast<InferTypeOpInterface>(op); |
| auto result = retTypeFn.refineReturnTypes( |
| op->getContext(), op->getLoc(), op->getOperands(), |
| op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(), |
| inferredReturnTypes); |
| if (failed(result)) |
| op->emitOpError() << "failed to infer returned types"; |
| |
| return result; |
| } |