blob: 70dbc05592b78dce32d95193351db5af584824a2 [file] [log] [blame]
//===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the definitions of the infer op interfaces defined in
// `InferTypeOpInterface.td`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/InferTypeOpInterface.cpp.inc"
} // namespace mlir
LogicalResult mlir::detail::inferReturnTensorTypes(
function_ref<LogicalResult(
MLIRContext *, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
componentTypeFn,
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
SmallVector<ShapedTypeComponents, 2> retComponents;
if (failed(componentTypeFn(context, location, operands, attributes, regions,
retComponents)))
return failure();
for (auto shapeAndType : retComponents) {
assert(shapeAndType.getAttribute() == nullptr && "attribute not supported");
if (shapeAndType.hasRank())
inferredReturnTypes.push_back(RankedTensorType::get(
shapeAndType.getDims(), shapeAndType.getElementType()));
else
inferredReturnTypes.push_back(
UnrankedTensorType::get(shapeAndType.getElementType()));
}
return success();
}
LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
SmallVector<Type, 4> inferredReturnTypes;
auto retTypeFn = cast<InferTypeOpInterface>(op);
if (failed(retTypeFn.inferReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
op->getAttrDictionary(), op->getRegions(), inferredReturnTypes)))
return failure();
if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes,
op->getResultTypes()))
return op->emitOpError(
"inferred type incompatible with return type of operation");
return success();
}