blob: b334eedf61ef63d737ca5acdd0da33ef53d2877a [file] [log] [blame]
//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
using namespace mlir;
namespace mlir {
#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
} // namespace mlir
OpOperandVector::operator SmallVector<Value>() {
SmallVector<Value> result;
result.reserve(this->size());
llvm::transform(*this, std::back_inserter(result),
[](OpOperand *opOperand) { return opOperand->get(); });
return result;
}
LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
DestinationStyleOpInterface dstStyleOp =
cast<DestinationStyleOpInterface>(op);
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
Type type = operand->get().getType();
if (type.isa<MemRefType>()) {
outputBufferOperands.push_back(operand);
} else if (type.isa<RankedTensorType>()) {
outputTensorOperands.push_back(operand);
} else {
return op->emitOpError("expected that operand #")
<< operand->getOperandNumber()
<< " is a ranked tensor or a ranked memref";
}
}
// Expect at least one output operand.
int64_t numInputs = dstStyleOp.getNumDpsInputs();
int64_t numInits = dstStyleOp.getNumDpsInits();
if (numInits == 0)
return op->emitOpError("expected at least one output operand");
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits)))
return failure();
// Verify the number of results matches the number of output tensors.
if (op->getNumResults() != outputTensorOperands.size())
return op->emitOpError("expected the number of results (")
<< op->getNumResults()
<< ") to be equal to the number of output tensors ("
<< outputTensorOperands.size() << ")";
// Simplifying assumption: either full tensor or full buffer mode.
// This allows simpler verification of output operands vs result types
// without premature tracking of which operand is what in mixed-mode.
// TODO: relax when mixed-mode needs to pass verification.
if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
return op->emitOpError(
"expected output operands to all have tensor type or "
"all have buffer type");
for (OpOperand *opOperand : outputTensorOperands) {
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
if (result.getType() != opOperand->get().getType())
return op->emitOpError("expected type of operand #")
<< opOperand->getOperandNumber() << " ("
<< opOperand->get().getType() << ")"
<< " to match type of corresponding result (" << result.getType()
<< ")";
}
return success();
}