blob: e52d0e17cda22b98563831653e6c1b7d7a9095ec [file] [log] [blame]
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
} // namespace mlir
LogicalResult
mlir::reifyResultShapes(OpBuilder &b, Operation *op,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto reifiableOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
if (!reifiableOp)
return failure();
LogicalResult status = reifiableOp.reifyResultShapes(b, reifiedReturnShapes);
#ifndef NDEBUG
if (failed(status))
return failure();
// Assert that ReifyRankedShapedTypeOpInterface::reifyResultShapes produced
// a correct result.
int64_t resultIdx = 0;
for (OpResult result : op->getResults()) {
auto shapedType = dyn_cast<ShapedType>(result.getType());
if (!shapedType)
continue;
if (!shapedType.hasRank()) {
// Nothing to check for unranked shaped values.
++resultIdx;
continue;
}
// Assert one OpFoldResult per dimension.
assert(shapedType.getRank() ==
static_cast<int64_t>(reifiedReturnShapes[resultIdx].size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
for (int64_t dim = 0; dim < shapedType.getRank(); ++dim) {
// reifyResultShapes must return:
// * Attribute for static dimensions
// * Value for dynamic dimensions
assert(shapedType.isDynamicDim(dim) ==
reifiedReturnShapes[resultIdx][dim].is<Value>() &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
}
++resultIdx;
}
// Assert that every shaped value result was reified.
assert(resultIdx == static_cast<int64_t>(reifiedReturnShapes.size()) &&
"incorrect implementation of ReifyRankedShapedTypeOpInterface");
#endif // NDEBUG
return status;
}
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasRank();
if (val.is<Attribute>())
return true;
return val.get<ShapedTypeComponents *>()->hasRank();
}
Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getElementType();
if (val.is<Attribute>())
return nullptr;
return val.get<ShapedTypeComponents *>()->getElementType();
}
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
if (auto t = llvm::dyn_cast_if_present<Type>(val)) {
ArrayRef<int64_t> vals = cast<ShapedType>(t).getShape();
res.assign(vals.begin(), vals.end());
} else if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
res.clear();
res.reserve(dattr.size());
for (auto it : dattr.getValues<APInt>())
res.push_back(it.getSExtValue());
} else {
auto vals = val.get<ShapedTypeComponents *>()->getDims();
res.assign(vals.begin(), vals.end());
}
}
void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
assert(hasRank());
res.ranked = true;
getDims(res.dims);
}
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getDimSize(index);
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr)
.getValues<APInt>()[index]
.getSExtValue();
auto *stc = val.get<ShapedTypeComponents *>();
return stc->getDims()[index];
}
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getRank();
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val))
return cast<DenseIntElementsAttr>(attr).size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
bool ShapeAdaptor::hasStaticShape() const {
if (!hasRank())
return false;
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).hasStaticShape();
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
return false;
return true;
}
auto *stc = val.get<ShapedTypeComponents *>();
return llvm::none_of(stc->getDims(), ShapedType::isDynamic);
}
int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
if (auto t = llvm::dyn_cast_if_present<Type>(val))
return cast<ShapedType>(t).getNumElements();
if (auto attr = llvm::dyn_cast_if_present<Attribute>(val)) {
auto dattr = cast<DenseIntElementsAttr>(attr);
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {
num *= index.getZExtValue();
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
auto *stc = val.get<ShapedTypeComponents *>();
int64_t num = 1;
for (int64_t dim : stc->getDims()) {
num *= dim;
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
void ShapeAdaptor::dump() const {
if (!hasRank()) {
llvm::errs() << "<<unranked>>\n";
return;
}
SmallVector<int64_t> dims;
getDims(dims);
auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
if (ShapedType::isDynamic(dim))
return "?";
return llvm::formatv("{0}", dim).str();
});
llvm::errs() << "rank = " << getRank() << " dims = [";
llvm::interleave(mapped, llvm::errs(), "x");
llvm::errs() << "]\n";
}
ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
Value val = operator[](index);
if (valueToShape)
if (ShapeAdaptor ret = valueToShape(val))
return ret;
DenseIntElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr)))
return nullptr;
if (attr.getType().getRank() != 1)
return nullptr;
return attr;
}
ShapeAdaptor ValueShapeRange::getShape(Value val) const {
if (operandShape)
if (ShapeAdaptor ret = operandShape(val))
return ret;
return val.getType();
}
ShapeAdaptor ValueShapeRange::getShape(int index) const {
if (index < 0 || static_cast<size_t>(index) >= size())
return nullptr;
return getShape(operator[](index));
}
LogicalResult mlir::detail::inferReturnTensorTypes(
ArrayRef<ShapedTypeComponents> retComponents,
SmallVectorImpl<Type> &inferredReturnTypes) {
for (const auto &shapeAndType : retComponents) {
Type elementTy = shapeAndType.getElementType();
assert(elementTy && "element type required to construct tensor");
Attribute attr = shapeAndType.getAttribute();
if (shapeAndType.hasRank()) {
inferredReturnTypes.push_back(
RankedTensorType::get(shapeAndType.getDims(), elementTy, attr));
} else {
assert(attr == nullptr && "attribute not supported");
inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
}
}
return success();
}
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes());
auto retTypeFn = cast<InferTypeOpInterface>(op);
auto result = retTypeFn.refineReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
op->getRawDictionaryAttrs(), op->getPropertiesStorage(), op->getRegions(),
inferredReturnTypes);
if (failed(result))
op->emitOpError() << "failed to infer returned types";
return result;
}