| //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/BlockAndValueMapping.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/IntegerSet.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SmallBitVector.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| |
| #define DEBUG_TYPE "affine-analysis" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" |
| |
| /// A utility function to check if a value is defined at the top level of |
| /// `region` or is an argument of `region`. A value of index type defined at the |
| /// top level of a `AffineScope` region is always a valid symbol for all |
| /// uses in that region. |
| static bool isTopLevelValue(Value value, Region *region) { |
| if (auto arg = value.dyn_cast<BlockArgument>()) |
| return arg.getParentRegion() == region; |
| return value.getDefiningOp()->getParentRegion() == region; |
| } |
| |
| /// Checks if `value` known to be a legal affine dimension or symbol in `src` |
| /// region remains legal if the operation that uses it is inlined into `dest` |
| /// with the given value mapping. `legalityCheck` is either `isValidDim` or |
| /// `isValidSymbol`, depending on the value being required to remain a valid |
| /// dimension or symbol. |
| static bool |
| remainsLegalAfterInline(Value value, Region *src, Region *dest, |
| const BlockAndValueMapping &mapping, |
| function_ref<bool(Value, Region *)> legalityCheck) { |
| // If the value is a valid dimension for any other reason than being |
| // a top-level value, it will remain valid: constants get inlined |
| // with the function, transitive affine applies also get inlined and |
| // will be checked themselves, etc. |
| if (!isTopLevelValue(value, src)) |
| return true; |
| |
| // If it's a top-level value because it's a block operand, i.e. a |
| // function argument, check whether the value replacing it after |
| // inlining is a valid dimension in the new region. |
| if (value.isa<BlockArgument>()) |
| return legalityCheck(mapping.lookup(value), dest); |
| |
| // If it's a top-level value because it's defined in the region, |
| // it can only be inlined if the defining op is a constant or a |
| // `dim`, which can appear anywhere and be valid, since the defining |
| // op won't be top-level anymore after inlining. |
| Attribute operandCst; |
| return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) || |
| value.getDefiningOp<memref::DimOp>() || |
| value.getDefiningOp<tensor::DimOp>(); |
| } |
| |
| /// Checks if all values known to be legal affine dimensions or symbols in `src` |
| /// remain so if their respective users are inlined into `dest`. |
| static bool |
| remainsLegalAfterInline(ValueRange values, Region *src, Region *dest, |
| const BlockAndValueMapping &mapping, |
| function_ref<bool(Value, Region *)> legalityCheck) { |
| return llvm::all_of(values, [&](Value v) { |
| return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck); |
| }); |
| } |
| |
| /// Checks if an affine read or write operation remains legal after inlining |
| /// from `src` to `dest`. |
| template <typename OpTy> |
| static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest, |
| const BlockAndValueMapping &mapping) { |
| static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface, |
| AffineWriteOpInterface>::value, |
| "only ops with affine read/write interface are supported"); |
| |
| AffineMap map = op.getAffineMap(); |
| ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims()); |
| ValueRange symbolOperands = |
| op.getMapOperands().take_back(map.getNumSymbols()); |
| if (!remainsLegalAfterInline( |
| dimOperands, src, dest, mapping, |
| static_cast<bool (*)(Value, Region *)>(isValidDim))) |
| return false; |
| if (!remainsLegalAfterInline( |
| symbolOperands, src, dest, mapping, |
| static_cast<bool (*)(Value, Region *)>(isValidSymbol))) |
| return false; |
| return true; |
| } |
| |
| /// Checks if an affine apply operation remains legal after inlining from `src` |
| /// to `dest`. |
| // Use "unused attribute" marker to silence clang-tidy warning stemming from |
| // the inability to see through "llvm::TypeSwitch". |
| template <> |
| bool LLVM_ATTRIBUTE_UNUSED |
| remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest, |
| const BlockAndValueMapping &mapping) { |
| // If it's a valid dimension, we need to check that it remains so. |
| if (isValidDim(op.getResult(), src)) |
| return remainsLegalAfterInline( |
| op.getMapOperands(), src, dest, mapping, |
| static_cast<bool (*)(Value, Region *)>(isValidDim)); |
| |
| // Otherwise it must be a valid symbol, check that it remains so. |
| return remainsLegalAfterInline( |
| op.getMapOperands(), src, dest, mapping, |
| static_cast<bool (*)(Value, Region *)>(isValidSymbol)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineDialect Interfaces |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This class defines the interface for handling inlining with affine |
| /// operations. |
| struct AffineInlinerInterface : public DialectInlinerInterface { |
| using DialectInlinerInterface::DialectInlinerInterface; |
| |
| //===--------------------------------------------------------------------===// |
| // Analysis Hooks |
| //===--------------------------------------------------------------------===// |
| |
| /// Returns true if the given region 'src' can be inlined into the region |
| /// 'dest' that is attached to an operation registered to the current dialect. |
| /// 'wouldBeCloned' is set if the region is cloned into its new location |
| /// rather than moved, indicating there may be other users. |
| bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, |
| BlockAndValueMapping &valueMapping) const final { |
| // We can inline into affine loops and conditionals if this doesn't break |
| // affine value categorization rules. |
| Operation *destOp = dest->getParentOp(); |
| if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp)) |
| return false; |
| |
| // Multi-block regions cannot be inlined into affine constructs, all of |
| // which require single-block regions. |
| if (!llvm::hasSingleElement(*src)) |
| return false; |
| |
| // Side-effecting operations that the affine dialect cannot understand |
| // should not be inlined. |
| Block &srcBlock = src->front(); |
| for (Operation &op : srcBlock) { |
| // Ops with no side effects are fine, |
| if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) { |
| if (iface.hasNoEffect()) |
| continue; |
| } |
| |
| // Assuming the inlined region is valid, we only need to check if the |
| // inlining would change it. |
| bool remainsValid = |
| llvm::TypeSwitch<Operation *, bool>(&op) |
| .Case<AffineApplyOp, AffineReadOpInterface, |
| AffineWriteOpInterface>([&](auto op) { |
| return remainsLegalAfterInline(op, src, dest, valueMapping); |
| }) |
| .Default([](Operation *) { |
| // Conservatively disallow inlining ops we cannot reason about. |
| return false; |
| }); |
| |
| if (!remainsValid) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| /// Returns true if the given operation 'op', that is registered to this |
| /// dialect, can be inlined into the given region, false otherwise. |
| bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, |
| BlockAndValueMapping &valueMapping) const final { |
| // Always allow inlining affine operations into a region that is marked as |
| // affine scope, or into affine loops and conditionals. There are some edge |
| // cases when inlining *into* affine structures, but that is handled in the |
| // other 'isLegalToInline' hook above. |
| Operation *parentOp = region->getParentOp(); |
| return parentOp->hasTrait<OpTrait::AffineScope>() || |
| isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp); |
| } |
| |
| /// Affine regions should be analyzed recursively. |
| bool shouldAnalyzeRecursively(Operation *op) const final { return true; } |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // AffineDialect |
| //===----------------------------------------------------------------------===// |
| |
| void AffineDialect::initialize() { |
| addOperations<AffineDmaStartOp, AffineDmaWaitOp, |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |
| >(); |
| addInterfaces<AffineInlinerInterface>(); |
| } |
| |
| /// Materialize a single constant operation from a given attribute value with |
| /// the desired resultant type. |
| Operation *AffineDialect::materializeConstant(OpBuilder &builder, |
| Attribute value, Type type, |
| Location loc) { |
| return builder.create<arith::ConstantOp>(loc, type, value); |
| } |
| |
| /// A utility function to check if a value is defined at the top level of an |
| /// op with trait `AffineScope`. If the value is defined in an unlinked region, |
| /// conservatively assume it is not top-level. A value of index type defined at |
| /// the top level is always a valid symbol. |
| bool mlir::isTopLevelValue(Value value) { |
| if (auto arg = value.dyn_cast<BlockArgument>()) { |
| // The block owning the argument may be unlinked, e.g. when the surrounding |
| // region has not yet been attached to an Op, at which point the parent Op |
| // is null. |
| Operation *parentOp = arg.getOwner()->getParentOp(); |
| return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
| } |
| // The defining Op may live in an unlinked block so its parent Op may be null. |
| Operation *parentOp = value.getDefiningOp()->getParentOp(); |
| return parentOp && parentOp->hasTrait<OpTrait::AffineScope>(); |
| } |
| |
| /// Returns the closest region enclosing `op` that is held by an operation with |
| /// trait `AffineScope`; `nullptr` if there is no such region. |
| // TODO: getAffineScope should be publicly exposed for affine passes/utilities. |
| static Region *getAffineScope(Operation *op) { |
| auto *curOp = op; |
| while (auto *parentOp = curOp->getParentOp()) { |
| if (parentOp->hasTrait<OpTrait::AffineScope>()) |
| return curOp->getParentRegion(); |
| curOp = parentOp; |
| } |
| return nullptr; |
| } |
| |
| // A Value can be used as a dimension id iff it meets one of the following |
| // conditions: |
| // *) It is valid as a symbol. |
| // *) It is an induction variable. |
| // *) It is the result of affine apply operation with dimension id arguments. |
| bool mlir::isValidDim(Value value) { |
| // The value must be an index type. |
| if (!value.getType().isIndex()) |
| return false; |
| |
| if (auto *defOp = value.getDefiningOp()) |
| return isValidDim(value, getAffineScope(defOp)); |
| |
| // This value has to be a block argument for an op that has the |
| // `AffineScope` trait or for an affine.for or affine.parallel. |
| auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp(); |
| return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() || |
| isa<AffineForOp, AffineParallelOp>(parentOp)); |
| } |
| |
| // Value can be used as a dimension id iff it meets one of the following |
| // conditions: |
| // *) It is valid as a symbol. |
| // *) It is an induction variable. |
| // *) It is the result of an affine apply operation with dimension id operands. |
| bool mlir::isValidDim(Value value, Region *region) { |
| // The value must be an index type. |
| if (!value.getType().isIndex()) |
| return false; |
| |
| // All valid symbols are okay. |
| if (isValidSymbol(value, region)) |
| return true; |
| |
| auto *op = value.getDefiningOp(); |
| if (!op) { |
| // This value has to be a block argument for an affine.for or an |
| // affine.parallel. |
| auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp(); |
| return isa<AffineForOp, AffineParallelOp>(parentOp); |
| } |
| |
| // Affine apply operation is ok if all of its operands are ok. |
| if (auto applyOp = dyn_cast<AffineApplyOp>(op)) |
| return applyOp.isValidDim(region); |
| // The dim op is okay if its operand memref/tensor is defined at the top |
| // level. |
| if (auto dimOp = dyn_cast<memref::DimOp>(op)) |
| return isTopLevelValue(dimOp.source()); |
| if (auto dimOp = dyn_cast<tensor::DimOp>(op)) |
| return isTopLevelValue(dimOp.source()); |
| return false; |
| } |
| |
| /// Returns true if the 'index' dimension of the `memref` defined by |
| /// `memrefDefOp` is a statically shaped one or defined using a valid symbol |
| /// for `region`. |
| template <typename AnyMemRefDefOp> |
| static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, |
| Region *region) { |
| auto memRefType = memrefDefOp.getType(); |
| // Statically shaped. |
| if (!memRefType.isDynamicDim(index)) |
| return true; |
| // Get the position of the dimension among dynamic dimensions; |
| unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); |
| return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), |
| region); |
| } |
| |
| /// Returns true if the result of the dim op is a valid symbol for `region`. |
| template <typename OpTy> |
| static bool isDimOpValidSymbol(OpTy dimOp, Region *region) { |
| // The dim op is okay if its source is defined at the top level. |
| if (isTopLevelValue(dimOp.source())) |
| return true; |
| |
| // Conservatively handle remaining BlockArguments as non-valid symbols. |
| // E.g. scf.for iterArgs. |
| if (dimOp.source().template isa<BlockArgument>()) |
| return false; |
| |
| // The dim op is also okay if its operand memref is a view/subview whose |
| // corresponding size is a valid symbol. |
| Optional<int64_t> index = dimOp.getConstantIndex(); |
| assert(index.hasValue() && |
| "expect only `dim` operations with a constant index"); |
| int64_t i = index.getValue(); |
| return TypeSwitch<Operation *, bool>(dimOp.source().getDefiningOp()) |
| .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>( |
| [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) |
| .Default([](Operation *) { return false; }); |
| } |
| |
| // A value can be used as a symbol (at all its use sites) iff it meets one of |
| // the following conditions: |
| // *) It is a constant. |
| // *) Its defining op or block arg appearance is immediately enclosed by an op |
| // with `AffineScope` trait. |
| // *) It is the result of an affine.apply operation with symbol operands. |
| // *) It is a result of the dim op on a memref whose corresponding size is a |
| // valid symbol. |
| bool mlir::isValidSymbol(Value value) { |
| if (!value) |
| return false; |
| |
| // The value must be an index type. |
| if (!value.getType().isIndex()) |
| return false; |
| |
| // Check that the value is a top level value. |
| if (isTopLevelValue(value)) |
| return true; |
| |
| if (auto *defOp = value.getDefiningOp()) |
| return isValidSymbol(value, getAffineScope(defOp)); |
| |
| return false; |
| } |
| |
| /// A value can be used as a symbol for `region` iff it meets one of the |
| /// following conditions: |
| /// *) It is a constant. |
| /// *) It is the result of an affine apply operation with symbol arguments. |
| /// *) It is a result of the dim op on a memref whose corresponding size is |
| /// a valid symbol. |
| /// *) It is defined at the top level of 'region' or is its argument. |
| /// *) It dominates `region`'s parent op. |
| /// If `region` is null, conservatively assume the symbol definition scope does |
| /// not exist and only accept the values that would be symbols regardless of |
| /// the surrounding region structure, i.e. the first three cases above. |
| bool mlir::isValidSymbol(Value value, Region *region) { |
| // The value must be an index type. |
| if (!value.getType().isIndex()) |
| return false; |
| |
| // A top-level value is a valid symbol. |
| if (region && ::isTopLevelValue(value, region)) |
| return true; |
| |
| auto *defOp = value.getDefiningOp(); |
| if (!defOp) { |
| // A block argument that is not a top-level value is a valid symbol if it |
| // dominates region's parent op. |
| Operation *regionOp = region ? region->getParentOp() : nullptr; |
| if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) |
| return isValidSymbol(value, parentOpRegion); |
| return false; |
| } |
| |
| // Constant operation is ok. |
| Attribute operandCst; |
| if (matchPattern(defOp, m_Constant(&operandCst))) |
| return true; |
| |
| // Affine apply operation is ok if all of its operands are ok. |
| if (auto applyOp = dyn_cast<AffineApplyOp>(defOp)) |
| return applyOp.isValidSymbol(region); |
| |
| // Dim op results could be valid symbols at any level. |
| if (auto dimOp = dyn_cast<memref::DimOp>(defOp)) |
| return isDimOpValidSymbol(dimOp, region); |
| if (auto dimOp = dyn_cast<tensor::DimOp>(defOp)) |
| return isDimOpValidSymbol(dimOp, region); |
| |
| // Check for values dominating `region`'s parent op. |
| Operation *regionOp = region ? region->getParentOp() : nullptr; |
| if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) |
| if (auto *parentRegion = region->getParentOp()->getParentRegion()) |
| return isValidSymbol(value, parentRegion); |
| |
| return false; |
| } |
| |
| // Returns true if 'value' is a valid index to an affine operation (e.g. |
| // affine.load, affine.store, affine.dma_start, affine.dma_wait) where |
| // `region` provides the polyhedral symbol scope. Returns false otherwise. |
| static bool isValidAffineIndexOperand(Value value, Region *region) { |
| return isValidDim(value, region) || isValidSymbol(value, region); |
| } |
| |
| /// Prints dimension and symbol list. |
| static void printDimAndSymbolList(Operation::operand_iterator begin, |
| Operation::operand_iterator end, |
| unsigned numDims, OpAsmPrinter &printer) { |
| OperandRange operands(begin, end); |
| printer << '(' << operands.take_front(numDims) << ')'; |
| if (operands.size() > numDims) |
| printer << '[' << operands.drop_front(numDims) << ']'; |
| } |
| |
| /// Parses dimension and symbol list and returns true if parsing failed. |
| ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, |
| SmallVectorImpl<Value> &operands, |
| unsigned &numDims) { |
| SmallVector<OpAsmParser::OperandType, 8> opInfos; |
| if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) |
| return failure(); |
| // Store number of dimensions for validation by caller. |
| numDims = opInfos.size(); |
| |
| // Parse the optional symbol operands. |
| auto indexTy = parser.getBuilder().getIndexType(); |
| return failure(parser.parseOperandList( |
| opInfos, OpAsmParser::Delimiter::OptionalSquare) || |
| parser.resolveOperands(opInfos, indexTy, operands)); |
| } |
| |
| /// Utility function to verify that a set of operands are valid dimension and |
| /// symbol identifiers. The operands should be laid out such that the dimension |
| /// operands are before the symbol operands. This function returns failure if |
| /// there was an invalid operand. An operation is provided to emit any necessary |
| /// errors. |
| template <typename OpTy> |
| static LogicalResult |
| verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, |
| unsigned numDims) { |
| unsigned opIt = 0; |
| for (auto operand : operands) { |
| if (opIt++ < numDims) { |
| if (!isValidDim(operand, getAffineScope(op))) |
| return op.emitOpError("operand cannot be used as a dimension id"); |
| } else if (!isValidSymbol(operand, getAffineScope(op))) { |
| return op.emitOpError("operand cannot be used as a symbol"); |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineApplyOp |
| //===----------------------------------------------------------------------===// |
| |
| AffineValueMap AffineApplyOp::getAffineValueMap() { |
| return AffineValueMap(getAffineMap(), getOperands(), getResult()); |
| } |
| |
| static ParseResult parseAffineApplyOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexTy = builder.getIndexType(); |
| |
| AffineMapAttr mapAttr; |
| unsigned numDims; |
| if (parser.parseAttribute(mapAttr, "map", result.attributes) || |
| parseDimAndSymbolList(parser, result.operands, numDims) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| auto map = mapAttr.getValue(); |
| |
| if (map.getNumDims() != numDims || |
| numDims + map.getNumSymbols() != result.operands.size()) { |
| return parser.emitError(parser.getNameLoc(), |
| "dimension or symbol index mismatch"); |
| } |
| |
| result.types.append(map.getNumResults(), indexTy); |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineApplyOp op) { |
| p << " " << op.mapAttr(); |
| printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
| op.getAffineMap().getNumDims(), p); |
| p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"}); |
| } |
| |
| static LogicalResult verify(AffineApplyOp op) { |
| // Check input and output dimensions match. |
| auto map = op.map(); |
| |
| // Verify that operand count matches affine map dimension and symbol count. |
| if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) |
| return op.emitOpError( |
| "operand count and affine map dimension and symbol count must match"); |
| |
| // Verify that the map only produces one result. |
| if (map.getNumResults() != 1) |
| return op.emitOpError("mapping must produce one value"); |
| |
| return success(); |
| } |
| |
| // The result of the affine apply operation can be used as a dimension id if all |
| // its operands are valid dimension ids. |
| bool AffineApplyOp::isValidDim() { |
| return llvm::all_of(getOperands(), |
| [](Value op) { return mlir::isValidDim(op); }); |
| } |
| |
| // The result of the affine apply operation can be used as a dimension id if all |
| // its operands are valid dimension ids with the parent operation of `region` |
| // defining the polyhedral scope for symbols. |
| bool AffineApplyOp::isValidDim(Region *region) { |
| return llvm::all_of(getOperands(), |
| [&](Value op) { return ::isValidDim(op, region); }); |
| } |
| |
| // The result of the affine apply operation can be used as a symbol if all its |
| // operands are symbols. |
| bool AffineApplyOp::isValidSymbol() { |
| return llvm::all_of(getOperands(), |
| [](Value op) { return mlir::isValidSymbol(op); }); |
| } |
| |
| // The result of the affine apply operation can be used as a symbol in `region` |
| // if all its operands are symbols in `region`. |
| bool AffineApplyOp::isValidSymbol(Region *region) { |
| return llvm::all_of(getOperands(), [&](Value operand) { |
| return mlir::isValidSymbol(operand, region); |
| }); |
| } |
| |
| OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) { |
| auto map = getAffineMap(); |
| |
| // Fold dims and symbols to existing values. |
| auto expr = map.getResult(0); |
| if (auto dim = expr.dyn_cast<AffineDimExpr>()) |
| return getOperand(dim.getPosition()); |
| if (auto sym = expr.dyn_cast<AffineSymbolExpr>()) |
| return getOperand(map.getNumDims() + sym.getPosition()); |
| |
| // Otherwise, default to folding the map. |
| SmallVector<Attribute, 1> result; |
| if (failed(map.constantFold(operands, result))) |
| return {}; |
| return result[0]; |
| } |
| |
| /// Replace all occurrences of AffineExpr at position `pos` in `map` by the |
| /// defining AffineApplyOp expression and operands. |
| /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced. |
| /// When `dimOrSymbolPosition >= dims.size()`, |
| /// AffineSymbolExpr@[pos - dims.size()] is replaced. |
| /// Mutate `map`,`dims` and `syms` in place as follows: |
| /// 1. `dims` and `syms` are only appended to. |
| /// 2. `map` dim and symbols are gradually shifted to higer positions. |
| /// 3. Old `dim` and `sym` entries are replaced by nullptr |
| /// This avoids the need for any bookkeeping. |
| static LogicalResult replaceDimOrSym(AffineMap *map, |
| unsigned dimOrSymbolPosition, |
| SmallVectorImpl<Value> &dims, |
| SmallVectorImpl<Value> &syms) { |
| bool isDimReplacement = (dimOrSymbolPosition < dims.size()); |
| unsigned pos = isDimReplacement ? dimOrSymbolPosition |
| : dimOrSymbolPosition - dims.size(); |
| Value &v = isDimReplacement ? dims[pos] : syms[pos]; |
| if (!v) |
| return failure(); |
| |
| auto affineApply = v.getDefiningOp<AffineApplyOp>(); |
| if (!affineApply) |
| return failure(); |
| |
| // At this point we will perform a replacement of `v`, set the entry in `dim` |
| // or `sym` to nullptr immediately. |
| v = nullptr; |
| |
| // Compute the map, dims and symbols coming from the AffineApplyOp. |
| AffineMap composeMap = affineApply.getAffineMap(); |
| assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results"); |
| AffineExpr composeExpr = |
| composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0); |
| ValueRange composeDims = |
| affineApply.getMapOperands().take_front(composeMap.getNumDims()); |
| ValueRange composeSyms = |
| affineApply.getMapOperands().take_back(composeMap.getNumSymbols()); |
| |
| // Append the dims and symbols where relevant and perform the replacement. |
| MLIRContext *ctx = map->getContext(); |
| AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx) |
| : getAffineSymbolExpr(pos, ctx); |
| dims.append(composeDims.begin(), composeDims.end()); |
| syms.append(composeSyms.begin(), composeSyms.end()); |
| *map = map->replace(toReplace, composeExpr, dims.size(), syms.size()); |
| |
| return success(); |
| } |
| |
| /// Iterate over `operands` and fold away all those produced by an AffineApplyOp |
| /// iteratively. Perform canonicalization of map and operands as well as |
| /// AffineMap simplification. `map` and `operands` are mutated in place. |
| static void composeAffineMapAndOperands(AffineMap *map, |
| SmallVectorImpl<Value> *operands) { |
| if (map->getNumResults() == 0) { |
| canonicalizeMapAndOperands(map, operands); |
| *map = simplifyAffineMap(*map); |
| return; |
| } |
| |
| MLIRContext *ctx = map->getContext(); |
| SmallVector<Value, 4> dims(operands->begin(), |
| operands->begin() + map->getNumDims()); |
| SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(), |
| operands->end()); |
| |
| // Iterate over dims and symbols coming from AffineApplyOp and replace until |
| // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims` |
| // and `syms` can only increase by construction. |
| // The implementation uses a `while` loop to support the case of symbols |
| // that may be constructed from dims ;this may be overkill. |
| while (true) { |
| bool changed = false; |
| for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos) |
| if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms)))) |
| break; |
| if (!changed) |
| break; |
| } |
| |
| // Clear operands so we can fill them anew. |
| operands->clear(); |
| |
| // At this point we may have introduced null operands, prune them out before |
| // canonicalizing map and operands. |
| unsigned nDims = 0, nSyms = 0; |
| SmallVector<AffineExpr, 4> dimReplacements, symReplacements; |
| dimReplacements.reserve(dims.size()); |
| symReplacements.reserve(syms.size()); |
| for (auto *container : {&dims, &syms}) { |
| bool isDim = (container == &dims); |
| auto &repls = isDim ? dimReplacements : symReplacements; |
| for (auto en : llvm::enumerate(*container)) { |
| Value v = en.value(); |
| if (!v) { |
| assert(isDim ? !map->isFunctionOfDim(en.index()) |
| : !map->isFunctionOfSymbol(en.index()) && |
| "map is function of unexpected expr@pos"); |
| repls.push_back(getAffineConstantExpr(0, ctx)); |
| continue; |
| } |
| repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx) |
| : getAffineSymbolExpr(nSyms++, ctx)); |
| operands->push_back(v); |
| } |
| } |
| *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims, |
| nSyms); |
| |
| // Canonicalize and simplify before returning. |
| canonicalizeMapAndOperands(map, operands); |
| *map = simplifyAffineMap(*map); |
| } |
| |
| void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, |
| SmallVectorImpl<Value> *operands) { |
| while (llvm::any_of(*operands, [](Value v) { |
| return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp()); |
| })) { |
| composeAffineMapAndOperands(map, operands); |
| } |
| } |
| |
| AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, |
| AffineMap map, |
| ValueRange operands) { |
| AffineMap normalizedMap = map; |
| SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end()); |
| composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); |
| assert(normalizedMap); |
| return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands); |
| } |
| |
| AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, |
| AffineExpr e, ValueRange values) { |
| return makeComposedAffineApply( |
| b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(), |
| values); |
| } |
| |
| // A symbol may appear as a dim in affine.apply operations. This function |
| // canonicalizes dims that are valid symbols into actual symbols. |
| template <class MapOrSet> |
| static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, |
| SmallVectorImpl<Value> *operands) { |
| if (!mapOrSet || operands->empty()) |
| return; |
| |
| assert(mapOrSet->getNumInputs() == operands->size() && |
| "map/set inputs must match number of operands"); |
| |
| auto *context = mapOrSet->getContext(); |
| SmallVector<Value, 8> resultOperands; |
| resultOperands.reserve(operands->size()); |
| SmallVector<Value, 8> remappedSymbols; |
| remappedSymbols.reserve(operands->size()); |
| unsigned nextDim = 0; |
| unsigned nextSym = 0; |
| unsigned oldNumSyms = mapOrSet->getNumSymbols(); |
| SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
| for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { |
| if (i < mapOrSet->getNumDims()) { |
| if (isValidSymbol((*operands)[i])) { |
| // This is a valid symbol that appears as a dim, canonicalize it. |
| dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); |
| remappedSymbols.push_back((*operands)[i]); |
| } else { |
| dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
| resultOperands.push_back((*operands)[i]); |
| } |
| } else { |
| resultOperands.push_back((*operands)[i]); |
| } |
| } |
| |
| resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); |
| *operands = resultOperands; |
| *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, |
| oldNumSyms + nextSym); |
| |
| assert(mapOrSet->getNumInputs() == operands->size() && |
| "map/set inputs must match number of operands"); |
| } |
| |
| // Works for either an affine map or an integer set. |
| template <class MapOrSet> |
| static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, |
| SmallVectorImpl<Value> *operands) { |
| static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value, |
| "Argument must be either of AffineMap or IntegerSet type"); |
| |
| if (!mapOrSet || operands->empty()) |
| return; |
| |
| assert(mapOrSet->getNumInputs() == operands->size() && |
| "map/set inputs must match number of operands"); |
| |
| canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands); |
| |
| // Check to see what dims are used. |
| llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); |
| llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); |
| mapOrSet->walkExprs([&](AffineExpr expr) { |
| if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) |
| usedDims[dimExpr.getPosition()] = true; |
| else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) |
| usedSyms[symExpr.getPosition()] = true; |
| }); |
| |
| auto *context = mapOrSet->getContext(); |
| |
| SmallVector<Value, 8> resultOperands; |
| resultOperands.reserve(operands->size()); |
| |
| llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims; |
| SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims()); |
| unsigned nextDim = 0; |
| for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { |
| if (usedDims[i]) { |
| // Remap dim positions for duplicate operands. |
| auto it = seenDims.find((*operands)[i]); |
| if (it == seenDims.end()) { |
| dimRemapping[i] = getAffineDimExpr(nextDim++, context); |
| resultOperands.push_back((*operands)[i]); |
| seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); |
| } else { |
| dimRemapping[i] = it->second; |
| } |
| } |
| } |
| llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols; |
| SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols()); |
| unsigned nextSym = 0; |
| for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { |
| if (!usedSyms[i]) |
| continue; |
| // Handle constant operands (only needed for symbolic operands since |
| // constant operands in dimensional positions would have already been |
| // promoted to symbolic positions above). |
| IntegerAttr operandCst; |
| if (matchPattern((*operands)[i + mapOrSet->getNumDims()], |
| m_Constant(&operandCst))) { |
| symRemapping[i] = |
| getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); |
| continue; |
| } |
| // Remap symbol positions for duplicate operands. |
| auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); |
| if (it == seenSymbols.end()) { |
| symRemapping[i] = getAffineSymbolExpr(nextSym++, context); |
| resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); |
| seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], |
| symRemapping[i])); |
| } else { |
| symRemapping[i] = it->second; |
| } |
| } |
| *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, |
| nextDim, nextSym); |
| *operands = resultOperands; |
| } |
| |
| void mlir::canonicalizeMapAndOperands(AffineMap *map, |
| SmallVectorImpl<Value> *operands) { |
| canonicalizeMapOrSetAndOperands<AffineMap>(map, operands); |
| } |
| |
| void mlir::canonicalizeSetAndOperands(IntegerSet *set, |
| SmallVectorImpl<Value> *operands) { |
| canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands); |
| } |
| |
| namespace { |
| /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing |
| /// maps that supply results into them. |
| /// |
| template <typename AffineOpTy> |
| struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> { |
| using OpRewritePattern<AffineOpTy>::OpRewritePattern; |
| |
| /// Replace the affine op with another instance of it with the supplied |
| /// map and mapOperands. |
| void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, |
| AffineMap map, ArrayRef<Value> mapOperands) const; |
| |
| LogicalResult matchAndRewrite(AffineOpTy affineOp, |
| PatternRewriter &rewriter) const override { |
| static_assert( |
| llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp, |
| AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp, |
| AffineVectorStoreOp, AffineVectorLoadOp>::value, |
| "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op " |
| "expected"); |
| auto map = affineOp.getAffineMap(); |
| AffineMap oldMap = map; |
| auto oldOperands = affineOp.getMapOperands(); |
| SmallVector<Value, 8> resultOperands(oldOperands); |
| composeAffineMapAndOperands(&map, &resultOperands); |
| canonicalizeMapAndOperands(&map, &resultOperands); |
| if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), |
| resultOperands.begin())) |
| return failure(); |
| |
| replaceAffineOp(rewriter, affineOp, map, resultOperands); |
| return success(); |
| } |
| }; |
| |
| // Specialize the template to account for the different build signatures for |
| // affine load, store, and apply ops. |
| template <> |
| void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp( |
| PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map, |
| mapOperands); |
| } |
| template <> |
| void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp( |
| PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffinePrefetchOp>( |
| prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(), |
| prefetch.isWrite(), prefetch.isDataCache()); |
| } |
| template <> |
| void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp( |
| PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffineStoreOp>( |
| store, store.getValueToStore(), store.getMemRef(), map, mapOperands); |
| } |
| template <> |
| void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp( |
| PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffineVectorLoadOp>( |
| vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map, |
| mapOperands); |
| } |
| template <> |
| void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp( |
| PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffineVectorStoreOp>( |
| vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map, |
| mapOperands); |
| } |
| |
| // Generic version for ops that don't have extra operands. |
| template <typename AffineOpTy> |
| void SimplifyAffineOp<AffineOpTy>::replaceAffineOp( |
| PatternRewriter &rewriter, AffineOpTy op, AffineMap map, |
| ArrayRef<Value> mapOperands) const { |
| rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands); |
| } |
| } // end anonymous namespace. |
| |
| void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<SimplifyAffineOp<AffineApplyOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common canonicalization pattern support logic |
| //===----------------------------------------------------------------------===// |
| |
| /// This is a common class used for patterns of the form |
| /// "someop(memrefcast) -> someop". It folds the source of any memref.cast |
| /// into the root operation directly. |
| static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) { |
| bool folded = false; |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto cast = operand.get().getDefiningOp<memref::CastOp>(); |
| if (cast && operand.get() != ignore && |
| !cast.getOperand().getType().isa<UnrankedMemRefType>()) { |
| operand.set(cast.getOperand()); |
| folded = true; |
| } |
| } |
| return success(folded); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineDmaStartOp |
| //===----------------------------------------------------------------------===// |
| |
| // TODO: Check that map operands are loop IVs or symbols. |
| void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, |
| Value srcMemRef, AffineMap srcMap, |
| ValueRange srcIndices, Value destMemRef, |
| AffineMap dstMap, ValueRange destIndices, |
| Value tagMemRef, AffineMap tagMap, |
| ValueRange tagIndices, Value numElements, |
| Value stride, Value elementsPerStride) { |
| result.addOperands(srcMemRef); |
| result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); |
| result.addOperands(srcIndices); |
| result.addOperands(destMemRef); |
| result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap)); |
| result.addOperands(destIndices); |
| result.addOperands(tagMemRef); |
| result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); |
| result.addOperands(tagIndices); |
| result.addOperands(numElements); |
| if (stride) { |
| result.addOperands({stride, elementsPerStride}); |
| } |
| } |
| |
| void AffineDmaStartOp::print(OpAsmPrinter &p) { |
| p << " " << getSrcMemRef() << '['; |
| p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); |
| p << "], " << getDstMemRef() << '['; |
| p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); |
| p << "], " << getTagMemRef() << '['; |
| p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); |
| p << "], " << getNumElements(); |
| if (isStrided()) { |
| p << ", " << getStride(); |
| p << ", " << getNumElementsPerStride(); |
| } |
| p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " |
| << getTagMemRefType(); |
| } |
| |
| // Parse AffineDmaStartOp. |
| // Ex: |
| // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, |
| // %stride, %num_elt_per_stride |
| // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> |
| // |
| ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType srcMemRefInfo; |
| AffineMapAttr srcMapAttr; |
| SmallVector<OpAsmParser::OperandType, 4> srcMapOperands; |
| OpAsmParser::OperandType dstMemRefInfo; |
| AffineMapAttr dstMapAttr; |
| SmallVector<OpAsmParser::OperandType, 4> dstMapOperands; |
| OpAsmParser::OperandType tagMemRefInfo; |
| AffineMapAttr tagMapAttr; |
| SmallVector<OpAsmParser::OperandType, 4> tagMapOperands; |
| OpAsmParser::OperandType numElementsInfo; |
| SmallVector<OpAsmParser::OperandType, 2> strideInfo; |
| |
| SmallVector<Type, 3> types; |
| auto indexType = parser.getBuilder().getIndexType(); |
| |
| // Parse and resolve the following list of operands: |
| // *) dst memref followed by its affine maps operands (in square brackets). |
| // *) src memref followed by its affine map operands (in square brackets). |
| // *) tag memref followed by its affine map operands (in square brackets). |
| // *) number of elements transferred by DMA operation. |
| if (parser.parseOperand(srcMemRefInfo) || |
| parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, |
| getSrcMapAttrName(), result.attributes) || |
| parser.parseComma() || parser.parseOperand(dstMemRefInfo) || |
| parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, |
| getDstMapAttrName(), result.attributes) || |
| parser.parseComma() || parser.parseOperand(tagMemRefInfo) || |
| parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, |
| getTagMapAttrName(), result.attributes) || |
| parser.parseComma() || parser.parseOperand(numElementsInfo)) |
| return failure(); |
| |
| // Parse optional stride and elements per stride. |
| if (parser.parseTrailingOperandList(strideInfo)) { |
| return failure(); |
| } |
| if (!strideInfo.empty() && strideInfo.size() != 2) { |
| return parser.emitError(parser.getNameLoc(), |
| "expected two stride related operands"); |
| } |
| bool isStrided = strideInfo.size() == 2; |
| |
| if (parser.parseColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() != 3) |
| return parser.emitError(parser.getNameLoc(), "expected three types"); |
| |
| if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || |
| parser.resolveOperands(srcMapOperands, indexType, result.operands) || |
| parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || |
| parser.resolveOperands(dstMapOperands, indexType, result.operands) || |
| parser.resolveOperand(tagMemRefInfo, types[2], result.operands) || |
| parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
| parser.resolveOperand(numElementsInfo, indexType, result.operands)) |
| return failure(); |
| |
| if (isStrided) { |
| if (parser.resolveOperands(strideInfo, indexType, result.operands)) |
| return failure(); |
| } |
| |
| // Check that src/dst/tag operand counts match their map.numInputs. |
| if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || |
| dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || |
| tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
| return parser.emitError(parser.getNameLoc(), |
| "memref operand count not equal to map.numInputs"); |
| return success(); |
| } |
| |
| LogicalResult AffineDmaStartOp::verify() { |
| if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>()) |
| return emitOpError("expected DMA source to be of memref type"); |
| if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>()) |
| return emitOpError("expected DMA destination to be of memref type"); |
| if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>()) |
| return emitOpError("expected DMA tag to be of memref type"); |
| |
| unsigned numInputsAllMaps = getSrcMap().getNumInputs() + |
| getDstMap().getNumInputs() + |
| getTagMap().getNumInputs(); |
| if (getNumOperands() != numInputsAllMaps + 3 + 1 && |
| getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { |
| return emitOpError("incorrect number of operands"); |
| } |
| |
| Region *scope = getAffineScope(*this); |
| for (auto idx : getSrcIndices()) { |
| if (!idx.getType().isIndex()) |
| return emitOpError("src index to dma_start must have 'index' type"); |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return emitOpError("src index must be a dimension or symbol identifier"); |
| } |
| for (auto idx : getDstIndices()) { |
| if (!idx.getType().isIndex()) |
| return emitOpError("dst index to dma_start must have 'index' type"); |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return emitOpError("dst index must be a dimension or symbol identifier"); |
| } |
| for (auto idx : getTagIndices()) { |
| if (!idx.getType().isIndex()) |
| return emitOpError("tag index to dma_start must have 'index' type"); |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return emitOpError("tag index must be a dimension or symbol identifier"); |
| } |
| return success(); |
| } |
| |
| LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| /// dma_start(memrefcast) -> dma_start |
| return foldMemRefCast(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineDmaWaitOp |
| //===----------------------------------------------------------------------===// |
| |
| // TODO: Check that map operands are loop IVs or symbols. |
| void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, |
| Value tagMemRef, AffineMap tagMap, |
| ValueRange tagIndices, Value numElements) { |
| result.addOperands(tagMemRef); |
| result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); |
| result.addOperands(tagIndices); |
| result.addOperands(numElements); |
| } |
| |
| void AffineDmaWaitOp::print(OpAsmPrinter &p) { |
| p << " " << getTagMemRef() << '['; |
| SmallVector<Value, 2> operands(getTagIndices()); |
| p.printAffineMapOfSSAIds(getTagMapAttr(), operands); |
| p << "], "; |
| p.printOperand(getNumElements()); |
| p << " : " << getTagMemRef().getType(); |
| } |
| |
| // Parse AffineDmaWaitOp. |
| // Eg: |
| // affine.dma_wait %tag[%index], %num_elements |
| // : memref<1 x i32, (d0) -> (d0), 4> |
| // |
| ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::OperandType tagMemRefInfo; |
| AffineMapAttr tagMapAttr; |
| SmallVector<OpAsmParser::OperandType, 2> tagMapOperands; |
| Type type; |
| auto indexType = parser.getBuilder().getIndexType(); |
| OpAsmParser::OperandType numElementsInfo; |
| |
| // Parse tag memref, its map operands, and dma size. |
| if (parser.parseOperand(tagMemRefInfo) || |
| parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, |
| getTagMapAttrName(), result.attributes) || |
| parser.parseComma() || parser.parseOperand(numElementsInfo) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(tagMemRefInfo, type, result.operands) || |
| parser.resolveOperands(tagMapOperands, indexType, result.operands) || |
| parser.resolveOperand(numElementsInfo, indexType, result.operands)) |
| return failure(); |
| |
| if (!type.isa<MemRefType>()) |
| return parser.emitError(parser.getNameLoc(), |
| "expected tag to be of memref type"); |
| |
| if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) |
| return parser.emitError(parser.getNameLoc(), |
| "tag memref operand count != to map.numInputs"); |
| return success(); |
| } |
| |
| LogicalResult AffineDmaWaitOp::verify() { |
| if (!getOperand(0).getType().isa<MemRefType>()) |
| return emitOpError("expected DMA tag to be of memref type"); |
| Region *scope = getAffineScope(*this); |
| for (auto idx : getTagIndices()) { |
| if (!idx.getType().isIndex()) |
| return emitOpError("index to dma_wait must have 'index' type"); |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return emitOpError("index must be a dimension or symbol identifier"); |
| } |
| return success(); |
| } |
| |
| LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| /// dma_wait(memrefcast) -> dma_wait |
| return foldMemRefCast(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineForOp |
| //===----------------------------------------------------------------------===// |
| |
| /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and |
| /// bodyBuilder are empty/null, we include default terminator op. |
| void AffineForOp::build(OpBuilder &builder, OperationState &result, |
| ValueRange lbOperands, AffineMap lbMap, |
| ValueRange ubOperands, AffineMap ubMap, int64_t step, |
| ValueRange iterArgs, BodyBuilderFn bodyBuilder) { |
| assert(((!lbMap && lbOperands.empty()) || |
| lbOperands.size() == lbMap.getNumInputs()) && |
| "lower bound operand count does not match the affine map"); |
| assert(((!ubMap && ubOperands.empty()) || |
| ubOperands.size() == ubMap.getNumInputs()) && |
| "upper bound operand count does not match the affine map"); |
| assert(step > 0 && "step has to be a positive integer constant"); |
| |
| for (Value val : iterArgs) |
| result.addTypes(val.getType()); |
| |
| // Add an attribute for the step. |
| result.addAttribute(getStepAttrName(), |
| builder.getIntegerAttr(builder.getIndexType(), step)); |
| |
| // Add the lower bound. |
| result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap)); |
| result.addOperands(lbOperands); |
| |
| // Add the upper bound. |
| result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap)); |
| result.addOperands(ubOperands); |
| |
| result.addOperands(iterArgs); |
| // Create a region and a block for the body. The argument of the region is |
| // the loop induction variable. |
| Region *bodyRegion = result.addRegion(); |
| bodyRegion->push_back(new Block); |
| Block &bodyBlock = bodyRegion->front(); |
| Value inductionVar = bodyBlock.addArgument(builder.getIndexType()); |
| for (Value val : iterArgs) |
| bodyBlock.addArgument(val.getType()); |
| |
| // Create the default terminator if the builder is not provided and if the |
| // iteration arguments are not provided. Otherwise, leave this to the caller |
| // because we don't know which values to return from the loop. |
| if (iterArgs.empty() && !bodyBuilder) { |
| ensureTerminator(*bodyRegion, builder, result.location); |
| } else if (bodyBuilder) { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.setInsertionPointToStart(&bodyBlock); |
| bodyBuilder(builder, result.location, inductionVar, |
| bodyBlock.getArguments().drop_front()); |
| } |
| } |
| |
| void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, |
| int64_t ub, int64_t step, ValueRange iterArgs, |
| BodyBuilderFn bodyBuilder) { |
| auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); |
| auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); |
| return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, |
| bodyBuilder); |
| } |
| |
| static LogicalResult verify(AffineForOp op) { |
| // Check that the body defines as single block argument for the induction |
| // variable. |
| auto *body = op.getBody(); |
| if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) |
| return op.emitOpError( |
| "expected body to have a single index argument for the " |
| "induction variable"); |
| |
| // Verify that the bound operands are valid dimension/symbols. |
| /// Lower bound. |
| if (op.getLowerBoundMap().getNumInputs() > 0) |
| if (failed( |
| verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), |
| op.getLowerBoundMap().getNumDims()))) |
| return failure(); |
| /// Upper bound. |
| if (op.getUpperBoundMap().getNumInputs() > 0) |
| if (failed( |
| verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), |
| op.getUpperBoundMap().getNumDims()))) |
| return failure(); |
| |
| unsigned opNumResults = op.getNumResults(); |
| if (opNumResults == 0) |
| return success(); |
| |
| // If ForOp defines values, check that the number and types of the defined |
| // values match ForOp initial iter operands and backedge basic block |
| // arguments. |
| if (op.getNumIterOperands() != opNumResults) |
| return op.emitOpError( |
| "mismatch between the number of loop-carried values and results"); |
| if (op.getNumRegionIterArgs() != opNumResults) |
| return op.emitOpError( |
| "mismatch between the number of basic block args and results"); |
| |
| return success(); |
| } |
| |
| /// Parse a for operation loop bounds. |
| static ParseResult parseBound(bool isLower, OperationState &result, |
| OpAsmParser &p) { |
| // 'min' / 'max' prefixes are generally syntactic sugar, but are required if |
| // the map has multiple results. |
| bool failedToParsedMinMax = |
| failed(p.parseOptionalKeyword(isLower ? "max" : "min")); |
| |
| auto &builder = p.getBuilder(); |
| auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() |
| : AffineForOp::getUpperBoundAttrName(); |
| |
| // Parse ssa-id as identity map. |
| SmallVector<OpAsmParser::OperandType, 1> boundOpInfos; |
| if (p.parseOperandList(boundOpInfos)) |
| return failure(); |
| |
| if (!boundOpInfos.empty()) { |
| // Check that only one operand was parsed. |
| if (boundOpInfos.size() > 1) |
| return p.emitError(p.getNameLoc(), |
| "expected only one loop bound operand"); |
| |
| // TODO: improve error message when SSA value is not of index type. |
| // Currently it is 'use of value ... expects different type than prior uses' |
| if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), |
| result.operands)) |
| return failure(); |
| |
| // Create an identity map using symbol id. This representation is optimized |
| // for storage. Analysis passes may expand it into a multi-dimensional map |
| // if desired. |
| AffineMap map = builder.getSymbolIdentityMap(); |
| result.addAttribute(boundAttrName, AffineMapAttr::get(map)); |
| return success(); |
| } |
| |
| // Get the attribute location. |
| llvm::SMLoc attrLoc = p.getCurrentLocation(); |
| |
| Attribute boundAttr; |
| if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, |
| result.attributes)) |
| return failure(); |
| |
| // Parse full form - affine map followed by dim and symbol list. |
| if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) { |
| unsigned currentNumOperands = result.operands.size(); |
| unsigned numDims; |
| if (parseDimAndSymbolList(p, result.operands, numDims)) |
| return failure(); |
| |
| auto map = affineMapAttr.getValue(); |
| if (map.getNumDims() != numDims) |
| return p.emitError( |
| p.getNameLoc(), |
| "dim operand count and affine map dim count must match"); |
| |
| unsigned numDimAndSymbolOperands = |
| result.operands.size() - currentNumOperands; |
| if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) |
| return p.emitError( |
| p.getNameLoc(), |
| "symbol operand count and affine map symbol count must match"); |
| |
| // If the map has multiple results, make sure that we parsed the min/max |
| // prefix. |
| if (map.getNumResults() > 1 && failedToParsedMinMax) { |
| if (isLower) { |
| return p.emitError(attrLoc, "lower loop bound affine map with " |
| "multiple results requires 'max' prefix"); |
| } |
| return p.emitError(attrLoc, "upper loop bound affine map with multiple " |
| "results requires 'min' prefix"); |
| } |
| return success(); |
| } |
| |
| // Parse custom assembly form. |
| if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) { |
| result.attributes.pop_back(); |
| result.addAttribute( |
| boundAttrName, |
| AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); |
| return success(); |
| } |
| |
| return p.emitError( |
| p.getNameLoc(), |
| "expected valid affine map representation for loop bounds"); |
| } |
| |
| static ParseResult parseAffineForOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| OpAsmParser::OperandType inductionVariable; |
| // Parse the induction variable followed by '='. |
| if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) |
| return failure(); |
| |
| // Parse loop bounds. |
| if (parseBound(/*isLower=*/true, result, parser) || |
| parser.parseKeyword("to", " between bounds") || |
| parseBound(/*isLower=*/false, result, parser)) |
| return failure(); |
| |
| // Parse the optional loop step, we default to 1 if one is not present. |
| if (parser.parseOptionalKeyword("step")) { |
| result.addAttribute( |
| AffineForOp::getStepAttrName(), |
| builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); |
| } else { |
| llvm::SMLoc stepLoc = parser.getCurrentLocation(); |
| IntegerAttr stepAttr; |
| if (parser.parseAttribute(stepAttr, builder.getIndexType(), |
| AffineForOp::getStepAttrName().data(), |
| result.attributes)) |
| return failure(); |
| |
| if (stepAttr.getValue().getSExtValue() < 0) |
| return parser.emitError( |
| stepLoc, |
| "expected step to be representable as a positive signed integer"); |
| } |
| |
| // Parse the optional initial iteration arguments. |
| SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands; |
| SmallVector<Type, 4> argTypes; |
| regionArgs.push_back(inductionVariable); |
| |
| if (succeeded(parser.parseOptionalKeyword("iter_args"))) { |
| // Parse assignment list and results type list. |
| if (parser.parseAssignmentList(regionArgs, operands) || |
| parser.parseArrowTypeList(result.types)) |
| return failure(); |
| // Resolve input operands. |
| for (auto operandType : llvm::zip(operands, result.types)) |
| if (parser.resolveOperand(std::get<0>(operandType), |
| std::get<1>(operandType), result.operands)) |
| return failure(); |
| } |
| // Induction variable. |
| Type indexType = builder.getIndexType(); |
| argTypes.push_back(indexType); |
| // Loop carried variables. |
| argTypes.append(result.types.begin(), result.types.end()); |
| // Parse the body region. |
| Region *body = result.addRegion(); |
| if (regionArgs.size() != argTypes.size()) |
| return parser.emitError( |
| parser.getNameLoc(), |
| "mismatch between the number of loop-carried values and results"); |
| if (parser.parseRegion(*body, regionArgs, argTypes)) |
| return failure(); |
| |
| AffineForOp::ensureTerminator(*body, builder, result.location); |
| |
| // Parse the optional attribute list. |
| return parser.parseOptionalAttrDict(result.attributes); |
| } |
| |
| static void printBound(AffineMapAttr boundMap, |
| Operation::operand_range boundOperands, |
| const char *prefix, OpAsmPrinter &p) { |
| AffineMap map = boundMap.getValue(); |
| |
| // Check if this bound should be printed using custom assembly form. |
| // The decision to restrict printing custom assembly form to trivial cases |
| // comes from the will to roundtrip MLIR binary -> text -> binary in a |
| // lossless way. |
| // Therefore, custom assembly form parsing and printing is only supported for |
| // zero-operand constant maps and single symbol operand identity maps. |
| if (map.getNumResults() == 1) { |
| AffineExpr expr = map.getResult(0); |
| |
| // Print constant bound. |
| if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { |
| if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) { |
| p << constExpr.getValue(); |
| return; |
| } |
| } |
| |
| // Print bound that consists of a single SSA symbol if the map is over a |
| // single symbol. |
| if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { |
| if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { |
| p.printOperand(*boundOperands.begin()); |
| return; |
| } |
| } |
| } else { |
| // Map has multiple results. Print 'min' or 'max' prefix. |
| p << prefix << ' '; |
| } |
| |
| // Print the map and its operands. |
| p << boundMap; |
| printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), |
| map.getNumDims(), p); |
| } |
| |
| unsigned AffineForOp::getNumIterOperands() { |
| AffineMap lbMap = getLowerBoundMapAttr().getValue(); |
| AffineMap ubMap = getUpperBoundMapAttr().getValue(); |
| |
| return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineForOp op) { |
| p << ' '; |
| p.printOperand(op.getBody()->getArgument(0)); |
| p << " = "; |
| printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); |
| p << " to "; |
| printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); |
| |
| if (op.getStep() != 1) |
| p << " step " << op.getStep(); |
| |
| bool printBlockTerminators = false; |
| if (op.getNumIterOperands() > 0) { |
| p << " iter_args("; |
| auto regionArgs = op.getRegionIterArgs(); |
| auto operands = op.getIterOperands(); |
| |
| llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { |
| p << std::get<0>(it) << " = " << std::get<1>(it); |
| }); |
| p << ") -> (" << op.getResultTypes() << ")"; |
| printBlockTerminators = true; |
| } |
| |
| p.printRegion(op.region(), |
| /*printEntryBlockArgs=*/false, printBlockTerminators); |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{op.getLowerBoundAttrName(), |
| op.getUpperBoundAttrName(), |
| op.getStepAttrName()}); |
| } |
| |
| /// Fold the constant bounds of a loop. |
| static LogicalResult foldLoopBounds(AffineForOp forOp) { |
| auto foldLowerOrUpperBound = [&forOp](bool lower) { |
| // Check to see if each of the operands is the result of a constant. If |
| // so, get the value. If not, ignore it. |
| SmallVector<Attribute, 8> operandConstants; |
| auto boundOperands = |
| lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); |
| for (auto operand : boundOperands) { |
| Attribute operandCst; |
| matchPattern(operand, m_Constant(&operandCst)); |
| operandConstants.push_back(operandCst); |
| } |
| |
| AffineMap boundMap = |
| lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); |
| assert(boundMap.getNumResults() >= 1 && |
| "bound maps should have at least one result"); |
| SmallVector<Attribute, 4> foldedResults; |
| if (failed(boundMap.constantFold(operandConstants, foldedResults))) |
| return failure(); |
| |
| // Compute the max or min as applicable over the results. |
| assert(!foldedResults.empty() && "bounds should have at least one result"); |
| auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue(); |
| for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { |
| auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue(); |
| maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) |
| : llvm::APIntOps::smin(maxOrMin, foldedResult); |
| } |
| lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) |
| : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); |
| return success(); |
| }; |
| |
| // Try to fold the lower bound. |
| bool folded = false; |
| if (!forOp.hasConstantLowerBound()) |
| folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); |
| |
| // Try to fold the upper bound. |
| if (!forOp.hasConstantUpperBound()) |
| folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); |
| return success(folded); |
| } |
| |
| /// Canonicalize the bounds of the given loop. |
| static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { |
| SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands()); |
| SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands()); |
| |
| auto lbMap = forOp.getLowerBoundMap(); |
| auto ubMap = forOp.getUpperBoundMap(); |
| auto prevLbMap = lbMap; |
| auto prevUbMap = ubMap; |
| |
| composeAffineMapAndOperands(&lbMap, &lbOperands); |
| canonicalizeMapAndOperands(&lbMap, &lbOperands); |
| lbMap = removeDuplicateExprs(lbMap); |
| |
| composeAffineMapAndOperands(&ubMap, &ubOperands); |
| canonicalizeMapAndOperands(&ubMap, &ubOperands); |
| ubMap = removeDuplicateExprs(ubMap); |
| |
| // Any canonicalization change always leads to updated map(s). |
| if (lbMap == prevLbMap && ubMap == prevUbMap) |
| return failure(); |
| |
| if (lbMap != prevLbMap) |
| forOp.setLowerBound(lbOperands, lbMap); |
| if (ubMap != prevUbMap) |
| forOp.setUpperBound(ubOperands, ubMap); |
| return success(); |
| } |
| |
| namespace { |
| /// This is a pattern to fold trivially empty loop bodies. |
| /// TODO: This should be moved into the folding hook. |
| struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> { |
| using OpRewritePattern<AffineForOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(AffineForOp forOp, |
| PatternRewriter &rewriter) const override { |
| // Check that the body only contains a yield. |
| if (!llvm::hasSingleElement(*forOp.getBody())) |
| return failure(); |
| // The initial values of the iteration arguments would be the op's results. |
| rewriter.replaceOp(forOp, forOp.getIterOperands()); |
| return success(); |
| } |
| }; |
| } // end anonymous namespace |
| |
| void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<AffineForEmptyLoopFolder>(context); |
| } |
| |
| /// Returns true if the affine.for has zero iterations in trivial cases. |
| static bool hasTrivialZeroTripCount(AffineForOp op) { |
| if (!op.hasConstantBounds()) |
| return false; |
| int64_t lb = op.getConstantLowerBound(); |
| int64_t ub = op.getConstantUpperBound(); |
| return ub - lb <= 0; |
| } |
| |
| LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| bool folded = succeeded(foldLoopBounds(*this)); |
| folded |= succeeded(canonicalizeLoopBounds(*this)); |
| if (hasTrivialZeroTripCount(*this)) { |
| // The initial values of the loop-carried variables (iter_args) are the |
| // results of the op. |
| results.assign(getIterOperands().begin(), getIterOperands().end()); |
| folded = true; |
| } |
| return success(folded); |
| } |
| |
| AffineBound AffineForOp::getLowerBound() { |
| auto lbMap = getLowerBoundMap(); |
| return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); |
| } |
| |
| AffineBound AffineForOp::getUpperBound() { |
| auto lbMap = getLowerBoundMap(); |
| auto ubMap = getUpperBoundMap(); |
| return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), |
| lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap); |
| } |
| |
| void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { |
| assert(lbOperands.size() == map.getNumInputs()); |
| assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
| |
| SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end()); |
| |
| auto ubOperands = getUpperBoundOperands(); |
| newOperands.append(ubOperands.begin(), ubOperands.end()); |
| auto iterOperands = getIterOperands(); |
| newOperands.append(iterOperands.begin(), iterOperands.end()); |
| (*this)->setOperands(newOperands); |
| |
| (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { |
| assert(ubOperands.size() == map.getNumInputs()); |
| assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
| |
| SmallVector<Value, 4> newOperands(getLowerBoundOperands()); |
| newOperands.append(ubOperands.begin(), ubOperands.end()); |
| auto iterOperands = getIterOperands(); |
| newOperands.append(iterOperands.begin(), iterOperands.end()); |
| (*this)->setOperands(newOperands); |
| |
| (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| void AffineForOp::setLowerBoundMap(AffineMap map) { |
| auto lbMap = getLowerBoundMap(); |
| assert(lbMap.getNumDims() == map.getNumDims() && |
| lbMap.getNumSymbols() == map.getNumSymbols()); |
| assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
| (void)lbMap; |
| (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| void AffineForOp::setUpperBoundMap(AffineMap map) { |
| auto ubMap = getUpperBoundMap(); |
| assert(ubMap.getNumDims() == map.getNumDims() && |
| ubMap.getNumSymbols() == map.getNumSymbols()); |
| assert(map.getNumResults() >= 1 && "bound map has at least one result"); |
| (void)ubMap; |
| (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| bool AffineForOp::hasConstantLowerBound() { |
| return getLowerBoundMap().isSingleConstant(); |
| } |
| |
| bool AffineForOp::hasConstantUpperBound() { |
| return getUpperBoundMap().isSingleConstant(); |
| } |
| |
| int64_t AffineForOp::getConstantLowerBound() { |
| return getLowerBoundMap().getSingleConstantResult(); |
| } |
| |
| int64_t AffineForOp::getConstantUpperBound() { |
| return getUpperBoundMap().getSingleConstantResult(); |
| } |
| |
| void AffineForOp::setConstantLowerBound(int64_t value) { |
| setLowerBound({}, AffineMap::getConstantMap(value, getContext())); |
| } |
| |
| void AffineForOp::setConstantUpperBound(int64_t value) { |
| setUpperBound({}, AffineMap::getConstantMap(value, getContext())); |
| } |
| |
| AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { |
| return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; |
| } |
| |
| AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { |
| return {operand_begin() + getLowerBoundMap().getNumInputs(), |
| operand_begin() + getLowerBoundMap().getNumInputs() + |
| getUpperBoundMap().getNumInputs()}; |
| } |
| |
| AffineForOp::operand_range AffineForOp::getControlOperands() { |
| return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs() + |
| getUpperBoundMap().getNumInputs()}; |
| } |
| |
| bool AffineForOp::matchingBoundOperandList() { |
| auto lbMap = getLowerBoundMap(); |
| auto ubMap = getUpperBoundMap(); |
| if (lbMap.getNumDims() != ubMap.getNumDims() || |
| lbMap.getNumSymbols() != ubMap.getNumSymbols()) |
| return false; |
| |
| unsigned numOperands = lbMap.getNumInputs(); |
| for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { |
| // Compare Value 's. |
| if (getOperand(i) != getOperand(numOperands + i)) |
| return false; |
| } |
| return true; |
| } |
| |
| Region &AffineForOp::getLoopBody() { return region(); } |
| |
| bool AffineForOp::isDefinedOutsideOfLoop(Value value) { |
| return !region().isAncestor(value.getParentRegion()); |
| } |
| |
| LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) { |
| for (auto *op : ops) |
| op->moveBefore(*this); |
| return success(); |
| } |
| |
| /// Returns true if the provided value is the induction variable of a |
| /// AffineForOp. |
| bool mlir::isForInductionVar(Value val) { |
| return getForInductionVarOwner(val) != AffineForOp(); |
| } |
| |
| /// Returns the loop parent of an induction variable. If the provided value is |
| /// not an induction variable, then return nullptr. |
| AffineForOp mlir::getForInductionVarOwner(Value val) { |
| auto ivArg = val.dyn_cast<BlockArgument>(); |
| if (!ivArg || !ivArg.getOwner()) |
| return AffineForOp(); |
| auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); |
| if (auto forOp = dyn_cast<AffineForOp>(containingInst)) |
| // Check to make sure `val` is the induction variable, not an iter_arg. |
| return forOp.getInductionVar() == val ? forOp : AffineForOp(); |
| return AffineForOp(); |
| } |
| |
| /// Extracts the induction variables from a list of AffineForOps and returns |
| /// them. |
| void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts, |
| SmallVectorImpl<Value> *ivs) { |
| ivs->reserve(forInsts.size()); |
| for (auto forInst : forInsts) |
| ivs->push_back(forInst.getInductionVar()); |
| } |
| |
| /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop |
| /// operations. |
| template <typename BoundListTy, typename LoopCreatorTy> |
| static void buildAffineLoopNestImpl( |
| OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, |
| ArrayRef<int64_t> steps, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, |
| LoopCreatorTy &&loopCreatorFn) { |
| assert(lbs.size() == ubs.size() && "Mismatch in number of arguments"); |
| assert(lbs.size() == steps.size() && "Mismatch in number of arguments"); |
| |
| // If there are no loops to be constructed, construct the body anyway. |
| OpBuilder::InsertionGuard guard(builder); |
| if (lbs.empty()) { |
| if (bodyBuilderFn) |
| bodyBuilderFn(builder, loc, ValueRange()); |
| return; |
| } |
| |
| // Create the loops iteratively and store the induction variables. |
| SmallVector<Value, 4> ivs; |
| ivs.reserve(lbs.size()); |
| for (unsigned i = 0, e = lbs.size(); i < e; ++i) { |
| // Callback for creating the loop body, always creates the terminator. |
| auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, |
| ValueRange iterArgs) { |
| ivs.push_back(iv); |
| // In the innermost loop, call the body builder. |
| if (i == e - 1 && bodyBuilderFn) { |
| OpBuilder::InsertionGuard nestedGuard(nestedBuilder); |
| bodyBuilderFn(nestedBuilder, nestedLoc, ivs); |
| } |
| nestedBuilder.create<AffineYieldOp>(nestedLoc); |
| }; |
| |
| // Delegate actual loop creation to the callback in order to dispatch |
| // between constant- and variable-bound loops. |
| auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody); |
| builder.setInsertionPointToStart(loop.getBody()); |
| } |
| } |
| |
| /// Creates an affine loop from the bounds known to be constants. |
| static AffineForOp |
| buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, |
| int64_t ub, int64_t step, |
| AffineForOp::BodyBuilderFn bodyBuilderFn) { |
| return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None, |
| bodyBuilderFn); |
| } |
| |
| /// Creates an affine loop from the bounds that may or may not be constants. |
| static AffineForOp |
| buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, |
| int64_t step, |
| AffineForOp::BodyBuilderFn bodyBuilderFn) { |
| auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>(); |
| auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>(); |
| if (lbConst && ubConst) |
| return buildAffineLoopFromConstants(builder, loc, lbConst.value(), |
| ubConst.value(), step, bodyBuilderFn); |
| return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub, |
| builder.getDimIdentityMap(), step, |
| /*iterArgs=*/llvm::None, bodyBuilderFn); |
| } |
| |
| void mlir::buildAffineLoopNest( |
| OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs, |
| ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, |
| buildAffineLoopFromConstants); |
| } |
| |
| void mlir::buildAffineLoopNest( |
| OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, |
| ArrayRef<int64_t> steps, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { |
| buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, |
| buildAffineLoopFromValues); |
| } |
| |
| AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop, |
| ValueRange newIterOperands, |
| ValueRange newYieldedValues, |
| ValueRange newIterArgs, |
| bool replaceLoopResults) { |
| assert(newIterOperands.size() == newYieldedValues.size() && |
| "newIterOperands must be of the same size as newYieldedValues"); |
| // Create a new loop before the existing one, with the extra operands. |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPoint(loop); |
| auto operands = llvm::to_vector<4>(loop.getIterOperands()); |
| operands.append(newIterOperands.begin(), newIterOperands.end()); |
| SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands()); |
| SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands()); |
| SmallVector<Value, 4> steps(loop.getStep()); |
| auto lbMap = loop.getLowerBoundMap(); |
| auto ubMap = loop.getUpperBoundMap(); |
| AffineForOp newLoop = |
| b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap, |
| loop.getStep(), operands); |
| // Take the body of the original parent loop. |
| newLoop.getLoopBody().takeBody(loop.getLoopBody()); |
| for (Value val : newIterArgs) |
| newLoop.getLoopBody().addArgument(val.getType()); |
| |
| // Update yield operation with new values to be added. |
| if (!newYieldedValues.empty()) { |
| auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator()); |
| b.setInsertionPoint(yield); |
| auto yieldOperands = llvm::to_vector<4>(yield.getOperands()); |
| yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end()); |
| b.create<AffineYieldOp>(yield.getLoc(), yieldOperands); |
| yield.erase(); |
| } |
| if (replaceLoopResults) { |
| for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( |
| loop.getNumResults()))) { |
| std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); |
| } |
| } |
| return newLoop; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineIfOp |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Remove else blocks that have nothing other than a zero value yield. |
| struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> { |
| using OpRewritePattern<AffineIfOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(AffineIfOp ifOp, |
| PatternRewriter &rewriter) const override { |
| if (ifOp.elseRegion().empty() || |
| !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) |
| return failure(); |
| |
| rewriter.startRootUpdate(ifOp); |
| rewriter.eraseBlock(ifOp.getElseBlock()); |
| rewriter.finalizeRootUpdate(ifOp); |
| return success(); |
| } |
| }; |
| |
| /// Removes affine.if cond if the condition is always true or false in certain |
| /// trivial cases. Promotes the then/else block in the parent operation block. |
| struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> { |
| using OpRewritePattern<AffineIfOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(AffineIfOp op, |
| PatternRewriter &rewriter) const override { |
| |
| auto isTriviallyFalse = [](IntegerSet iSet) { |
| return iSet.isEmptyIntegerSet(); |
| }; |
| |
| auto isTriviallyTrue = [](IntegerSet iSet) { |
| return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 && |
| iSet.getConstraint(0) == 0); |
| }; |
| |
| IntegerSet affineIfConditions = op.getIntegerSet(); |
| Block *blockToMove; |
| if (isTriviallyFalse(affineIfConditions)) { |
| // The absence, or equivalently, the emptiness of the else region need not |
| // be checked when affine.if is returning results because if an affine.if |
| // operation is returning results, it always has a non-empty else region. |
| if (op.getNumResults() == 0 && !op.hasElse()) { |
| // If the else region is absent, or equivalently, empty, remove the |
| // affine.if operation (which is not returning any results). |
| rewriter.eraseOp(op); |
| return success(); |
| } |
| blockToMove = op.getElseBlock(); |
| } else if (isTriviallyTrue(affineIfConditions)) { |
| blockToMove = op.getThenBlock(); |
| } else { |
| return failure(); |
| } |
| Operation *blockToMoveTerminator = blockToMove->getTerminator(); |
| // Promote the "blockToMove" block to the parent operation block between the |
| // prologue and epilogue of "op". |
| rewriter.mergeBlockBefore(blockToMove, op); |
| // Replace the "op" operation with the operands of the |
| // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is |
| // the affine.yield operation present in the "blockToMove" block. It has no |
| // operands when affine.if is not returning results and therefore, in that |
| // case, replaceOp just erases "op". When affine.if is not returning |
| // results, the affine.yield operation can be omitted. It gets inserted |
| // implicitly. |
| rewriter.replaceOp(op, blockToMoveTerminator->getOperands()); |
| // Erase the "blockToMoveTerminator" operation since it is now in the parent |
| // operation block, which already has its own terminator. |
| rewriter.eraseOp(blockToMoveTerminator); |
| return success(); |
| } |
| }; |
| } // end anonymous namespace. |
| |
| static LogicalResult verify(AffineIfOp op) { |
| // Verify that we have a condition attribute. |
| auto conditionAttr = |
| op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); |
| if (!conditionAttr) |
| return op.emitOpError( |
| "requires an integer set attribute named 'condition'"); |
| |
| // Verify that there are enough operands for the condition. |
| IntegerSet condition = conditionAttr.getValue(); |
| if (op.getNumOperands() != condition.getNumInputs()) |
| return op.emitOpError( |
| "operand count and condition integer set dimension and " |
| "symbol count must match"); |
| |
| // Verify that the operands are valid dimension/symbols. |
| if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), |
| condition.getNumDims()))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static ParseResult parseAffineIfOp(OpAsmParser &parser, |
| OperationState &result) { |
| // Parse the condition attribute set. |
| IntegerSetAttr conditionAttr; |
| unsigned numDims; |
| if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), |
| result.attributes) || |
| parseDimAndSymbolList(parser, result.operands, numDims)) |
| return failure(); |
| |
| // Verify the condition operands. |
| auto set = conditionAttr.getValue(); |
| if (set.getNumDims() != numDims) |
| return parser.emitError( |
| parser.getNameLoc(), |
| "dim operand count and integer set dim count must match"); |
| if (numDims + set.getNumSymbols() != result.operands.size()) |
| return parser.emitError( |
| parser.getNameLoc(), |
| "symbol operand count and integer set symbol count must match"); |
| |
| if (parser.parseOptionalArrowTypeList(result.types)) |
| return failure(); |
| |
| // Create the regions for 'then' and 'else'. The latter must be created even |
| // if it remains empty for the validity of the operation. |
| result.regions.reserve(2); |
| Region *thenRegion = result.addRegion(); |
| Region *elseRegion = result.addRegion(); |
| |
| // Parse the 'then' region. |
| if (parser.parseRegion(*thenRegion, {}, {})) |
| return failure(); |
| AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), |
| result.location); |
| |
| // If we find an 'else' keyword then parse the 'else' region. |
| if (!parser.parseOptionalKeyword("else")) { |
| if (parser.parseRegion(*elseRegion, {}, {})) |
| return failure(); |
| AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), |
| result.location); |
| } |
| |
| // Parse the optional attribute list. |
| if (parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineIfOp op) { |
| auto conditionAttr = |
| op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName()); |
| p << " " << conditionAttr; |
| printDimAndSymbolList(op.operand_begin(), op.operand_end(), |
| conditionAttr.getValue().getNumDims(), p); |
| p.printOptionalArrowTypeList(op.getResultTypes()); |
| p.printRegion(op.thenRegion(), |
| /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/op.getNumResults()); |
| |
| // Print the 'else' regions if it has any blocks. |
| auto &elseRegion = op.elseRegion(); |
| if (!elseRegion.empty()) { |
| p << " else"; |
| p.printRegion(elseRegion, |
| /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/op.getNumResults()); |
| } |
| |
| // Print the attribute list. |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/op.getConditionAttrName()); |
| } |
| |
| IntegerSet AffineIfOp::getIntegerSet() { |
| return (*this) |
| ->getAttrOfType<IntegerSetAttr>(getConditionAttrName()) |
| .getValue(); |
| } |
| |
| void AffineIfOp::setIntegerSet(IntegerSet newSet) { |
| (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); |
| } |
| |
| void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { |
| setIntegerSet(set); |
| (*this)->setOperands(operands); |
| } |
| |
| void AffineIfOp::build(OpBuilder &builder, OperationState &result, |
| TypeRange resultTypes, IntegerSet set, ValueRange args, |
| bool withElseRegion) { |
| assert(resultTypes.empty() || withElseRegion); |
| result.addTypes(resultTypes); |
| result.addOperands(args); |
| result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); |
| |
| Region *thenRegion = result.addRegion(); |
| thenRegion->push_back(new Block()); |
| if (resultTypes.empty()) |
| AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); |
| |
| Region *elseRegion = result.addRegion(); |
| if (withElseRegion) { |
| elseRegion->push_back(new Block()); |
| if (resultTypes.empty()) |
| AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); |
| } |
| } |
| |
| void AffineIfOp::build(OpBuilder &builder, OperationState &result, |
| IntegerSet set, ValueRange args, bool withElseRegion) { |
| AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, |
| withElseRegion); |
| } |
| |
| /// Canonicalize an affine if op's conditional (integer set + operands). |
| LogicalResult AffineIfOp::fold(ArrayRef<Attribute>, |
| SmallVectorImpl<OpFoldResult> &) { |
| auto set = getIntegerSet(); |
| SmallVector<Value, 4> operands(getOperands()); |
| canonicalizeSetAndOperands(&set, &operands); |
| |
| // Any canonicalization change always leads to either a reduction in the |
| // number of operands or a change in the number of symbolic operands |
| // (promotion of dims to symbols). |
| if (operands.size() < getIntegerSet().getNumInputs() || |
| set.getNumSymbols() > getIntegerSet().getNumSymbols()) { |
| setConditional(set, operands); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineLoadOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
| AffineMap map, ValueRange operands) { |
| assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); |
| result.addOperands(operands); |
| if (map) |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| auto memrefType = operands[0].getType().cast<MemRefType>(); |
| result.types.push_back(memrefType.getElementType()); |
| } |
| |
| void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
| Value memref, AffineMap map, ValueRange mapOperands) { |
| assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
| result.addOperands(memref); |
| result.addOperands(mapOperands); |
| auto memrefType = memref.getType().cast<MemRefType>(); |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| result.types.push_back(memrefType.getElementType()); |
| } |
| |
| void AffineLoadOp::build(OpBuilder &builder, OperationState &result, |
| Value memref, ValueRange indices) { |
| auto memrefType = memref.getType().cast<MemRefType>(); |
| int64_t rank = memrefType.getRank(); |
| // Create identity map for memrefs with at least one dimension or () -> () |
| // for zero-dimensional memrefs. |
| auto map = |
| rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
| build(builder, result, memref, map, indices); |
| } |
| |
| static ParseResult parseAffineLoadOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexTy = builder.getIndexType(); |
| |
| MemRefType type; |
| OpAsmParser::OperandType memrefInfo; |
| AffineMapAttr mapAttr; |
| SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
| return failure( |
| parser.parseOperand(memrefInfo) || |
| parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
| AffineLoadOp::getMapAttrName(), |
| result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(memrefInfo, type, result.operands) || |
| parser.resolveOperands(mapOperands, indexTy, result.operands) || |
| parser.addTypeToList(type.getElementType(), result.types)); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineLoadOp op) { |
| p << " " << op.getMemRef() << '['; |
| if (AffineMapAttr mapAttr = |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
| p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
| p << ']'; |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{op.getMapAttrName()}); |
| p << " : " << op.getMemRefType(); |
| } |
| |
| /// Verify common indexing invariants of affine.load, affine.store, |
| /// affine.vector_load and affine.vector_store. |
| static LogicalResult |
| verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, |
| Operation::operand_range mapOperands, |
| MemRefType memrefType, unsigned numIndexOperands) { |
| if (mapAttr) { |
| AffineMap map = mapAttr.getValue(); |
| if (map.getNumResults() != memrefType.getRank()) |
| return op->emitOpError("affine map num results must equal memref rank"); |
| if (map.getNumInputs() != numIndexOperands) |
| return op->emitOpError("expects as many subscripts as affine map inputs"); |
| } else { |
| if (memrefType.getRank() != numIndexOperands) |
| return op->emitOpError( |
| "expects the number of subscripts to be equal to memref rank"); |
| } |
| |
| Region *scope = getAffineScope(op); |
| for (auto idx : mapOperands) { |
| if (!idx.getType().isIndex()) |
| return op->emitOpError("index to load must have 'index' type"); |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return op->emitOpError("index must be a dimension or symbol identifier"); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult verify(AffineLoadOp op) { |
| auto memrefType = op.getMemRefType(); |
| if (op.getType() != memrefType.getElementType()) |
| return op.emitOpError("result type must match element type of memref"); |
| |
| if (failed(verifyMemoryOpIndexing( |
| op.getOperation(), |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
| op.getMapOperands(), memrefType, |
| /*numIndexOperands=*/op.getNumOperands() - 1))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<SimplifyAffineOp<AffineLoadOp>>(context); |
| } |
| |
| OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) { |
| /// load(memrefcast) -> load |
| if (succeeded(foldMemRefCast(*this))) |
| return getResult(); |
| return OpFoldResult(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineStoreOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
| Value valueToStore, Value memref, AffineMap map, |
| ValueRange mapOperands) { |
| assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
| result.addOperands(valueToStore); |
| result.addOperands(memref); |
| result.addOperands(mapOperands); |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| // Use identity map. |
| void AffineStoreOp::build(OpBuilder &builder, OperationState &result, |
| Value valueToStore, Value memref, |
| ValueRange indices) { |
| auto memrefType = memref.getType().cast<MemRefType>(); |
| int64_t rank = memrefType.getRank(); |
| // Create identity map for memrefs with at least one dimension or () -> () |
| // for zero-dimensional memrefs. |
| auto map = |
| rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
| build(builder, result, valueToStore, memref, map, indices); |
| } |
| |
| static ParseResult parseAffineStoreOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto indexTy = parser.getBuilder().getIndexType(); |
| |
| MemRefType type; |
| OpAsmParser::OperandType storeValueInfo; |
| OpAsmParser::OperandType memrefInfo; |
| AffineMapAttr mapAttr; |
| SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
| return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || |
| parser.parseOperand(memrefInfo) || |
| parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
| AffineStoreOp::getMapAttrName(), |
| result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(storeValueInfo, type.getElementType(), |
| result.operands) || |
| parser.resolveOperand(memrefInfo, type, result.operands) || |
| parser.resolveOperands(mapOperands, indexTy, result.operands)); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineStoreOp op) { |
| p << " " << op.getValueToStore(); |
| p << ", " << op.getMemRef() << '['; |
| if (AffineMapAttr mapAttr = |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
| p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
| p << ']'; |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{op.getMapAttrName()}); |
| p << " : " << op.getMemRefType(); |
| } |
| |
| LogicalResult verify(AffineStoreOp op) { |
| // The value to store must have the same type as memref element type. |
| auto memrefType = op.getMemRefType(); |
| if (op.getValueToStore().getType() != memrefType.getElementType()) |
| return op.emitOpError( |
| "value to store must have the same type as memref element type"); |
| |
| if (failed(verifyMemoryOpIndexing( |
| op.getOperation(), |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
| op.getMapOperands(), memrefType, |
| /*numIndexOperands=*/op.getNumOperands() - 2))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<SimplifyAffineOp<AffineStoreOp>>(context); |
| } |
| |
| LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| /// store(memrefcast) -> store |
| return foldMemRefCast(*this, getValueToStore()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineMinMaxOpBase |
| //===----------------------------------------------------------------------===// |
| |
| template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) { |
| // Verify that operand count matches affine map dimension and symbol count. |
| if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) |
| return op.emitOpError( |
| "operand count and affine map dimension and symbol count must match"); |
| return success(); |
| } |
| |
| template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { |
| p << ' ' << op->getAttr(T::getMapAttrName()); |
| auto operands = op.getOperands(); |
| unsigned numDims = op.map().getNumDims(); |
| p << '(' << operands.take_front(numDims) << ')'; |
| |
| if (operands.size() != numDims) |
| p << '[' << operands.drop_front(numDims) << ']'; |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{T::getMapAttrName()}); |
| } |
| |
| template <typename T> |
| static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexType = builder.getIndexType(); |
| SmallVector<OpAsmParser::OperandType, 8> dimInfos; |
| SmallVector<OpAsmParser::OperandType, 8> symInfos; |
| AffineMapAttr mapAttr; |
| return failure( |
| parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) || |
| parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) || |
| parser.parseOperandList(symInfos, |
| OpAsmParser::Delimiter::OptionalSquare) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.resolveOperands(dimInfos, indexType, result.operands) || |
| parser.resolveOperands(symInfos, indexType, result.operands) || |
| parser.addTypeToList(indexType, result.types)); |
| } |
| |
| /// Fold an affine min or max operation with the given operands. The operand |
| /// list may contain nulls, which are interpreted as the operand not being a |
| /// constant. |
| template <typename T> |
| static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) { |
| static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value, |
| "expected affine min or max op"); |
| |
| // Fold the affine map. |
| // TODO: Fold more cases: |
| // min(some_affine, some_affine + constant, ...), etc. |
| SmallVector<int64_t, 2> results; |
| auto foldedMap = op.map().partialConstantFold(operands, &results); |
| |
| // If some of the map results are not constant, try changing the map in-place. |
| if (results.empty()) { |
| // If the map is the same, report that folding did not happen. |
| if (foldedMap == op.map()) |
| return {}; |
| op->setAttr("map", AffineMapAttr::get(foldedMap)); |
| return op.getResult(); |
| } |
| |
| // Otherwise, completely fold the op into a constant. |
| auto resultIt = std::is_same<T, AffineMinOp>::value |
| ? std::min_element(results.begin(), results.end()) |
| : std::max_element(results.begin(), results.end()); |
| if (resultIt == results.end()) |
| return {}; |
| return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); |
| } |
| |
| /// Remove duplicated expressions in affine min/max ops. |
| template <typename T> |
| struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(T affineOp, |
| PatternRewriter &rewriter) const override { |
| AffineMap oldMap = affineOp.getAffineMap(); |
| |
| SmallVector<AffineExpr, 4> newExprs; |
| for (AffineExpr expr : oldMap.getResults()) { |
| // This is a linear scan over newExprs, but it should be fine given that |
| // we typically just have a few expressions per op. |
| if (!llvm::is_contained(newExprs, expr)) |
| newExprs.push_back(expr); |
| } |
| |
| if (newExprs.size() == oldMap.getNumResults()) |
| return failure(); |
| |
| auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), |
| newExprs, rewriter.getContext()); |
| rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands()); |
| |
| return success(); |
| } |
| }; |
| |
| /// Merge an affine min/max op to its consumers if its consumer is also an |
| /// affine min/max op. |
| /// |
| /// This pattern requires the producer affine min/max op is bound to a |
| /// dimension/symbol that is used as a standalone expression in the consumer |
| /// affine op's map. |
| /// |
| /// For example, a pattern like the following: |
| /// |
| /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1] |
| /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2] |
| /// |
| /// Can be turned into: |
| /// |
| /// %1 = affine.min affine_map< |
| /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1] |
| template <typename T> struct MergeAffineMinMaxOp : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(T affineOp, |
| PatternRewriter &rewriter) const override { |
| AffineMap oldMap = affineOp.getAffineMap(); |
| ValueRange dimOperands = |
| affineOp.getMapOperands().take_front(oldMap.getNumDims()); |
| ValueRange symOperands = |
| affineOp.getMapOperands().take_back(oldMap.getNumSymbols()); |
| |
| auto newDimOperands = llvm::to_vector<8>(dimOperands); |
| auto newSymOperands = llvm::to_vector<8>(symOperands); |
| SmallVector<AffineExpr, 4> newExprs; |
| SmallVector<T, 4> producerOps; |
| |
| // Go over each expression to see whether it's a single dimension/symbol |
| // with the corresponding operand which is the result of another affine |
| // min/max op. If So it can be merged into this affine op. |
| for (AffineExpr expr : oldMap.getResults()) { |
| if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) { |
| Value symValue = symOperands[symExpr.getPosition()]; |
| if (auto producerOp = symValue.getDefiningOp<T>()) { |
| producerOps.push_back(producerOp); |
| continue; |
| } |
| } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { |
| Value dimValue = dimOperands[dimExpr.getPosition()]; |
| if (auto producerOp = dimValue.getDefiningOp<T>()) { |
| producerOps.push_back(producerOp); |
| continue; |
| } |
| } |
| // For the above cases we will remove the expression by merging the |
| // producer affine min/max's affine expressions. Otherwise we need to |
| // keep the existing expression. |
| newExprs.push_back(expr); |
| } |
| |
| if (producerOps.empty()) |
| return failure(); |
| |
| unsigned numUsedDims = oldMap.getNumDims(); |
| unsigned numUsedSyms = oldMap.getNumSymbols(); |
| |
| // Now go over all producer affine ops and merge their expressions. |
| for (T producerOp : producerOps) { |
| AffineMap producerMap = producerOp.getAffineMap(); |
| unsigned numProducerDims = producerMap.getNumDims(); |
| unsigned numProducerSyms = producerMap.getNumSymbols(); |
| |
| // Collect all dimension/symbol values. |
| ValueRange dimValues = |
| producerOp.getMapOperands().take_front(numProducerDims); |
| ValueRange symValues = |
| producerOp.getMapOperands().take_back(numProducerSyms); |
| newDimOperands.append(dimValues.begin(), dimValues.end()); |
| newSymOperands.append(symValues.begin(), symValues.end()); |
| |
| // For expressions we need to shift to avoid overlap. |
| for (AffineExpr expr : producerMap.getResults()) { |
| newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims) |
| .shiftSymbols(numProducerSyms, numUsedSyms)); |
| } |
| |
| numUsedDims += numProducerDims; |
| numUsedSyms += numProducerSyms; |
| } |
| |
| auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs, |
| rewriter.getContext()); |
| auto newOperands = |
| llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands)); |
| rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands); |
| |
| return success(); |
| } |
| }; |
| |
| template <typename T> |
| struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> { |
| using OpRewritePattern<T>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(T affineOp, |
| PatternRewriter &rewriter) const override { |
| if (affineOp.map().getNumResults() != 1) |
| return failure(); |
| rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(), |
| affineOp.getOperands()); |
| return success(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // AffineMinOp |
| //===----------------------------------------------------------------------===// |
| // |
| // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) |
| // |
| |
| OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) { |
| return foldMinMaxOp(*this, operands); |
| } |
| |
| void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| MLIRContext *context) { |
| patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>, |
| DeduplicateAffineMinMaxExpressions<AffineMinOp>, |
| MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineMaxOp |
| //===----------------------------------------------------------------------===// |
| // |
| // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) |
| // |
| |
| OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) { |
| return foldMinMaxOp(*this, operands); |
| } |
| |
| void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, |
| MLIRContext *context) { |
| patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>, |
| DeduplicateAffineMinMaxExpressions<AffineMaxOp>, |
| MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffinePrefetchOp |
| //===----------------------------------------------------------------------===// |
| |
| // |
| // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> |
| // |
| static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexTy = builder.getIndexType(); |
| |
| MemRefType type; |
| OpAsmParser::OperandType memrefInfo; |
| IntegerAttr hintInfo; |
| auto i32Type = parser.getBuilder().getIntegerType(32); |
| StringRef readOrWrite, cacheType; |
| |
| AffineMapAttr mapAttr; |
| SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
| if (parser.parseOperand(memrefInfo) || |
| parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
| AffinePrefetchOp::getMapAttrName(), |
| result.attributes) || |
| parser.parseComma() || parser.parseKeyword(&readOrWrite) || |
| parser.parseComma() || parser.parseKeyword("locality") || |
| parser.parseLess() || |
| parser.parseAttribute(hintInfo, i32Type, |
| AffinePrefetchOp::getLocalityHintAttrName(), |
| result.attributes) || |
| parser.parseGreater() || parser.parseComma() || |
| parser.parseKeyword(&cacheType) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(type) || |
| parser.resolveOperand(memrefInfo, type, result.operands) || |
| parser.resolveOperands(mapOperands, indexTy, result.operands)) |
| return failure(); |
| |
| if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) |
| return parser.emitError(parser.getNameLoc(), |
| "rw specifier has to be 'read' or 'write'"); |
| result.addAttribute( |
| AffinePrefetchOp::getIsWriteAttrName(), |
| parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); |
| |
| if (!cacheType.equals("data") && !cacheType.equals("instr")) |
| return parser.emitError(parser.getNameLoc(), |
| "cache type has to be 'data' or 'instr'"); |
| |
| result.addAttribute( |
| AffinePrefetchOp::getIsDataCacheAttrName(), |
| parser.getBuilder().getBoolAttr(cacheType.equals("data"))); |
| |
| return success(); |
| } |
| |
| static void print(OpAsmPrinter &p, AffinePrefetchOp op) { |
| p << " " << op.memref() << '['; |
| AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()); |
| if (mapAttr) { |
| SmallVector<Value, 2> operands(op.getMapOperands()); |
| p.printAffineMapOfSSAIds(mapAttr, operands); |
| } |
| p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " |
| << "locality<" << op.localityHint() << ">, " |
| << (op.isDataCache() ? "data" : "instr"); |
| p.printOptionalAttrDict( |
| op->getAttrs(), |
| /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), |
| op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); |
| p << " : " << op.getMemRefType(); |
| } |
| |
| static LogicalResult verify(AffinePrefetchOp op) { |
| auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()); |
| if (mapAttr) { |
| AffineMap map = mapAttr.getValue(); |
| if (map.getNumResults() != op.getMemRefType().getRank()) |
| return op.emitOpError("affine.prefetch affine map num results must equal" |
| " memref rank"); |
| if (map.getNumInputs() + 1 != op.getNumOperands()) |
| return op.emitOpError("too few operands"); |
| } else { |
| if (op.getNumOperands() != 1) |
| return op.emitOpError("too few operands"); |
| } |
| |
| Region *scope = getAffineScope(op); |
| for (auto idx : op.getMapOperands()) { |
| if (!isValidAffineIndexOperand(idx, scope)) |
| return op.emitOpError("index must be a dimension or symbol identifier"); |
| } |
| return success(); |
| } |
| |
| void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| // prefetch(memrefcast) -> prefetch |
| results.add<SimplifyAffineOp<AffinePrefetchOp>>(context); |
| } |
| |
| LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| /// prefetch(memrefcast) -> prefetch |
| return foldMemRefCast(*this); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineParallelOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
| TypeRange resultTypes, |
| ArrayRef<AtomicRMWKind> reductions, |
| ArrayRef<int64_t> ranges) { |
| SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0)); |
| auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) { |
| return builder.getConstantAffineMap(value); |
| })); |
| SmallVector<int64_t> steps(ranges.size(), 1); |
| build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs, |
| /*ubArgs=*/{}, steps); |
| } |
| |
| void AffineParallelOp::build(OpBuilder &builder, OperationState &result, |
| TypeRange resultTypes, |
| ArrayRef<AtomicRMWKind> reductions, |
| ArrayRef<AffineMap> lbMaps, ValueRange lbArgs, |
| ArrayRef<AffineMap> ubMaps, ValueRange ubArgs, |
| ArrayRef<int64_t> steps) { |
| assert(llvm::all_of(lbMaps, |
| [lbMaps](AffineMap m) { |
| return m.getNumDims() == lbMaps[0].getNumDims() && |
| m.getNumSymbols() == lbMaps[0].getNumSymbols(); |
| }) && |
| "expected all lower bounds maps to have the same number of dimensions " |
| "and symbols"); |
| assert(llvm::all_of(ubMaps, |
| [ubMaps](AffineMap m) { |
| return m.getNumDims() == ubMaps[0].getNumDims() && |
| m.getNumSymbols() == ubMaps[0].getNumSymbols(); |
| }) && |
| "expected all upper bounds maps to have the same number of dimensions " |
| "and symbols"); |
| assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) && |
| "expected lower bound maps to have as many inputs as lower bound " |
| "operands"); |
| assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) && |
| "expected upper bound maps to have as many inputs as upper bound " |
| "operands"); |
| |
| result.addTypes(resultTypes); |
| |
| // Convert the reductions to integer attributes. |
| SmallVector<Attribute, 4> reductionAttrs; |
| for (AtomicRMWKind reduction : reductions) |
| reductionAttrs.push_back( |
| builder.getI64IntegerAttr(static_cast<int64_t>(reduction))); |
| result.addAttribute(getReductionsAttrName(), |
| builder.getArrayAttr(reductionAttrs)); |
| |
| // Concatenates maps defined in the same input space (same dimensions and |
| // symbols), assumes there is at least one map. |
| auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps, |
| SmallVectorImpl<int32_t> &groups) { |
| if (maps.empty()) |
| return AffineMap::get(builder.getContext()); |
| SmallVector<AffineExpr> exprs; |
| groups.reserve(groups.size() + maps.size()); |
| exprs.reserve(maps.size()); |
| for (AffineMap m : maps) { |
| llvm::append_range(exprs, m.getResults()); |
| groups.push_back(m.getNumResults()); |
| } |
| return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs, |
| maps[0].getContext()); |
| }; |
| |
| // Set up the bounds. |
| SmallVector<int32_t> lbGroups, ubGroups; |
| AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups); |
| AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups); |
| result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); |
| result.addAttribute(getLowerBoundsGroupsAttrName(), |
| builder.getI32TensorAttr(lbGroups)); |
| result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); |
| result.addAttribute(getUpperBoundsGroupsAttrName(), |
| builder.getI32TensorAttr(ubGroups)); |
| result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); |
| result.addOperands(lbArgs); |
| result.addOperands(ubArgs); |
| |
| // Create a region and a block for the body. |
| auto *bodyRegion = result.addRegion(); |
| auto *body = new Block(); |
| // Add all the block arguments. |
| for (unsigned i = 0, e = steps.size(); i < e; ++i) |
| body->addArgument(IndexType::get(builder.getContext())); |
| bodyRegion->push_back(body); |
| if (resultTypes.empty()) |
| ensureTerminator(*bodyRegion, builder, result.location); |
| } |
| |
| Region &AffineParallelOp::getLoopBody() { return region(); } |
| |
| bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { |
| return !region().isAncestor(value.getParentRegion()); |
| } |
| |
| LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) { |
| for (Operation *op : ops) |
| op->moveBefore(*this); |
| return success(); |
| } |
| |
| unsigned AffineParallelOp::getNumDims() { return steps().size(); } |
| |
| AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { |
| return getOperands().take_front(lowerBoundsMap().getNumInputs()); |
| } |
| |
| AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { |
| return getOperands().drop_front(lowerBoundsMap().getNumInputs()); |
| } |
| |
| AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) { |
| auto values = lowerBoundsGroups().getValues<int32_t>(); |
| unsigned start = 0; |
| for (unsigned i = 0; i < pos; ++i) |
| start += values[i]; |
| return lowerBoundsMap().getSliceMap(start, values[pos]); |
| } |
| |
| AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) { |
| auto values = upperBoundsGroups().getValues<int32_t>(); |
| unsigned start = 0; |
| for (unsigned i = 0; i < pos; ++i) |
| start += values[i]; |
| return upperBoundsMap().getSliceMap(start, values[pos]); |
| } |
| |
| AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { |
| return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands()); |
| } |
| |
| AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { |
| return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands()); |
| } |
| |
| Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() { |
| if (hasMinMaxBounds()) |
| return llvm::None; |
| |
| // Try to convert all the ranges to constant expressions. |
| SmallVector<int64_t, 8> out; |
| AffineValueMap rangesValueMap; |
| AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), |
| &rangesValueMap); |
| out.reserve(rangesValueMap.getNumResults()); |
| for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { |
| auto expr = rangesValueMap.getResult(i); |
| auto cst = expr.dyn_cast<AffineConstantExpr>(); |
| if (!cst) |
| return llvm::None; |
| out.push_back(cst.getValue()); |
| } |
| return out; |
| } |
| |
| Block *AffineParallelOp::getBody() { return ®ion().front(); } |
| |
| OpBuilder AffineParallelOp::getBodyBuilder() { |
| return OpBuilder(getBody(), std::prev(getBody()->end())); |
| } |
| |
| void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { |
| assert(lbOperands.size() == map.getNumInputs() && |
| "operands to map must match number of inputs"); |
| |
| auto ubOperands = getUpperBoundsOperands(); |
| |
| SmallVector<Value, 4> newOperands(lbOperands); |
| newOperands.append(ubOperands.begin(), ubOperands.end()); |
| (*this)->setOperands(newOperands); |
| |
| lowerBoundsMapAttr(AffineMapAttr::get(map)); |
| } |
| |
| void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { |
| assert(ubOperands.size() == map.getNumInputs() && |
| "operands to map must match number of inputs"); |
| |
| SmallVector<Value, 4> newOperands(getLowerBoundsOperands()); |
| newOperands.append(ubOperands.begin(), ubOperands.end()); |
| (*this)->setOperands(newOperands); |
| |
| upperBoundsMapAttr(AffineMapAttr::get(map)); |
| } |
| |
| void AffineParallelOp::setLowerBoundsMap(AffineMap map) { |
| AffineMap lbMap = lowerBoundsMap(); |
| assert(lbMap.getNumDims() == map.getNumDims() && |
| lbMap.getNumSymbols() == map.getNumSymbols()); |
| (void)lbMap; |
| lowerBoundsMapAttr(AffineMapAttr::get(map)); |
| } |
| |
| void AffineParallelOp::setUpperBoundsMap(AffineMap map) { |
| AffineMap ubMap = upperBoundsMap(); |
| assert(ubMap.getNumDims() == map.getNumDims() && |
| ubMap.getNumSymbols() == map.getNumSymbols()); |
| (void)ubMap; |
| upperBoundsMapAttr(AffineMapAttr::get(map)); |
| } |
| |
| SmallVector<int64_t, 8> AffineParallelOp::getSteps() { |
| SmallVector<int64_t, 8> result; |
| for (Attribute attr : steps()) { |
| result.push_back(attr.cast<IntegerAttr>().getInt()); |
| } |
| return result; |
| } |
| |
| void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) { |
| stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); |
| } |
| |
| static LogicalResult verify(AffineParallelOp op) { |
| auto numDims = op.getNumDims(); |
| if (op.lowerBoundsGroups().getNumElements() != numDims || |
| op.upperBoundsGroups().getNumElements() != numDims || |
| op.steps().size() != numDims || |
| op.getBody()->getNumArguments() != numDims) { |
| return op.emitOpError() |
| << "the number of region arguments (" |
| << op.getBody()->getNumArguments() |
| << ") and the number of map groups for lower (" |
| << op.lowerBoundsGroups().getNumElements() << ") and upper bound (" |
| << op.upperBoundsGroups().getNumElements() |
| << "), and the number of steps (" << op.steps().size() |
| << ") must all match"; |
| } |
| |
| unsigned expectedNumLBResults = 0; |
| for (APInt v : op.lowerBoundsGroups()) |
| expectedNumLBResults += v.getZExtValue(); |
| if (expectedNumLBResults != op.lowerBoundsMap().getNumResults()) |
| return op.emitOpError() << "expected lower bounds map to have " |
| << expectedNumLBResults << " results"; |
| unsigned expectedNumUBResults = 0; |
| for (APInt v : op.upperBoundsGroups()) |
| expectedNumUBResults += v.getZExtValue(); |
| if (expectedNumUBResults != op.upperBoundsMap().getNumResults()) |
| return op.emitOpError() << "expected upper bounds map to have " |
| << expectedNumUBResults << " results"; |
| |
| if (op.reductions().size() != op.getNumResults()) |
| return op.emitOpError("a reduction must be specified for each output"); |
| |
| // Verify reduction ops are all valid |
| for (Attribute attr : op.reductions()) { |
| auto intAttr = attr.dyn_cast<IntegerAttr>(); |
| if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) |
| return op.emitOpError("invalid reduction attribute"); |
| } |
| |
| // Verify that the bound operands are valid dimension/symbols. |
| /// Lower bounds. |
| if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), |
| op.lowerBoundsMap().getNumDims()))) |
| return failure(); |
| /// Upper bounds. |
| if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), |
| op.upperBoundsMap().getNumDims()))) |
| return failure(); |
| return success(); |
| } |
| |
| LogicalResult AffineValueMap::canonicalize() { |
| SmallVector<Value, 4> newOperands{operands}; |
| auto newMap = getAffineMap(); |
| composeAffineMapAndOperands(&newMap, &newOperands); |
| if (newMap == getAffineMap() && newOperands == operands) |
| return failure(); |
| reset(newMap, newOperands); |
| return success(); |
| } |
| |
| /// Canonicalize the bounds of the given loop. |
| static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { |
| AffineValueMap lb = op.getLowerBoundsValueMap(); |
| bool lbCanonicalized = succeeded(lb.canonicalize()); |
| |
| AffineValueMap ub = op.getUpperBoundsValueMap(); |
| bool ubCanonicalized = succeeded(ub.canonicalize()); |
| |
| // Any canonicalization change always leads to updated map(s). |
| if (!lbCanonicalized && !ubCanonicalized) |
| return failure(); |
| |
| if (lbCanonicalized) |
| op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); |
| if (ubCanonicalized) |
| op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); |
| |
| return success(); |
| } |
| |
| LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| return canonicalizeLoopBounds(*this); |
| } |
| |
| /// Prints a lower(upper) bound of an affine parallel loop with max(min) |
| /// conditions in it. `mapAttr` is a flat list of affine expressions and `group` |
| /// identifies which of the those expressions form max/min groups. `operands` |
| /// are the SSA values of dimensions and symbols and `keyword` is either "min" |
| /// or "max". |
| static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, |
| DenseIntElementsAttr group, ValueRange operands, |
| StringRef keyword) { |
| AffineMap map = mapAttr.getValue(); |
| unsigned numDims = map.getNumDims(); |
| ValueRange dimOperands = operands.take_front(numDims); |
| ValueRange symOperands = operands.drop_front(numDims); |
| unsigned start = 0; |
| for (llvm::APInt groupSize : group) { |
| if (start != 0) |
| p << ", "; |
| |
| unsigned size = groupSize.getZExtValue(); |
| if (size == 1) { |
| p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands); |
| ++start; |
| } else { |
| p << keyword << '('; |
| AffineMap submap = map.getSliceMap(start, size); |
| p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands); |
| p << ')'; |
| start += size; |
| } |
| } |
| } |
| |
| static void print(OpAsmPrinter &p, AffineParallelOp op) { |
| p << " (" << op.getBody()->getArguments() << ") = ("; |
| printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(), |
| op.getLowerBoundsOperands(), "max"); |
| p << ") to ("; |
| printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(), |
| op.getUpperBoundsOperands(), "min"); |
| p << ')'; |
| SmallVector<int64_t, 8> steps = op.getSteps(); |
| bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); |
| if (!elideSteps) { |
| p << " step ("; |
| llvm::interleaveComma(steps, p); |
| p << ')'; |
| } |
| if (op.getNumResults()) { |
| p << " reduce ("; |
| llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { |
| AtomicRMWKind sym = |
| *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt()); |
| p << "\"" << stringifyAtomicRMWKind(sym) << "\""; |
| }); |
| p << ") -> (" << op.getResultTypes() << ")"; |
| } |
| |
| p.printRegion(op.region(), /*printEntryBlockArgs=*/false, |
| /*printBlockTerminators=*/op.getNumResults()); |
| p.printOptionalAttrDict( |
| op->getAttrs(), |
| /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), |
| AffineParallelOp::getLowerBoundsMapAttrName(), |
| AffineParallelOp::getLowerBoundsGroupsAttrName(), |
| AffineParallelOp::getUpperBoundsMapAttrName(), |
| AffineParallelOp::getUpperBoundsGroupsAttrName(), |
| AffineParallelOp::getStepsAttrName()}); |
| } |
| |
| /// Given a list of lists of parsed operands, populates `uniqueOperands` with |
| /// unique operands. Also populates `replacements with affine expressions of |
| /// `kind` that can be used to update affine maps previously accepting a |
| /// `operands` to accept `uniqueOperands` instead. |
| static void deduplicateAndResolveOperands( |
| OpAsmParser &parser, |
| ArrayRef<SmallVector<OpAsmParser::OperandType>> operands, |
| SmallVectorImpl<Value> &uniqueOperands, |
| SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) { |
| assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) && |
| "expected operands to be dim or symbol expression"); |
| |
| Type indexType = parser.getBuilder().getIndexType(); |
| for (const auto &list : operands) { |
| SmallVector<Value> valueOperands; |
| parser.resolveOperands(list, indexType, valueOperands); |
| for (Value operand : valueOperands) { |
| unsigned pos = std::distance(uniqueOperands.begin(), |
| llvm::find(uniqueOperands, operand)); |
| if (pos == uniqueOperands.size()) |
| uniqueOperands.push_back(operand); |
| replacements.push_back( |
| kind == AffineExprKind::DimId |
| ? getAffineDimExpr(pos, parser.getContext()) |
| : getAffineSymbolExpr(pos, parser.getContext())); |
| } |
| } |
| } |
| |
| namespace { |
| enum class MinMaxKind { Min, Max }; |
| } // namespace |
| |
| /// Parses an affine map that can contain a min/max for groups of its results, |
| /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates |
| /// `result` attributes with the map (flat list of expressions) and the grouping |
| /// (list of integers that specify how many expressions to put into each |
| /// min/max) attributes. Deduplicates repeated operands. |
| /// |
| /// parallel-bound ::= `(` parallel-group-list `)` |
| /// parallel-group-list ::= parallel-group (`,` parallel-group-list)? |
| /// parallel-group ::= simple-group | min-max-group |
| /// simple-group ::= expr-of-ssa-ids |
| /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)` |
| /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)? |
| /// |
| /// Examples: |
| /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6)) |
| /// (%0, max(%1 - 2 * %2)) |
| static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, |
| OperationState &result, |
| MinMaxKind kind) { |
| constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map"; |
| |
| StringRef mapName = kind == MinMaxKind::Min |
| ? AffineParallelOp::getUpperBoundsMapAttrName() |
| : AffineParallelOp::getLowerBoundsMapAttrName(); |
| StringRef groupsName = kind == MinMaxKind::Min |
| ? AffineParallelOp::getUpperBoundsGroupsAttrName() |
| : AffineParallelOp::getLowerBoundsGroupsAttrName(); |
| |
| if (failed(parser.parseLParen())) |
| return failure(); |
| |
| if (succeeded(parser.parseOptionalRParen())) { |
| result.addAttribute( |
| mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap())); |
| result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({})); |
| return success(); |
| } |
| |
| SmallVector<AffineExpr> flatExprs; |
| SmallVector<SmallVector<OpAsmParser::OperandType>> flatDimOperands; |
| SmallVector<SmallVector<OpAsmParser::OperandType>> flatSymOperands; |
| SmallVector<int32_t> numMapsPerGroup; |
| SmallVector<OpAsmParser::OperandType> mapOperands; |
| do { |
| if (succeeded(parser.parseOptionalKeyword( |
| kind == MinMaxKind::Min ? "min" : "max"))) { |
| mapOperands.clear(); |
| AffineMapAttr map; |
| if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName, |
| result.attributes, |
| OpAsmParser::Delimiter::Paren))) |
| return failure(); |
| result.attributes.erase(tmpAttrName); |
| llvm::append_range(flatExprs, map.getValue().getResults()); |
| auto operandsRef = llvm::makeArrayRef(mapOperands); |
| auto dimsRef = operandsRef.take_front(map.getValue().getNumDims()); |
| SmallVector<OpAsmParser::OperandType> dims(dimsRef.begin(), |
| dimsRef.end()); |
| auto symsRef = operandsRef.drop_front(map.getValue().getNumDims()); |
| SmallVector<OpAsmParser::OperandType> syms(symsRef.begin(), |
| symsRef.end()); |
| flatDimOperands.append(map.getValue().getNumResults(), dims); |
| flatSymOperands.append(map.getValue().getNumResults(), syms); |
| numMapsPerGroup.push_back(map.getValue().getNumResults()); |
| } else { |
| if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(), |
| flatSymOperands.emplace_back(), |
| flatExprs.emplace_back()))) |
| return failure(); |
| numMapsPerGroup.push_back(1); |
| } |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (failed(parser.parseRParen())) |
| return failure(); |
| |
| unsigned totalNumDims = 0; |
| unsigned totalNumSyms = 0; |
| for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) { |
| unsigned numDims = flatDimOperands[i].size(); |
| unsigned numSyms = flatSymOperands[i].size(); |
| flatExprs[i] = flatExprs[i] |
| .shiftDims(numDims, totalNumDims) |
| .shiftSymbols(numSyms, totalNumSyms); |
| totalNumDims += numDims; |
| totalNumSyms += numSyms; |
| } |
| |
| // Deduplicate map operands. |
| SmallVector<Value> dimOperands, symOperands; |
| SmallVector<AffineExpr> dimRplacements, symRepacements; |
| deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands, |
| dimRplacements, AffineExprKind::DimId); |
| deduplicateAndResolveOperands(parser, flatSymOperands, symOperands, |
| symRepacements, AffineExprKind::SymbolId); |
| |
| result.operands.append(dimOperands.begin(), dimOperands.end()); |
| result.operands.append(symOperands.begin(), symOperands.end()); |
| |
| Builder &builder = parser.getBuilder(); |
| auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs, |
| parser.getContext()); |
| flatMap = flatMap.replaceDimsAndSymbols( |
| dimRplacements, symRepacements, dimOperands.size(), symOperands.size()); |
| |
| result.addAttribute(mapName, AffineMapAttr::get(flatMap)); |
| result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup)); |
| return success(); |
| } |
| |
| // |
| // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound |
| // `to` parallel-bound steps? region attr-dict? |
| // steps ::= `steps` `(` integer-literals `)` |
| // |
| static ParseResult parseAffineParallelOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexType = builder.getIndexType(); |
| SmallVector<OpAsmParser::OperandType, 4> ivs; |
| if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, |
| OpAsmParser::Delimiter::Paren) || |
| parser.parseEqual() || |
| parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) || |
| parser.parseKeyword("to") || |
| parseAffineMapWithMinMax(parser, result, MinMaxKind::Min)) |
| return failure(); |
| |
| AffineMapAttr stepsMapAttr; |
| NamedAttrList stepsAttrs; |
| SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands; |
| if (failed(parser.parseOptionalKeyword("step"))) { |
| SmallVector<int64_t, 4> steps(ivs.size(), 1); |
| result.addAttribute(AffineParallelOp::getStepsAttrName(), |
| builder.getI64ArrayAttr(steps)); |
| } else { |
| if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, |
| AffineParallelOp::getStepsAttrName(), |
| stepsAttrs, |
| OpAsmParser::Delimiter::Paren)) |
| return failure(); |
| |
| // Convert steps from an AffineMap into an I64ArrayAttr. |
| SmallVector<int64_t, 4> steps; |
| auto stepsMap = stepsMapAttr.getValue(); |
| for (const auto &result : stepsMap.getResults()) { |
| auto constExpr = result.dyn_cast<AffineConstantExpr>(); |
| if (!constExpr) |
| return parser.emitError(parser.getNameLoc(), |
| "steps must be constant integers"); |
| steps.push_back(constExpr.getValue()); |
| } |
| result.addAttribute(AffineParallelOp::getStepsAttrName(), |
| builder.getI64ArrayAttr(steps)); |
| } |
| |
| // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the |
| // quoted strings are a member of the enum AtomicRMWKind. |
| SmallVector<Attribute, 4> reductions; |
| if (succeeded(parser.parseOptionalKeyword("reduce"))) { |
| if (parser.parseLParen()) |
| return failure(); |
| do { |
| // Parse a single quoted string via the attribute parsing, and then |
| // verify it is a member of the enum and convert to it's integer |
| // representation. |
| StringAttr attrVal; |
| NamedAttrList attrStorage; |
| auto loc = parser.getCurrentLocation(); |
| if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", |
| attrStorage)) |
| return failure(); |
| llvm::Optional<AtomicRMWKind> reduction = |
| symbolizeAtomicRMWKind(attrVal.getValue()); |
| if (!reduction) |
| return parser.emitError(loc, "invalid reduction value: ") << attrVal; |
| reductions.push_back(builder.getI64IntegerAttr( |
| static_cast<int64_t>(reduction.getValue()))); |
| // While we keep getting commas, keep parsing. |
| } while (succeeded(parser.parseOptionalComma())); |
| if (parser.parseRParen()) |
| return failure(); |
| } |
| result.addAttribute(AffineParallelOp::getReductionsAttrName(), |
| builder.getArrayAttr(reductions)); |
| |
| // Parse return types of reductions (if any) |
| if (parser.parseOptionalArrowTypeList(result.types)) |
| return failure(); |
| |
| // Now parse the body. |
| Region *body = result.addRegion(); |
| SmallVector<Type, 4> types(ivs.size(), indexType); |
| if (parser.parseRegion(*body, ivs, types) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| |
| // Add a terminator if none was parsed. |
| AffineParallelOp::ensureTerminator(*body, builder, result.location); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineYieldOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(AffineYieldOp op) { |
| auto *parentOp = op->getParentOp(); |
| auto results = parentOp->getResults(); |
| auto operands = op.getOperands(); |
| |
| if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp)) |
| return op.emitOpError() << "only terminates affine.if/for/parallel regions"; |
| if (parentOp->getNumResults() != op.getNumOperands()) |
| return op.emitOpError() << "parent of yield must have same number of " |
| "results as the yield operands"; |
| for (auto it : llvm::zip(results, operands)) { |
| if (std::get<0>(it).getType() != std::get<1>(it).getType()) |
| return op.emitOpError() |
| << "types mismatch between yield op and its parent"; |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineVectorLoadOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
| VectorType resultType, AffineMap map, |
| ValueRange operands) { |
| assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); |
| result.addOperands(operands); |
| if (map) |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| result.types.push_back(resultType); |
| } |
| |
| void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
| VectorType resultType, Value memref, |
| AffineMap map, ValueRange mapOperands) { |
| assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
| result.addOperands(memref); |
| result.addOperands(mapOperands); |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| result.types.push_back(resultType); |
| } |
| |
| void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result, |
| VectorType resultType, Value memref, |
| ValueRange indices) { |
| auto memrefType = memref.getType().cast<MemRefType>(); |
| int64_t rank = memrefType.getRank(); |
| // Create identity map for memrefs with at least one dimension or () -> () |
| // for zero-dimensional memrefs. |
| auto map = |
| rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
| build(builder, result, resultType, memref, map, indices); |
| } |
| |
| void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) { |
| results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context); |
| } |
| |
| static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto &builder = parser.getBuilder(); |
| auto indexTy = builder.getIndexType(); |
| |
| MemRefType memrefType; |
| VectorType resultType; |
| OpAsmParser::OperandType memrefInfo; |
| AffineMapAttr mapAttr; |
| SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
| return failure( |
| parser.parseOperand(memrefInfo) || |
| parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
| AffineVectorLoadOp::getMapAttrName(), |
| result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(memrefType) || parser.parseComma() || |
| parser.parseType(resultType) || |
| parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
| parser.resolveOperands(mapOperands, indexTy, result.operands) || |
| parser.addTypeToList(resultType, result.types)); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { |
| p << " " << op.getMemRef() << '['; |
| if (AffineMapAttr mapAttr = |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
| p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
| p << ']'; |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{op.getMapAttrName()}); |
| p << " : " << op.getMemRefType() << ", " << op.getType(); |
| } |
| |
| /// Verify common invariants of affine.vector_load and affine.vector_store. |
| static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, |
| VectorType vectorType) { |
| // Check that memref and vector element types match. |
| if (memrefType.getElementType() != vectorType.getElementType()) |
| return op->emitOpError( |
| "requires memref and vector types of the same elemental type"); |
| return success(); |
| } |
| |
| static LogicalResult verify(AffineVectorLoadOp op) { |
| MemRefType memrefType = op.getMemRefType(); |
| if (failed(verifyMemoryOpIndexing( |
| op.getOperation(), |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
| op.getMapOperands(), memrefType, |
| /*numIndexOperands=*/op.getNumOperands() - 1))) |
| return failure(); |
| |
| if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, |
| op.getVectorType()))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // AffineVectorStoreOp |
| //===----------------------------------------------------------------------===// |
| |
| void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, |
| Value valueToStore, Value memref, AffineMap map, |
| ValueRange mapOperands) { |
| assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); |
| result.addOperands(valueToStore); |
| result.addOperands(memref); |
| result.addOperands(mapOperands); |
| result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); |
| } |
| |
| // Use identity map. |
| void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result, |
| Value valueToStore, Value memref, |
| ValueRange indices) { |
| auto memrefType = memref.getType().cast<MemRefType>(); |
| int64_t rank = memrefType.getRank(); |
| // Create identity map for memrefs with at least one dimension or () -> () |
| // for zero-dimensional memrefs. |
| auto map = |
| rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); |
| build(builder, result, valueToStore, memref, map, indices); |
| } |
| void AffineVectorStoreOp::getCanonicalizationPatterns( |
| RewritePatternSet &results, MLIRContext *context) { |
| results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context); |
| } |
| |
| static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, |
| OperationState &result) { |
| auto indexTy = parser.getBuilder().getIndexType(); |
| |
| MemRefType memrefType; |
| VectorType resultType; |
| OpAsmParser::OperandType storeValueInfo; |
| OpAsmParser::OperandType memrefInfo; |
| AffineMapAttr mapAttr; |
| SmallVector<OpAsmParser::OperandType, 1> mapOperands; |
| return failure( |
| parser.parseOperand(storeValueInfo) || parser.parseComma() || |
| parser.parseOperand(memrefInfo) || |
| parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, |
| AffineVectorStoreOp::getMapAttrName(), |
| result.attributes) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonType(memrefType) || parser.parseComma() || |
| parser.parseType(resultType) || |
| parser.resolveOperand(storeValueInfo, resultType, result.operands) || |
| parser.resolveOperand(memrefInfo, memrefType, result.operands) || |
| parser.resolveOperands(mapOperands, indexTy, result.operands)); |
| } |
| |
| static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { |
| p << " " << op.getValueToStore(); |
| p << ", " << op.getMemRef() << '['; |
| if (AffineMapAttr mapAttr = |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName())) |
| p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); |
| p << ']'; |
| p.printOptionalAttrDict(op->getAttrs(), |
| /*elidedAttrs=*/{op.getMapAttrName()}); |
| p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); |
| } |
| |
| static LogicalResult verify(AffineVectorStoreOp op) { |
| MemRefType memrefType = op.getMemRefType(); |
| if (failed(verifyMemoryOpIndexing( |
| op.getOperation(), |
| op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()), |
| op.getMapOperands(), memrefType, |
| /*numIndexOperands=*/op.getNumOperands() - 2))) |
| return failure(); |
| |
| if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, |
| op.getVectorType()))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TableGen'd op method definitions |
| //===----------------------------------------------------------------------===// |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" |