blob: 36828eabd59f72b9831237aa59809be8dbdde1f3 [file] [log] [blame]
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the Linalg operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::linalg;
#include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc"
/// Forward declarations.
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
/// asserted to be of ShapedType.
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler = nullptr);
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes);
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
/// Helper function to convert a vector of `OpFoldResult`s into a vector of
/// `Value`s.
static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
ArrayRef<OpFoldResult> valueOrAttrVec) {
return llvm::to_vector<4>(
llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
return getValueOrCreateConstantIndexOp(b, loc, value);
}));
}
/// This is a common class used for patterns of the form
/// ```
/// someop(memrefcast(%src)) -> someop(%src)
/// ```
/// It folds the source of the memref.cast into the root operation directly.
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}
/// This is a specialization of `foldMemRefCast` used for patterns of the form
/// ```
/// tiled_loop(memrefcast(%src)) -> tiled_loop(%src)
/// ```
/// It folds the source of the memref.cast into the root operation directly.
static LogicalResult foldMemRefCastInTiledLoopOp(TiledLoopOp op) {
bool folded = false;
Location loc = op->getLoc();
Block *body = op.getBody();
OpBuilder b = OpBuilder::atBlockBegin(body);
// Update `input` and `output` operands and block arguments if necessary.
// Operands list: [lbs, ubs, steps, inputs, outputs].
// Block args list: [ivs, inputs, outputs].
for (size_t operandIndex = op.getNumControlOperands(),
bbArgIndex = op.getNumLoops(), e = op.getNumOperands();
operandIndex < e; ++operandIndex, ++bbArgIndex) {
OpOperand &operand = op->getOpOperand(operandIndex);
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
BlockArgument newBbArg =
body->insertArgument(bbArgIndex, castOp.getOperand().getType());
BlockArgument oldBbArg = body->getArgument(newBbArg.getArgNumber() + 1);
// Insert memref.cast back to the original type.
oldBbArg.replaceAllUsesWith(
b.create<memref::CastOp>(loc, oldBbArg.getType(), newBbArg));
body->eraseArgument(oldBbArg.getArgNumber());
folded = true;
}
}
return success(folded);
}
//===----------------------------------------------------------------------===//
// Region builder helper.
// TODO: Move this to a utility library.
// The public methods on this class are referenced directly from generated code
// and bind by name to math functions in the DSL as:
// `applyfn__{fnName}`
// Examples:
// `applyfn__add`
// `applyfn__mul`
// The naming convention is intentional in order to match snake-cased DSL names.
// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
//
// Implementations of the math functions must be polymorphic over numeric types,
// internally performing necessary casts. If the function application makes no
// sense, then the only recourse is to assert and return nullptr. This can be
// extended later if it becomes possible to fail construction of the region. The
// invariant should be enforced at a higher level.
//
// TODO: These helpers are currently type polymorphic over the class of integer
// and floating point types, but they will not internally cast within bit
// widths of a class (mixed precision such as i8->i32) or across classes
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
// to be handled with care and work is being considered to extend the op
// language to make such cases explicit. In the mean-time, violating this will
// fail verification, which is deemed acceptable.
//===----------------------------------------------------------------------===//
namespace {
class RegionBuilderHelper {
public:
RegionBuilderHelper(MLIRContext *context, Block &block)
: context(context), block(block) {}
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
if (operand.getType() == toType)
return operand;
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>()) {
if (isUnsignedCast)
return builder.create<arith::FPToUIOp>(loc, toType, operand);
return builder.create<arith::FPToSIOp>(loc, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return builder.create<arith::IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
return builder.create<arith::ExtUIOp>(loc, toType, operand);
return builder.create<arith::ExtSIOp>(loc, toType, operand);
}
if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<arith::TruncIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
if (operand.getType().isa<IntegerType>()) {
if (isUnsignedCast)
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
}
}
emitWarning(operand.getLoc()) << "could not cast operand of type "
<< operand.getType() << " to " << toType;
return operand;
}
Value applyfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::AddIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
void yieldOutputs(ValueRange values) {
assert(!values.empty() && "linalg ops must yield outputs");
if (values.empty())
return;
Value first = values.front();
OpBuilder builder = getBuilder();
builder.create<YieldOp>(first.getLoc(), values);
}
Value constant(std::string value) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
valueAttr);
}
Value index(int64_t dim) {
OpBuilder builder = getBuilder();
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}
Type getFloat32Type() { return Float32Type::get(context); }
Type getFloat64Type() { return Float64Type::get(context); }
private:
MLIRContext *context;
Block &block;
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
OpBuilder getBuilder() {
OpBuilder builder(context);
builder.setInsertionPointToEnd(&block);
return builder;
}
};
} // namespace
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
void CopyOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
b.create<linalg::YieldOp>(block.getArgument(0));
}
void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
Value output, AffineMap inputPermutation,
AffineMap outputPermutation,
ArrayRef<NamedAttribute> namedAttrs) {
result.addOperands({input, output});
result.addAttributes(namedAttrs);
if (inputPermutation)
result.addAttribute("inputPermutation",
AffineMapAttr::get(inputPermutation));
if (outputPermutation)
result.addAttribute("outputPermutation",
AffineMapAttr::get(outputPermutation));
result.addRegion();
fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
TypeRange{input.getType()},
TypeRange{output.getType()});
}
ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
Type outputType) {
OpBuilder opBuilder(parser.getContext());
fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
TypeRange{outputType});
return success();
}
/// CopyOp region is elided when printing.
void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(CopyOp op) {
OpOperand *output = op.getOutputOperand(0);
OpOperand *input = op.getInputOperand(0);
if (getElementTypeOrSelf(input->get()) != getElementTypeOrSelf(output->get()))
return op.emitOpError("expects views of the same type");
if (op.getRank(input) != op.getRank(output))
return op.emitOpError("expects views of the same rank");
auto rank = op.getNumParallelLoops();
auto inputPermutationMap = op.inputPermutation();
if (inputPermutationMap) {
if (inputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional input_permutation map of rank ")
<< rank;
if (!inputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional input_permutation map to be a permutation");
}
auto outputPermutationMap = op.outputPermutation();
if (outputPermutationMap) {
if (outputPermutationMap->getNumInputs() != rank)
return op.emitOpError("expects optional output_permutation map of rank ")
<< rank;
if (!outputPermutationMap->isPermutation())
return op.emitOpError(
"expects optional output_permutation map to be a permutation");
}
if (rank == 0 && inputPermutationMap)
return op.emitOpError("expected no input permutation when rank == 0");
if (rank == 0 && outputPermutationMap)
return op.emitOpError("expected no output permutation when rank == 0");
return success();
}
void CopyOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
namespace {
/// Remove copy operations that copy data inplace. Requirements are:
/// 1) The input and output values are identical.
/// 2) The input and output permutation maps are identical.
struct EraseIdentityCopyOp : public OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
assert(copyOp.hasBufferSemantics());
if (copyOp.input() == copyOp.output() &&
copyOp.inputPermutation() == copyOp.outputPermutation()) {
rewriter.eraseOp(copyOp);
return success();
}
return failure();
}
};
} // namespace
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<EraseIdentityCopyOp>(context);
}
//===----------------------------------------------------------------------===//
// FillOp
//===----------------------------------------------------------------------===//
void FillOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
assert(block.getNumArguments() == 2 && "FillOp regionBuilder expects 2 args");
b.create<linalg::YieldOp>(block.getArgument(0));
}
void FillOp::build(OpBuilder &builder, OperationState &result, Value value,
Value output) {
build(builder, result, output.getType().dyn_cast<RankedTensorType>(), value,
output);
fillStructuredOpRegion<FillOp>(builder, *result.regions.front(),
TypeRange{value.getType()},
TypeRange{output.getType()}, {});
}
ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type valueType,
Type outputType) {
OpBuilder opBuilder(parser.getContext());
fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{valueType},
TypeRange{outputType});
return success();
}
/// FillOp region is elided when printing.
void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
static LogicalResult verify(FillOp op) {
OpOperand *output = op.getOutputOperand(0);
Type fillType = op.value().getType();
if (getElementTypeOrSelf(output->get()) != fillType)
return op.emitOpError("expects fill type to match view elemental type");
return success();
}
void FillOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (output().getType().isa<MemRefType>())
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
}
//===----------------------------------------------------------------------===//
// GenericOps
//===----------------------------------------------------------------------===//
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs,
builder.getAffineMapArrayAttr(indexingMaps),
builder.getStrArrayAttr(iteratorTypes),
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
libraryCall.empty() ? StringAttr()
: builder.getStringAttr(libraryCall));
result.addAttributes(attributes);
if (!bodyBuild)
return;
SmallVector<Type, 4> blockArgTypes;
for (ValueRange container : {inputs, outputs})
for (Value v : container)
blockArgTypes.push_back(getElementTypeOrSelf(v));
OpBuilder::InsertionGuard guard(builder);
auto &region = *result.regions.front();
Block *bodyBlock = builder.createBlock(&region, region.end(), blockArgTypes);
bodyBuild(builder, result.location, bodyBlock->getArguments());
}
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
}
void GenericOp::build(
OpBuilder &builder, OperationState &result, ValueRange inputs,
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild, attributes);
}
void GenericOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
ArrayRef<StringRef> iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
iteratorTypes,
/*doc=*/"",
/*libraryCall=*/"", bodyBuild, attributes);
}
static void print(OpAsmPrinter &p, GenericOp op) {
p << " ";
// Print extra attributes.
auto genericAttrNames = op.linalgTraitAttrNames();
llvm::StringSet<> genericAttrNamesSet;
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
SmallVector<NamedAttribute, 8> genericAttrs;
for (auto attr : op->getAttrs())
if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
genericAttrs.push_back(attr);
if (!genericAttrs.empty()) {
auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
p << genericDictAttr;
}
// Printing is shared with named ops, except for the region and attributes
printCommonStructuredOpParts(p, op);
genericAttrNames.push_back("operand_segment_sizes");
genericAttrNamesSet.insert(genericAttrNames.back());
bool hasExtraAttrs = false;
for (NamedAttribute n : op->getAttrs()) {
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
break;
}
if (hasExtraAttrs) {
p << " attrs = ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/genericAttrNames);
}
// Print region.
if (!op.region().empty())
p.printRegion(op.region());
// Print results.
printNamedStructuredOpResults(p, op.result_tensors().getTypes());
}
static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
DictionaryAttr dictAttr;
// Parse the core linalg traits that must check into a dictAttr.
// The name is unimportant as we will overwrite result.attributes.
// The core linalg traits must contain the information necessary to pass the
// verifier.
if (parser.parseAttribute(dictAttr, "_", result.attributes))
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
// Parsing is shared with named ops, except for the region.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
// Optional attributes may be added.
if (succeeded(parser.parseOptionalKeyword("attrs")))
if (failed(parser.parseEqual()) ||
failed(parser.parseOptionalAttrDict(result.attributes)))
return failure();
SmallVector<OpAsmParser::OperandType, 8> regionOperands;
std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<Type, 8> operandTypes, regionTypes;
if (parser.parseRegion(*region, regionOperands, regionTypes))
return failure();
result.addRegion(std::move(region));
// Generic ops may specify that a subset of its outputs are tensors. Such
// outputs are specified in the result type.
// TODO: may need to move output parsing before region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
return failure();
result.addTypes(outputTensorsTypes);
return success();
}
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
for (Value value : results) {
effects.emplace_back(MemoryEffects::Allocate::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : inputBuffers) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : outputs) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
SideEffects::DefaultResource::get());
}
}
void GenericOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
SmallVector<Value> inputBuffers = getInputBufferOperands();
SmallVector<Value> outputBuffers = getOutputBufferOperands();
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
outputBuffers);
}
template <typename GenericOpType>
static LogicalResult verifyGenericOp(GenericOpType op) {
return success();
}
static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
namespace {
// Deduplicate redundant args of a linalg generic op.
// An arg is redundant if it has the same Value and indexing map as another.
struct DeduplicateGenericOpInputs : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Associate each input to an equivalent "canonical" input that has the same
// Value and indexing map.
//
// In the non-duplicate case, input `i` will have canonical input `i`. But
// in the case of duplicated inputs, the canonical input could be some other
// input `< i`. That is, a later input will have some earlier input as its
// canonical input.
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> canonicalInput;
// For later remapping tasks like deduplicating payload block arguments,
// having a simple "inputIndex -> canonicalInputIndex" integer mapping is
// convenient.
SmallVector<unsigned> canonicalInputIndices;
for (OpOperand *opOperand : genericOp.getInputOperands()) {
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
// STL-like maps have a convenient behavior for our use case here. In the
// case of duplicate keys, the insertion is rejected, and the returned
// iterator gives access to the value already in the map.
auto pair = canonicalInput.insert(
{{opOperand->get(), indexingMap}, opOperand->getOperandNumber()});
canonicalInputIndices.push_back(pair.first->second);
}
// If there are no duplicate args, then bail out.
if (canonicalInput.size() == genericOp.getNumInputs())
return failure();
// The operands for the newly canonicalized op.
SmallVector<Value> newInputOperands;
for (OpOperand *opOperand : genericOp.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newInputOperands.push_back(opOperand->get());
// Repair the indexing maps by filtering out the ones that have been
// eliminated.
SmallVector<AffineMap> newIndexingMaps;
for (OpOperand *opOperand : genericOp.getInputOperands())
if (canonicalInputIndices[opOperand->getOperandNumber()] ==
opOperand->getOperandNumber())
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
for (OpOperand *opOperand : genericOp.getOutputOperands())
newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand));
// Clone the old op with new operands.
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
auto newOp = rewriter.create<GenericOp>(
genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands,
outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.iterator_types(), genericOp.docAttr(),
genericOp.library_callAttr());
// Copy over unknown attributes. They might be load bearing for some flow.
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
for (NamedAttribute kv : genericOp->getAttrs()) {
if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) {
newOp->setAttr(kv.getName(), kv.getValue());
}
}
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
newOp.region().begin());
// Repair the payload entry block by RAUW'ing redundant arguments and
// erasing them.
Block &payload = newOp.region().front();
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
for (OpOperand *opOperand : llvm::reverse(inputOperands)) {
// Iterate in reverse, so that we erase later args first, preventing the
// argument list from shifting unexpectedly and invalidating all our
// indices.
unsigned operandNumber = opOperand->getOperandNumber();
if (canonicalInputIndices[operandNumber] == operandNumber)
continue;
payload.getArgument(operandNumber)
.replaceAllUsesWith(
payload.getArgument(canonicalInputIndices[operandNumber]));
payload.eraseArgument(operandNumber);
}
rewriter.replaceOp(genericOp, newOp->getResults());
return success();
}
};
/// Remove generic operations (on tensors) that are just copying
/// the values from inputs to the results. Requirements are
/// 1) All iterator types are parallel
/// 2) The body contains just a yield operation with the yielded values being
/// the arguments corresponding to the operands.
struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
return failure();
// Check all indexing maps are identity.
if (llvm::any_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return !map.isIdentity(); }))
return failure();
// Check that the body of the linalg operation is just a linalg.yield
// operation.
Block &body = genericOp.region().front();
if (!llvm::hasSingleElement(body))
return failure();
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return failure();
// Get the argument number of the returned values. That is the operand
// number to use for replacing uses of this operation.
SmallVector<Value> returnedArgs;
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return failure();
unsigned argumentNumber = yieldArg.getArgNumber();
returnedArgs.push_back(genericOp->getOperand(argumentNumber));
}
if (returnedArgs.size() != genericOp->getNumResults())
return failure();
rewriter.replaceOp(genericOp, returnedArgs);
return success();
}
};
} // namespace
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DeduplicateGenericOpInputs, EraseIdentityGenericOp>(context);
}
//===----------------------------------------------------------------------===//
// InitTensorOp
//===----------------------------------------------------------------------===//
void InitTensorOp::build(OpBuilder &b, OperationState &result,
ArrayRef<OpFoldResult> sizes, Type elementType,
ArrayRef<NamedAttribute> attrs) {
SmallVector<Value, 4> dynamicSizes;
SmallVector<int64_t, 4> staticSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
ShapedType::kDynamicSize);
auto resultType = RankedTensorType ::get(staticSizes, elementType);
build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
result.addAttributes(attrs);
}
static LogicalResult verify(InitTensorOp op) {
RankedTensorType resultType = op.getType();
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
op.static_sizes().cast<ArrayAttr>(),
[](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
if (failed(verifyListOfOperandsOrIntegers(op, "sizes", resultType.getRank(),
op.static_sizes(), op.sizes(),
ShapedType::isDynamic)))
return failure();
if (op.static_sizes().size() != static_cast<unsigned>(resultType.getRank()))
return op->emitError("expected ")
<< resultType.getRank() << " sizes values";
Type expectedType =
InitTensorOp::inferResultType(staticSizes, resultType.getElementType());
if (resultType != expectedType) {
return op.emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
return success();
}
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
Type elementType) {
return RankedTensorType::get(staticSizes, elementType);
}
namespace {
/// Change the type of the result of a `linalg.init_tensor` by making the result
/// type statically sized along dimension that in the original operation where
/// defined as dynamic, but the size was defined using a `constant` op. For
/// example
///
/// %c5 = arith.constant 5: index
/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
///
/// to
///
/// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
using OpRewritePattern<InitTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InitTensorOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 4> dynamicSizes;
SmallVector<int64_t, 4> staticSizes;
for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
// If the size is already static, nothing to do.
if (!op.isDynamicSize(i)) {
staticSizes.push_back(op.getStaticSize(i));
continue;
}
// If the size is dynamic but defined using a `constant` op, get the
// constant value to find the static size to use.
unsigned operandNum = op.getIndexOfDynamicSize(i);
Value sizeOperand = op.getOperand(operandNum);
if (auto constantIndexOp =
sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
staticSizes.push_back(constantIndexOp.value());
continue;
}
// Fallback case. Keep the size dynamic.
dynamicSizes.push_back(sizeOperand);
staticSizes.push_back(ShapedType::kDynamicSize);
}
RankedTensorType newType =
RankedTensorType::get(staticSizes, op.getType().getElementType());
if (newType == op.getType())
return failure();
auto newOp =
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
rewriter.getI64ArrayAttr(staticSizes));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
} // namespace
namespace {
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
/// slice of this is also needed only for its shape. The result can be
/// replaced by a new init_tensor operation of the same size as the extract
/// slice op.
struct FoldInitTensorWithExtractSliceOp
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
if (!sliceOp.source().getDefiningOp<linalg::InitTensorOp>())
return failure();
// ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
// as well as its result type.
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
sliceOp, sliceOp.sizes(),
sliceOp.result().getType().cast<RankedTensorType>().getShape(),
sliceOp.getSourceType().getElementType());
return success();
}
};
template <typename TensorReshapeOp>
struct FoldInitTensorWithTensorReshapeOp
: public OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
if (!reshapeOp.src().template getDefiningOp<InitTensorOp>())
return failure();
Location loc = reshapeOp.getLoc();
ReifiedRankedShapedTypeDims resultShapes;
if (failed(reshapeOp.reifyResultShapes(rewriter, resultShapes)) ||
!llvm::hasSingleElement(resultShapes))
return failure();
Value initTensor = rewriter.create<InitTensorOp>(
loc, getAsOpFoldResult(resultShapes[0]),
reshapeOp.getResultType().getElementType());
if (initTensor.getType() != reshapeOp.getResultType()) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(
reshapeOp, reshapeOp.getResultType(), initTensor);
} else {
rewriter.replaceOp(reshapeOp, initTensor);
}
return success();
}
};
struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
PatternRewriter &rewriter) const override {
Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
auto initTensorOp = dimOp.source().getDefiningOp<linalg::InitTensorOp>();
if (!initTensorOp || !maybeConstantIndex)
return failure();
if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
return failure();
rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
return success();
}
};
} // namespace
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldInitTensorWithDimOp, FoldInitTensorWithExtractSliceOp,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
ReplaceStaticShapeDims>(context);
}
LogicalResult InitTensorOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto shapes = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
if (isDynamicSize(dim))
return getDynamicSize(dim);
return builder.create<arith::ConstantIndexOp>(getLoc(),
getStaticSize(dim));
}));
reifiedReturnShapes.emplace_back(std::move(shapes));
return success();
}
//===----------------------------------------------------------------------===//
// PadTensorOp
//===----------------------------------------------------------------------===//
// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
// supports optional types.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
Type typeToInfer, Type typeToInferFrom) {}
ParseResult parseInferType(OpAsmParser &parser,
Optional<OpAsmParser::OperandType> optOperand,
Type &typeToInfer, Type typeToInferFrom) {
if (optOperand)
typeToInfer = typeToInferFrom;
return success();
}
static LogicalResult verify(PadTensorOp op) {
auto sourceType = op.source().getType().cast<RankedTensorType>();
auto resultType = op.result().getType().cast<RankedTensorType>();
auto expectedType = PadTensorOp::inferResultType(
sourceType, extractFromI64ArrayAttr(op.static_low()),
extractFromI64ArrayAttr(op.static_high()));
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
if (expectedType.isDynamicDim(i))
continue;
return op.emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
auto &region = op.region();
unsigned rank = resultType.getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
return op.emitError("expected the block to have ") << rank << " arguments";
// Note: the number and type of yield values are checked in the YieldOp.
for (auto en : llvm::enumerate(block.getArgumentTypes())) {
if (!en.value().isIndex())
return op.emitOpError("expected block argument ")
<< (en.index() + 1) << " to be an index";
}
return success();
}
RankedTensorType PadTensorOp::inferResultType(RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh,
ArrayRef<int64_t> resultShape) {
unsigned rank = sourceType.getRank();
assert(staticLow.size() == rank && "unexpected staticLow size mismatch");
assert(staticHigh.size() == rank && "unexpected staticHigh size mismatch");
assert((resultShape.empty() || resultShape.size() == rank) &&
"unexpected resultShape size mismatch");
SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
if (sourceType.isDynamicDim(i) ||
staticLow[i] == ShapedType::kDynamicSize ||
staticHigh[i] == ShapedType::kDynamicSize) {
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
: resultShape[i]);
} else {
int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
assert((resultShape.empty() || size == resultShape[i] ||
resultShape[i] == ShapedType::kDynamicSize) &&
"mismatch between inferred shape and result shape");
inferredShape.push_back(size);
}
}
return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh, ValueRange low,
ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = source.getType().cast<RankedTensorType>();
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamicSize);
build(b, result, source, staticVector, staticVector, low, high, nofold,
attrs);
}
void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
assert(resultType.isa<RankedTensorType>());
auto sourceType = source.getType().cast<RankedTensorType>();
SmallVector<Value, 4> dynamicLow, dynamicHigh;
SmallVector<int64_t, 4> staticLow, staticHigh;
// staticLow and staticHigh have full information of the padding config.
// This will grow staticLow and staticHigh with 1 value. If the config is
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
// value as well.
dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
ShapedType::kDynamicSize);
dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
ShapedType::kDynamicSize);
if (!resultType) {
resultType =
PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
}
build(b, result, resultType, source, dynamicLow, dynamicHigh,
b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
result.addAttributes(attrs);
}
PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high,
bool nofold, Location loc,
OpBuilder &builder) {
auto padTensorOp =
builder.create<linalg::PadTensorOp>(loc, type, source, low, high, nofold);
int rank = padTensorOp.getResultType().getRank();
SmallVector<Type, 4> blockArgTypes;
blockArgTypes.assign(rank, builder.getIndexType());
auto &region = padTensorOp.region();
// `builder.createBlock` changes the insertion point within the block. Create
// a guard to reset the insertion point of the builder after it is destroyed.
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&region, region.end(), blockArgTypes);
builder.create<linalg::YieldOp>(loc, pad);
return padTensorOp;
}
PadTensorOp PadTensorOp::createPadHighOp(Type type, Value source, Value pad,
bool nofold, Location loc,
OpBuilder &b) {
SmallVector<OpFoldResult, 4> low, high;
auto rankedTensorType = type.cast<RankedTensorType>();
assert(rankedTensorType.hasStaticShape());
for (auto en : enumerate(rankedTensorType.getShape())) {
AffineExpr d0;
bindDims(b.getContext(), d0);
auto dimOp = b.createOrFold<tensor::DimOp>(loc, source, en.index());
Value paddingWidth =
makeComposedAffineApply(b, loc, en.value() - d0, {dimOp});
high.push_back(paddingWidth);
low.push_back(b.createOrFold<arith::ConstantIndexOp>(loc, 0));
}
return PadTensorOp::createPadScalarOp(type, source, pad, low, high, nofold,
loc, b);
}
LogicalResult PadTensorOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
Location loc = getLoc();
auto lowPad = getMixedLowPad();
auto highPad = getMixedHighPad();
SmallVector<Value> shapes;
for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
// Shape along each dimension is source dim + low pad + high pad.
SmallVector<Value> mapOperands;
mapOperands.push_back(b.createOrFold<tensor::DimOp>(loc, source(), dim));
AffineExpr expr = b.getAffineDimExpr(0);
unsigned numSymbols = 0;
auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
if (Value v = valueOrAttr.dyn_cast<Value>()) {
expr = expr + b.getAffineSymbolExpr(numSymbols++);
mapOperands.push_back(v);
return;
}
int64_t staticValue =
valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
expr = expr + staticValue;
};
addOpFoldResult(lowPad[dim]);
addOpFoldResult(highPad[dim]);
shapes.push_back(applyMapToValues(
b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
}
reifiedReturnShapes.emplace_back(std::move(shapes));
return success();
}
//===----------------------------------------------------------------------===//
// Methods related to PadTensor tiling.
//===----------------------------------------------------------------------===//
SmallVector<Value> PadTensorOp::getDestinationOperands(OpBuilder &b) {
ReifiedRankedShapedTypeDims reifiedShapes;
(void)reifyResultShapes(b, reifiedShapes);
SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
Value initTensor = b.create<InitTensorOp>(getLoc(), mixedSizes,
getResultType().getElementType());
return {initTensor};
}
SmallVector<StringRef> PadTensorOp::getLoopIteratorTypes() {
SmallVector<StringRef> iteratorTypes(getResultType().getRank(),
getParallelIteratorTypeName());
return iteratorTypes;
}
SmallVector<Range> PadTensorOp::getLoopBounds(OpBuilder &b) {
ReifiedRankedShapedTypeDims reifiedShapes;
(void)reifyResultShapes(b, reifiedShapes);
Value zero = b.create<arith::ConstantIndexOp>(getLoc(), 0);
Value one = b.create<arith::ConstantIndexOp>(getLoc(), 1);
// Initialize all the ranges to {zero, one, one}. All the `ub`s are
// overwritten.
SmallVector<Range> loopRanges(reifiedShapes[0].size(), {zero, one, one});
for (auto ub : enumerate(reifiedShapes[0]))
loopRanges[ub.index()].size = ub.value();
return loopRanges;
}
Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
// Only constant padding value supported.
Value padValue = getConstantPaddingValue();
if (!padValue)
return nullptr;
// Helper variables and functions for various arithmetic operations. These are
// used extensively for computing new offset/length and padding values.
Location loc = getLoc();
AffineExpr dim0, dim1;
bindDims(b.getContext(), dim0, dim1);
// Add two integers.
auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
auto add = [&](Value v1, Value v2) {
return b.createOrFold<AffineApplyOp>(loc, addMap, ValueRange{v1, v2});
};
// Subtract two integers.
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
auto sub = [&](Value v1, Value v2) {
return b.createOrFold<AffineApplyOp>(loc, subMap, ValueRange{v1, v2});
};
// Take the minimum of two integers.
auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext());
auto min = [&](Value v1, Value v2) {
return b.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
};
// Take the maximum of two integers.
auto max = [&](Value v1, Value v2) {
return b.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
};
// Zero index-typed integer.
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
// Helper function for filling static/dynamic low/high padding indices vectors
// of PadTensorOp.
auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
SmallVector<int64_t> &staticIndices) {
if (auto constInt = getConstantIntValue(val)) {
staticIndices.push_back(*constInt);
} else {
staticIndices.push_back(ShapedType::kDynamicSize);
dynIndices.push_back(val);
}
};
// Compute new offsets, lengths, low padding, high padding.
SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
SmallVector<Value> newLows, newHighs;
SmallVector<int64_t> staticNewLows, staticNewHighs;
// Set to true if the original data source is not read at all.
bool hasZeroLen = false;
// Same as hasZeroLen, but for dynamic dimension sizes. This condition
// is true if the original data source turns out to be unused at runtime.
Value dynHasZeroLenCond;
int64_t rank = getSourceType().getRank();
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = getValueOrCreateConstantIndexOp(b, loc, getMixedLowPad()[dim]);
bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
auto high = getValueOrCreateConstantIndexOp(b, loc, getMixedHighPad()[dim]);
bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]);
auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]);
auto srcSize = b.createOrFold<tensor::DimOp>(loc, source(), dim);
// The new amount of low padding is `low - offset`. Except for the case
// where none of the low padding is read. In that case, the new amount of
// low padding is zero.
//
// Optimization: If low = 0, then newLow = 0.
Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
appendIndex(newLow, newLows, staticNewLows);
// Start reading the data from position `offset - low`. Since the original
// read may have started in the low padding zone, this value could be
// negative. Therefore, start reading from:
//
// max(offset - low, 0)
//
// The original read could also have started in the high padding zone.
// In that case, set the offset to the end of source tensor. The new
// ExtractSliceOp length will be zero in that case. (Effectively reading no
// data from the source.)
//
// Optimization: If low = 0, then the formula can be simplified.
Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
: min(offset, srcSize);
newOffsets.push_back(getAsOpFoldResult(newOffset));
// The original ExtractSliceOp was reading until position `offset + length`.
// Therefore, the corresponding position within the source tensor is:
//
// offset + length - low
//
// In case the original ExtractSliceOp stopped reading within the low
// padding zone, this value can be negative. In that case, the end position
// of the read should be zero. (Similar to newOffset.)
//
// The original read could also have stopped in the high padding zone.
// In that case, set the end positition of the read should be the end of the
// source tensor. (Similar to newOffset.)
//
// endLoc = min(max(offset - low + length, 0), srcSize)
//
// The new ExtractSliceOp length is `endLoc - newOffset`.
//
// Optimization: If low = 0, then the formula can be simplified.
Value endLoc = hasLowPad
? min(max(add(sub(offset, low), length), zero), srcSize)
: min(add(offset, length), srcSize);
Value newLength = sub(endLoc, newOffset);
newLengths.push_back(getAsOpFoldResult(newLength));
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
if (auto newLengthInt = getConstantIntValue(newLength)) {
hasZeroLen |= *newLengthInt == 0;
} else {
Value check = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
newLength, zero);
dynHasZeroLenCond =
dynHasZeroLenCond
? b.create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
: check;
}
// The amount of high padding is simply the number of elements remaining,
// so that the result has the same length as the original ExtractSliceOp.
// As an optimization, if the original high padding is zero, then the new
// high padding must also be zero.
Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
appendIndex(newHigh, newHighs, staticNewHighs);
// Only unit stride supported.
newStrides.push_back(b.getIndexAttr(1));
}
// The shape of the result can be obtained from the sizes passed in.
SmallVector<Value> dynDims;
SmallVector<int64_t> shape;
dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize);
RankedTensorType resultType =
RankedTensorType::get(shape, getResultType().getElementType());
// Insert cast to ensure that types match. (May be folded away.)
auto castResult = [&](Value val) -> Operation * {
auto castOp = b.create<tensor::CastOp>(loc, resultType, val);
return castOp;
};
// In cases where the original data source is unused: Emit a GenerateOp and
// do not generate a SliceOp. (The result shape of the SliceOp would
// have a dimension of size 0, the semantics of which is unclear.)
auto createGenerateOp = [&]() {
// Create GenerateOp.
auto generateOp = b.create<tensor::GenerateOp>(
loc, resultType, dynDims,
[&](OpBuilder &builder, Location gLoc, ValueRange indices) {
builder.create<tensor::YieldOp>(gLoc, padValue);
});
return castResult(generateOp);
};
// Emit a SliceOp and a PadTensorOp. Should not be used in cases where
// the result shape of the new SliceOp has a zero dimension.
auto createPadTensorOfSubTensor = [&]() {
// Create pad_tensor(subtensor(x)).
auto newSliceOp = b.create<tensor::ExtractSliceOp>(
loc, source(), newOffsets, newLengths, newStrides);
auto newPadTensorOp = b.create<PadTensorOp>(
loc, newSliceOp, staticNewLows, staticNewHighs, newLows, newHighs);
// Copy region to new PadTensorOp.
BlockAndValueMapping bvm;
region().cloneInto(&newPadTensorOp.getRegion(), bvm);
// Cast result and return.
return castResult(newPadTensorOp);
};
// Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
// that the original data source x is not used.
if (hasZeroLen) {
return createGenerateOp();
}
// If there are dynamic dimensions: Generate an scf.if check to avoid creating
// SliceOps with result dimensions of size 0 at runtime.
if (dynHasZeroLenCond) {
auto result = b.create<scf::IfOp>(
loc, resultType, dynHasZeroLenCond,
/*thenBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
},
/*elseBuilder=*/
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc,
createPadTensorOfSubTensor()->getResult(0));
});
return result;
}
return createPadTensorOfSubTensor();
}
namespace {
// Folds linalg.pad_tensor when padding is static zeros and the attribute
// doesn't request otherwise.
struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
PatternRewriter &rewriter) const override {
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
return failure();
if (padTensorOp.nofold())
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(
padTensorOp, padTensorOp.result().getType(), padTensorOp.source());
return success();
}
};
// Fold CastOp into PadTensorOp when adding static information.
struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
PatternRewriter &rewriter) const override {
auto castOp = padTensorOp.source().getDefiningOp<tensor::CastOp>();
if (!tensor::canFoldIntoConsumerOp(castOp))
return failure();
auto newResultType = PadTensorOp::inferResultType(
castOp.source().getType().cast<RankedTensorType>(),
extractFromI64ArrayAttr(padTensorOp.static_low()),
extractFromI64ArrayAttr(padTensorOp.static_high()),
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
rewriter.updateRootInPlace(padTensorOp, [&]() {
padTensorOp.sourceMutable().assign(castOp.source());
});
} else {
auto newOp = rewriter.create<PadTensorOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
padTensorOp.static_high(), padTensorOp.nofold());
BlockAndValueMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
padTensorOp, padTensorOp.getResultType(), newOp);
}
return success();
}
};
// Fold CastOp using the result of PadTensorOp back into the latter if it adds
// static information.
struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
PatternRewriter &rewriter) const override {
if (!padTensorOp.result().hasOneUse())
return failure();
auto tensorCastOp =
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
if (!tensorCastOp)
return failure();
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
tensorCastOp.dest().getType()))
return failure();
auto replacementOp = rewriter.create<PadTensorOp>(
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
padTensorOp.static_low(), padTensorOp.static_high(),
padTensorOp.nofold());
replacementOp.region().takeBody(padTensorOp.region());
rewriter.replaceOp(padTensorOp, replacementOp.result());
rewriter.replaceOp(tensorCastOp, replacementOp.result());
return success();
}
};
} // namespace
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
results.add<FoldTargetTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,
/// "constant" means an actual constant or "defined outside of the block".
///
/// Values are considered constant in three cases:
/// - A ConstantLike value.
/// - A basic block argument from a different block.
/// - A value defined outside of the block.
///
/// If the padding value is not constant, an empty Value is returned.
Value PadTensorOp::getConstantPaddingValue() {
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
if (!yieldOp || yieldOp.values().size() != 1)
return {};
Value padValue = yieldOp.values().front();
// Check if yield value is a constant.
if (matchPattern(padValue, m_Constant()))
return padValue;
// Check if yield value is defined inside the PadTensorOp block.
if (padValue.getParentBlock() == &getRegion().front())
return {};
// Else: Yield value defined outside of the PadTensorOp block.
return padValue;
}
OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
!nofold())
return source();
return {};
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
}
static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
::mlir::printReshapeOp<linalg::TensorCollapseShapeOp>(p, op);
}
template <typename AffineExprTy>
unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
unsigned pos = 0;
for (const auto &exprs : exprArrays) {
for (auto expr : exprs) {
expr.walk([&pos](AffineExpr e) {
if (auto d = e.dyn_cast<AffineExprTy>())
pos = std::max(pos, d.getPosition());
});
}
}
return pos;
}
SmallVector<AffineMap, 4> TensorCollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4>
TensorCollapseShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
}
SmallVector<AffineMap, 4> TensorExpandShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4>
TensorExpandShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
}
/// For reshape op compute the shape at dimension `dimIndex` of the output in
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
/// `src`.
static OpFoldResult
getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
int64_t dimIndex, Value src,
ArrayRef<AffineMap> reassociationMap) {
AffineMap map = reassociationMap[dimIndex];
unsigned startPos =
map.getResults().front().cast<AffineDimExpr>().getPosition();
unsigned endPos = map.getResults().back().cast<AffineDimExpr>().getPosition();
AffineExpr expr;
SmallVector<Value, 2> dynamicDims;
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
dynamicDims.push_back(builder.createOrFold<tensor::DimOp>(loc, src, dim));
AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
expr = (expr ? expr * currExpr : currExpr);
}
return applyMapToValues(builder, loc,
AffineMap::get(0, endPos - startPos + 1, expr),
dynamicDims)[0];
}
/// Given the `src` of a collapsing reshape op and its reassociation maps,
/// compute the shape of the result of the reshape.
static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
return llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
return getCollapsedOutputDimFromInputShape(builder, loc, dim, src,
reassociation);
}));
}
/// Compute a map that for a given dimension of the expanded type gives the
/// dimension in the collapsed type it maps to. Essentially its the inverse of
/// the `reassocation` maps.
static llvm::DenseMap<int64_t, int64_t>
getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
for (auto map : enumerate(reassociation)) {
unsigned startPos =
map.value().getResults().front().cast<AffineDimExpr>().getPosition();
unsigned endPos =
map.value().getResults().back().cast<AffineDimExpr>().getPosition();
for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
expandedDimToCollapsedDim[dim] = map.index();
}
}
return expandedDimToCollapsedDim;
}
/// For an expanding reshape op, compute the value for a dimension of the output
/// from the shape of the input.
static OpFoldResult getExpandedOutputDimFromInputShape(
OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
}
unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
unsigned startPos = reassociation[sourceDimPos]
.getResults()
.front()
.cast<AffineDimExpr>()
.getPosition();
unsigned endPos = reassociation[sourceDimPos]
.getResults()
.back()
.cast<AffineDimExpr>()
.getPosition();
int64_t linearizedStaticDim = 1;
for (auto d :
llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
if (d.index() + startPos == static_cast<unsigned>(dimIndex))
continue;
assert(!ShapedType::isDynamic(d.value()) &&
"single dimension cannot be expanded into multiple dynamic "
"dimensions");
linearizedStaticDim *= d.value();
}
Value sourceDim = builder.create<tensor::DimOp>(loc, src, sourceDimPos);
return applyMapToValues(
builder, loc,
AffineMap::get(
0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
sourceDim)[0];
}
/// Given the `src` of an expanding reshape op, the reassociation maps and the
/// result type, compute the shape of the result of the reshape.
static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
getExpandedDimToCollapsedDimMap(reassociation);
return llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
dstStaticShape, reassociation,
expandedDimToCollapsedDim);
}));
}
static SmallVector<OpFoldResult, 4>
getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
ArrayRef<int64_t> dstStaticShape,
ArrayRef<AffineMap> reassocation) {
return dstStaticShape.size() >
static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
? getExpandedOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation)
: getCollapsedOutputShapeFromInputShape(
builder, loc, src, dstStaticShape, reassocation);
}
//===----------------------------------------------------------------------===//
// TensorReshapeOp
//===----------------------------------------------------------------------===//
/// Compute the RankedTensorType obtained by applying `reassociation` to `type`.
static RankedTensorType
computeTensorReshapeCollapsedType(RankedTensorType type,
ArrayRef<AffineMap> reassociation) {
auto shape = type.getShape();
SmallVector<int64_t, 4> newShape;
newShape.reserve(reassociation.size());
// Use the fact that reassociation is valid to simplify the logic: only use
// each map's rank.
assert(isReassociationValid(reassociation) && "invalid reassociation");
unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
auto band = shape.slice(currentDim, dim);
int64_t size = 1;
if (llvm::is_contained(band, ShapedType::kDynamicSize))
size = ShapedType::kDynamicSize;
else
for (unsigned d = 0; d < dim; ++d)
size *= shape[currentDim + d];
newShape.push_back(size);
currentDim += dim;
}
return RankedTensorType::get(newShape, type.getElementType());
}
void mlir::linalg::TensorCollapseShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto resultType = computeTensorReshapeCollapsedType(
src.getType().cast<RankedTensorType>(),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
void mlir::linalg::TensorExpandShapeOp::build(
OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto resultType = computeTensorReshapeCollapsedType(
src.getType().cast<RankedTensorType>(),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
build(b, result, resultType, src, attrs);
result.addAttribute(getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
}
template <typename TensorReshapeOp,
bool isExpansion =
std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
RankedTensorType expandedType,
RankedTensorType collapsedType) {
if (failed(
verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
return failure();
auto maps = op.getReassociationMaps();
RankedTensorType expectedType =
computeTensorReshapeCollapsedType(expandedType, maps);
if (collapsedType != expectedType)
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
return success();
}
static LogicalResult verify(TensorExpandShapeOp op) {
return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
}
static LogicalResult verify(TensorCollapseShapeOp op) {
return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
}
namespace {
/// Reshape of a splat constant can be replaced with a constant of the result
/// type.
template <typename TensorReshapeOp>
struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
DenseElementsAttr attr;
if (!matchPattern(reshapeOp.src(), m_Constant(&attr)))
return failure();
if (!attr || !attr.isSplat())
return failure();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
reshapeOp.getResultType(), attr.getRawData(), true);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
return success();
}
};
/// Fold linalg.fill -> linalg.tensor_reshape chain.
///
/// For such op chains, we can create new linalg.fill ops with the result
/// type of the linalg.tensor_reshape op.
template <typename TensorReshapeOp>
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto oldFill = reshapeOp.src().template getDefiningOp<FillOp>();
if (!oldFill)
return failure();
Location loc = oldFill.getLoc();
auto newInit = rewriter.create<TensorReshapeOp>(
loc, reshapeOp.getResultType(), oldFill.output(),
reshapeOp.reassociation());
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, oldFill.value(), newInit);
return success();
}
};
} // namespace
void TensorExpandShapeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<CollapseReshapeOps<TensorExpandShapeOp>,
CollapseMixedReshapeOps<TensorExpandShapeOp, TensorCollapseShapeOp>,
FoldFillWithTensorReshape<TensorExpandShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorExpandShapeOp>,
FoldReshapeWithConstant<TensorExpandShapeOp>>(context);
}
void TensorCollapseShapeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<CollapseReshapeOps<TensorCollapseShapeOp>,
CollapseMixedReshapeOps<TensorCollapseShapeOp, TensorExpandShapeOp>,
FoldFillWithTensorReshape<TensorCollapseShapeOp>,
FoldInitTensorWithTensorReshapeOp<TensorCollapseShapeOp>,
FoldReshapeWithConstant<TensorCollapseShapeOp>>(context);
}
LogicalResult TensorExpandShapeOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
b, getLoc(), src(), getResultType().getShape(),
getReassociationMaps()));
reifiedReturnShapes.emplace_back(std::move(resultShape));
return success();
}
LogicalResult TensorCollapseShapeOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
auto resultShape =
getAsValues(b, getLoc(),
getReshapeOutputShapeFromInputShape(
b, getLoc(), src(), getResultType().getShape(),
getReassociationMaps()));
reifiedReturnShapes.emplace_back(std::move(resultShape));
return success();
}
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, linalg::YieldOp op) {
if (op.getNumOperands() > 0)
p << ' ' << op.getOperands();
p.printOptionalAttrDict(op->getAttrs());
if (op.getNumOperands() > 0)
p << " : " << op.getOperandTypes();
}
static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands));
}
// Check the operand number and types must match the element types of the
// LinalgOp interface's shaped operands.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
if (op.getNumOperands() != linalgOp.getNumOutputs())
return op.emitOpError("expected number of yield values (")
<< linalgOp.getNumOutputs()
<< ") to match the number of operands of the enclosing "
<< "LinalgOp (" << op.getNumOperands() << ")";
for (OpOperand &opOperand : op->getOpOperands()) {
OpOperand *outputOperand =
linalgOp.getOutputOperand(opOperand.getOperandNumber());
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
if (opOperand.get().getType() != elementType)
return op.emitOpError("type of yield operand ")
<< (opOperand.getOperandNumber() + 1) << " ("
<< opOperand.get().getType() << ") doesn't match "
<< "the element type of the enclosing linalg.generic op ("
<< elementType << ")";
}
return success();
}
static LogicalResult verify(linalg::YieldOp op) {
auto *parentOp = op->getParentOp();
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
return op.emitOpError("expected single non-empty parent region");
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
return verifyYield(op, cast<LinalgOp>(parentOp));
if (auto padTensorOp = dyn_cast<linalg::PadTensorOp>(parentOp)) {
if (op.getNumOperands() != 1)
return op.emitOpError("expected single yield operand (got ")
<< op->getNumOperands() << ")";
if (op.getOperand(0).getType() !=
padTensorOp.getType().cast<ShapedType>().getElementType())
return op.emitOpError("expected yield type to match shape element type");
return success();
}
if (auto tiledLoopOp = dyn_cast<linalg::TiledLoopOp>(parentOp)) {
// Check if output args with tensor types match results types.
SmallVector<Value, 2> tensorOuts;
llvm::copy_if(
tiledLoopOp.outputs(), std::back_inserter(tensorOuts),
[&](Value out) { return out.getType().isa<RankedTensorType>(); });
if (tensorOuts.size() != op.values().size())
return op.emitOpError("expected number of tensor output args = ")
<< tensorOuts.size() << " to match the number of yield operands = "
<< op.values().size();
TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts));
for (auto &item :
llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) {
Type outType, resultType;
unsigned index = item.index();
std::tie(outType, resultType) = item.value();
if (outType != resultType)
return op.emitOpError("expected yield operand ")
<< index << " with type = " << resultType
<< " to match output arg type = " << outType;
}
return success();
}
return op.emitOpError("expected parent op with LinalgOp interface");
}
//===----------------------------------------------------------------------===//
// TiledLoopOp
//===----------------------------------------------------------------------===//
void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
ValueRange lowerBounds, ValueRange upperBounds,
ValueRange steps, ValueRange inputs, ValueRange outputs,
ArrayAttr iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange,
ValueRange, ValueRange)>
bodyBuilderFn) {
build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs,
iteratorTypes, llvm::None, bodyBuilderFn);
}
void TiledLoopOp::build(OpBuilder &builder, OperationState &result,
ValueRange lowerBounds, ValueRange upperBounds,
ValueRange steps, ValueRange inputs, ValueRange outputs,
ArrayAttr iteratorTypes,
Optional<ArrayAttr> distributionTypes,
function_ref<void(OpBuilder &, Location, ValueRange,
ValueRange, ValueRange)>
bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
result.addOperands(steps);
result.addOperands(inputs);
result.addOperands(outputs);
result.addAttribute(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
static_cast<int32_t>(upperBounds.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
if (distributionTypes.hasValue())
result.addAttribute(getDistributionTypesAttrName(),
distributionTypes.getValue());
// Add output types for `RankedTensorType` output arguments.
for (Value output : outputs) {
Type outputType = output.getType();
if (outputType.isa<RankedTensorType>())
result.addTypes(outputType);
}
OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
for (Type type : TypeRange(inputs))
argTypes.push_back(type);
for (Type type : TypeRange(outputs))
argTypes.push_back(type);
Region *bodyRegion = result.addRegion();
Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock);
bodyBuilderFn(builder, result.location,
bodyBlock->getArguments().take_front(numIVs),
bodyBlock->getArguments().slice(numIVs, inputs.size()),
bodyBlock->getArguments().take_back(outputs.size()));
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
}
static void print(OpAsmPrinter &p, TiledLoopOp op) {
p << " (" << op.getInductionVars() << ") = (" << op.lowerBound() << ") to ("
<< op.upperBound() << ") step (" << op.step() << ")";
if (!op.inputs().empty()) {
p << " ins (";
llvm::interleaveComma(llvm::zip(op.getRegionInputArgs(), op.inputs()), p,
[&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it)
<< ": " << std::get<1>(it).getType();
});
p << ")";
}
if (!op.outputs().empty()) {
p << " outs (";
llvm::interleaveComma(llvm::zip(op.getRegionOutputArgs(), op.outputs()), p,
[&](auto it) {
p << std::get<0>(it) << " = " << std::get<1>(it)
<< ": " << std::get<1>(it).getType();
});
p << ")";
}
if (llvm::any_of(op.iterator_types(), [](Attribute attr) {
return attr.cast<StringAttr>().getValue() !=
getParallelIteratorTypeName();
}))
p << " iterators" << op.iterator_types() << "";
if (op.distribution_types().hasValue())
p << " distribution" << op.distribution_types().getValue() << "";
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(),
getIteratorTypesAttrName(),
getDistributionTypesAttrName()});
}
static ParseResult parseTiledLoopOp(OpAsmParser &parser,
OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::OperandType, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
return failure();
// Parse loop bounds.
SmallVector<OpAsmParser::OperandType, 4> lower;
if (parser.parseEqual() ||
parser.parseOperandList(lower, ivs.size(),
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
return failure();
SmallVector<OpAsmParser::OperandType, 4> upper;
if (parser.parseKeyword("to") ||
parser.parseOperandList(upper, ivs.size(),
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
return failure();
// Parse step values.
SmallVector<OpAsmParser::OperandType, 4> steps;
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, ivs.size(),
OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return failure();
// Parse input tensors.
SmallVector<OpAsmParser::OperandType, 4> inputs, input_region_args;
SmallVector<Type, 4> inputTypes;
if (succeeded(parser.parseOptionalKeyword("ins"))) {
llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseAssignmentListWithTypes(input_region_args, inputs,
inputTypes))
return failure();
if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc,
result.operands))
return failure();
}
// Parse output tensors.
SmallVector<OpAsmParser::OperandType, 4> outputs, output_region_args;
SmallVector<Type, 4> outputTypes;
if (succeeded(parser.parseOptionalKeyword("outs"))) {
llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseAssignmentListWithTypes(output_region_args, outputs,
outputTypes))
return failure();
if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc,
result.operands))
return failure();
for (Type outputType : outputTypes)
if (outputType.isa<RankedTensorType>())
result.addTypes(outputType);
}
// Parse attributes.
SmallVector<Attribute, 4> iterTypes, distributionTypes;
auto parseAttr = [&](StringRef keyword, SmallVector<Attribute, 4> *attrs) {
if (succeeded(parser.parseOptionalKeyword(keyword))) {
StringAttr attr;
if (parser.parseLSquare() || parser.parseAttribute(attr))
return failure();
attrs->push_back(attr);
for (int i = 1, e = ivs.size(); i < e; ++i) {
if (parser.parseComma() || parser.parseAttribute(attr))
return failure();
attrs->push_back(attr);
}
if (parser.parseRSquare())
return failure();
}
return success();
};
if (failed(parseAttr("iterators", &iterTypes)) ||
failed(parseAttr("distribution", &distributionTypes)))
return failure();
// Set all loop iterator types to "parallel" if they are not printed in IR.
if (iterTypes.empty()) {
auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName());
iterTypes = SmallVector<Attribute, 4>(ivs.size(), parallelIter);
}
result.addAttribute(getIteratorTypesAttrName(),
builder.getArrayAttr(iterTypes));
if (!distributionTypes.empty())
result.addAttribute(getDistributionTypesAttrName(),
builder.getArrayAttr(distributionTypes));
result.addAttribute(
TiledLoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
static_cast<int32_t>(upper.size()),
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
// Parse the body.
Region *body = result.addRegion();
SmallVector<Type, 4> region_types(ivs.size(), builder.getIndexType());
region_types.append(inputTypes);
region_types.append(outputTypes);
SmallVector<OpAsmParser::OperandType, 4> region_args(ivs);
region_args.append(input_region_args);
region_args.append(output_region_args);
if (parser.parseRegion(*body, region_args, region_types))
return failure();
// Parse optional attributes.
parser.parseOptionalAttrDict(result.attributes);
return success();
}
Region &TiledLoopOp::getLoopBody() { return region(); }
LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
for (auto *op : ops)
op->moveBefore(*this);
return success();
}
bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) {
return !region().isAncestor(value.getParentRegion());
}
static LogicalResult verify(TiledLoopOp op) {
// Check if iterator types are provided for every loop dimension.
if (op.iterator_types().size() != op.getNumLoops())
return op.emitOpError("expected iterator types array attribute size = ")
<< op.iterator_types().size()
<< " to match the number of loops = " << op.getNumLoops();
// Check if types of input arguments match region args types.
for (auto &item :
llvm::enumerate(llvm::zip(op.inputs(), op.getRegionInputArgs()))) {
Value input, inputRegionArg;
unsigned index = item.index();
std::tie(input, inputRegionArg) = item.value();
if (input.getType() != inputRegionArg.getType())
return op.emitOpError("expected input arg ")
<< index << " with type = " << input.getType()
<< " to match region arg " << index + op.getNumLoops()
<< " type = " << inputRegionArg.getType();
}
// Check if types of input arguments match region args types.
for (auto &item :
llvm::enumerate(llvm::zip(op.outputs(), op.getRegionOutputArgs()))) {
Value output, outputRegionArg;
unsigned index = item.index();
std::tie(output, outputRegionArg) = item.value();
if (output.getType() != outputRegionArg.getType())
return op.emitOpError("expected output arg ")
<< index << " with type = " << output.getType()
<< " to match region arg "
<< index + op.getNumLoops() + op.inputs().size()
<< " type = " << outputRegionArg.getType();
}
return success();
}
namespace {
static constexpr int64_t kNoMatch = -1;
// Folds away TiledLoopOp inputs if they have no uses within the body.
//
// Example:
//
// %0 = linalg.tiled_loop ... ins (%in_ = %in: tensor<...>,
// %in_buf_ = %in_buf: memref<...>) {...}
// Becomes
//
// linalg.tiled_loop ... ins (%in_buf_ = %in_buf: memref<...>) {...}
struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
using OpRewritePattern<linalg::TiledLoopOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TiledLoopOp tiledLoop,
PatternRewriter &rewriter) const final {
SmallVector<Value, 2> newInputs, regionInputTensorArgs;
// Store ids of the corresponding old and new input operands.
SmallVector<int64_t, 2> oldInputIdToNew(tiledLoop.inputs().size(),
kNoMatch);
for (auto en : llvm::enumerate(
llvm::zip(tiledLoop.inputs(), tiledLoop.getRegionInputArgs()))) {
Value in, bbArg;
size_t index = en.index();
std::tie(in, bbArg) = en.value();
if (!bbArg.use_empty()) {
oldInputIdToNew[index] = newInputs.size();
newInputs.push_back(in);
}
}
if (newInputs.size() == tiledLoop.inputs().size())
return failure();
Location loc = tiledLoop.getLoc();
auto newTiledLoop = rewriter.create<TiledLoopOp>(
loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(),
newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(),
tiledLoop.distribution_types());
// Clone the region.
BlockAndValueMapping bvm;
bvm.map(tiledLoop.getInductionVars(), newTiledLoop.getInductionVars());
bvm.map(tiledLoop.getRegionOutputArgs(),
newTiledLoop.getRegionOutputArgs());
for (const auto &en : llvm::enumerate(oldInputIdToNew))
if (en.value() != kNoMatch)
bvm.map(tiledLoop.getRegionInputArgs()[en.index()],
newTiledLoop.getRegionInputArgs()[en.value()]);
OpBuilder innerBuilder =
OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
for (auto &op : *tiledLoop.getBody())
innerBuilder.clone(op, bvm);
rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
return success();
}
};
} // namespace
/// A simple, conservative analysis to determine if the loop is shape
/// conserving. I.e., the type of the arg-th yielded value is the same as the
/// type of the corresponding basic block argument of the loop.
/// Note: This function handles only simple cases. Expand as needed.
static bool isShapePreserving(TiledLoopOp loopOp, int64_t arg) {
auto yieldOp = cast<YieldOp>(loopOp.getLoopBody().front().getTerminator());
if (yieldOp.values().empty())
// Tiled loop either has no outputs or is a "memref-based version". In
// either case, the loop is shape conserving.
return true;
assert(arg < static_cast<int64_t>(yieldOp.values().size()) &&
"arg is out of bounds");
Value value = yieldOp.values()[arg];
while (value) {
if (value == loopOp.getRegionOutputArgs()[arg])
return true;
OpResult opResult = value.dyn_cast<OpResult>();
if (!opResult)
return false;
using tensor::InsertSliceOp;
value = llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
.template Case<InsertSliceOp>(
[&](InsertSliceOp op) { return op.dest(); })
.template Case<TiledLoopOp>([&](TiledLoopOp loopOp) {
return isShapePreserving(loopOp, opResult.getResultNumber())
? loopOp.outputs()[opResult.getResultNumber()]
: Value();
})
.Default([&](auto op) { return Value(); });
}
return false;
}
namespace {
/// Fold dim(x) where `x` is an input/output argument of a TiledLoopOp block
/// to dim(y) where `y` is the initial input/output value of the argument.
///
/// E.g.:
/// %y = ... : tensor<...>
/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
/// tensor.dim %x, %c0 : tensor<...>
/// }
///
/// is folded to:
/// %y = ... : tensor<...>
/// linalg.tiled_loop ... ins(%x = %y : tensor<...>) {
/// tensor.dim %y, %c0 : tensor<...>
/// }
///
/// Note: Dim ops are folded only if it can be proven that the runtime type of
/// the yielded value (in case of outputs) does not change with loop iterations.
template <typename OpTy>
struct DimOfTiledLoopInsOutsFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy dimOp,
PatternRewriter &rewriter) const final {
auto src = dimOp.source().template dyn_cast<BlockArgument>();
if (!src)
return failure();
auto loopOp =
dyn_cast<TiledLoopOp>(src.getOwner()->getParent()->getParentOp());
if (!loopOp)
return failure();
unsigned numLoops = loopOp.getNumLoops();
unsigned numInputArgs = loopOp.getRegionInputArgs().size();
if (src.getArgNumber() >= numInputArgs + numLoops &&