blob: 67c9ccbaec5bed2a55174f23ea7a57f795ba6126 [file] [log] [blame]
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
} // namespace mlir
bool ShapeAdaptor::hasRank() const {
if (val.isNull())
return false;
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().hasRank();
if (val.is<Attribute>())
return true;
return val.get<ShapedTypeComponents *>()->hasRank();
}
Type ShapeAdaptor::getElementType() const {
if (val.isNull())
return nullptr;
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().getElementType();
if (val.is<Attribute>())
return nullptr;
return val.get<ShapedTypeComponents *>()->getElementType();
}
void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>()) {
ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape();
res.assign(vals.begin(), vals.end());
} else if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
res.clear();
res.reserve(dattr.size());
for (auto it : dattr.getValues<APInt>())
res.push_back(it.getSExtValue());
} else {
auto vals = val.get<ShapedTypeComponents *>()->getDims();
res.assign(vals.begin(), vals.end());
}
}
void ShapeAdaptor::getDims(ShapedTypeComponents &res) const {
assert(hasRank());
res.ranked = true;
getDims(res.dims);
}
int64_t ShapeAdaptor::getDimSize(int index) const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().getDimSize(index);
if (auto attr = val.dyn_cast<Attribute>())
return attr.cast<DenseIntElementsAttr>()
.getValues<APInt>()[index]
.getSExtValue();
auto *stc = val.get<ShapedTypeComponents *>();
return stc->getDims()[index];
}
int64_t ShapeAdaptor::getRank() const {
assert(hasRank());
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().getRank();
if (auto attr = val.dyn_cast<Attribute>())
return attr.cast<DenseIntElementsAttr>().size();
return val.get<ShapedTypeComponents *>()->getDims().size();
}
bool ShapeAdaptor::hasStaticShape() const {
if (!hasRank())
return false;
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().hasStaticShape();
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
for (auto index : dattr.getValues<APInt>())
if (ShapedType::isDynamic(index.getSExtValue()))
return false;
return true;
}
auto *stc = val.get<ShapedTypeComponents *>();
for (int64_t dim : stc->getDims())
if (ShapedType::isDynamic(dim))
return false;
return true;
}
int64_t ShapeAdaptor::getNumElements() const {
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
if (auto t = val.dyn_cast<Type>())
return t.cast<ShapedType>().getNumElements();
if (auto attr = val.dyn_cast<Attribute>()) {
auto dattr = attr.cast<DenseIntElementsAttr>();
int64_t num = 1;
for (auto index : dattr.getValues<APInt>()) {
num *= index.getZExtValue();
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
auto *stc = val.get<ShapedTypeComponents *>();
int64_t num = 1;
for (int64_t dim : stc->getDims()) {
num *= dim;
assert(num >= 0 && "integer overflow in element count computation");
}
return num;
}
void ShapeAdaptor::dump() const {
if (!hasRank()) {
llvm::errs() << "<<unranked>>\n";
return;
}
SmallVector<int64_t> dims;
getDims(dims);
auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string {
if (ShapedType::isDynamic(dim))
return "?";
return llvm::formatv("{0}", dim).str();
});
llvm::errs() << "rank = " << getRank() << " dims = [";
llvm::interleave(mapped, llvm::errs(), "x");
llvm::errs() << "]\n";
}
ShapeAdaptor ValueShapeRange::getValueAsShape(int index) {
Value val = operator[](index);
if (valueToShape)
if (ShapeAdaptor ret = valueToShape(val))
return ret;
DenseIntElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr)))
return nullptr;
if (attr.getType().getRank() != 1)
return nullptr;
return attr;
}
ShapeAdaptor ValueShapeRange::getShape(Value val) const {
if (operandShape)
if (ShapeAdaptor ret = operandShape(val))
return ret;
return val.getType();
}
ShapeAdaptor ValueShapeRange::getShape(int index) const {
if (index < 0 || static_cast<size_t>(index) >= size())
return nullptr;
return getShape(operator[](index));
}
LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes, regions,
retComponents)))
return failure();
for (auto shapeAndType : retComponents) {
assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
if (shapeAndType.hasRank())
inferredReturnTypes.push_back(RankedTensorType::get(
shapeAndType.getDims(), shapeAndType.getElementType()));
else
inferredReturnTypes.push_back(
UnrankedTensorType::get(shapeAndType.getElementType()));
}
return success();
}
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes;
auto retTypeFn = cast<InferTypeOpInterface>(op);
if (failed(retTypeFn.inferReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
return failure();
if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
op->getResultTypes()))
return op->emitOpError("inferred type(s) ")
<< inferredReturnTypes
<< " are incompatible with return type(s) of operation "
<< op->getResultTypes();
return success();
}