| //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" |
| |
| #include "mlir/AsmParser/AsmParser.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/PDL/IR/PDL.h" |
| #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
| #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" |
| #include "mlir/Interfaces/TilingInterface.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| using namespace mlir::transform; |
| |
| #define DEBUG_TYPE "linalg-transforms" |
| |
| /// Extracts a vector of unsigned from an array attribute. Asserts if the |
| /// attribute contains values other than intergers. May truncate. |
| static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) { |
| SmallVector<unsigned> result; |
| result.reserve(attr.size()); |
| for (APInt value : attr.getAsValueRange<IntegerAttr>()) |
| result.push_back(value.getZExtValue()); |
| return result; |
| } |
| |
| /// Extracts a vector of int64_t from an array attribute. Asserts if the |
| /// attribute contains values other than integers. |
| static SmallVector<int64_t> extractI64Array(ArrayAttr attr) { |
| SmallVector<int64_t> result; |
| result.reserve(attr.size()); |
| for (APInt value : attr.getAsValueRange<IntegerAttr>()) |
| result.push_back(value.getSExtValue()); |
| return result; |
| } |
| |
| namespace { |
| /// A simple pattern rewriter that implements no special logic. |
| class SimpleRewriter : public PatternRewriter { |
| public: |
| SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {} |
| }; |
| } // namespace |
| |
| /// Attempts to apply the pattern specified as template argument to the given |
| /// operation. The pattern is expected to have a `returningMatchAndRewrite` |
| /// function that returns the "main" result or failure. Returns failure if the |
| /// pattern failed to apply. Extra arguments are forwarded to the pattern |
| /// constructor. |
| template <typename PatternTy, typename... Args> |
| static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { |
| // Check if the given operation has the type expected by the pattern. |
| using OpTy = typename llvm::function_traits< |
| decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; |
| auto op = dyn_cast<OpTy>(operation); |
| if (!op) |
| return failure(); |
| |
| // Apply the pattern directly to the op. |
| PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); |
| SimpleRewriter rewriter(operation->getContext()); |
| rewriter.setInsertionPoint(operation); |
| auto result = pattern.returningMatchAndRewrite(op, rewriter); |
| if (failed(result)) |
| return failure(); |
| return cast<LinalgOp>(result->getOperation()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // DecomposeOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::DecomposeOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| FailureOr<LinalgOp> windowedNhwc = |
| tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, |
| Conv1DNwcWcfOp>>(target); |
| if (succeeded(windowedNhwc)) { |
| results.push_back(*windowedNhwc); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| FailureOr<LinalgOp> windowedNchw = |
| tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, |
| Conv1DNcwFcwOp>>(target); |
| if (succeeded(windowedNchw)) { |
| results.push_back(*windowedNchw); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| FailureOr<LinalgOp> depthwise = |
| tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); |
| if (succeeded(depthwise)) { |
| results.push_back(*depthwise); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| results.assign(1, nullptr); |
| return emitDefaultSilenceableFailure(target); |
| } |
| //===----------------------------------------------------------------------===// |
| // FuseOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Apply a tiling transformation to all payload ops and store both the |
| /// tiled operation as well as the created tile loops. |
| static LogicalResult applyTilingToAll( |
| Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops, |
| transform::TransformResults &transformResults, |
| function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> |
| applyFn) { |
| SmallVector<Operation *> tiledLinalgOps; |
| SmallVector<SmallVector<Operation *>> loopOps(numLoops); |
| for (unsigned int i = 0; i < numLoops; ++i) |
| loopOps[i].reserve(payloadOps.size()); |
| |
| for (Operation *target : payloadOps) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) |
| return transformOp->emitError("only TilingInterface ops are supported"); |
| |
| SimpleRewriter rewriter(target->getContext()); |
| rewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
| applyFn(tilingInterfaceOp); |
| if (failed(tiledResults)) |
| return failure(); |
| |
| // Perform the replacement of tiled and fused values. |
| SmallVector<Operation *> opsToReplace{target}; |
| llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
| for (Operation *toReplace : opsToReplace) { |
| SmallVector<Value> replacements; |
| replacements.reserve(toReplace->getNumResults()); |
| for (OpResult res : toReplace->getResults()) { |
| auto it = tiledResults->replacements.find(res); |
| if (it == tiledResults->replacements.end()) |
| replacements.push_back(res); |
| else |
| replacements.push_back(it->getSecond()); |
| } |
| rewriter.replaceOp(toReplace, replacements); |
| } |
| |
| // Report back the relevant handles to the transform op. |
| tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); |
| assert(tiledResults->loops.size() == numLoops && |
| "Mismatched number of loops, tile and fuse transform should have " |
| "failed"); |
| for (unsigned int i = 0; i < numLoops; ++i) |
| loopOps[i].push_back(tiledResults->loops[i]); |
| } |
| |
| transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); |
| for (unsigned int i = 0; i < numLoops; ++i) |
| transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); |
| |
| return success(); |
| } |
| |
| /// Parse a tiling-like operation that returns the tiled op as well as the |
| /// created tile loops. The function counts the non-zero tile sizes to compute |
| /// the number of results. |
| static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, |
| StringRef sizesAttrName) { |
| OpAsmParser::UnresolvedOperand targetOperand; |
| SMLoc opLoc = parser.getCurrentLocation(); |
| if (parser.parseOperand(targetOperand) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return failure(); |
| Attribute sizesAttr = result.attributes.get(sizesAttrName); |
| if (!sizesAttr) |
| return parser.emitError(opLoc) |
| << "expected '" << sizesAttrName << "' attribute"; |
| auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); |
| if (!sizesArrayAttr) |
| return parser.emitError(opLoc) |
| << "'" << sizesAttrName << "' attribute must be an array"; |
| Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); |
| size_t numExpectedLoops = |
| sizesArrayAttr.size() - |
| llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); |
| result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); |
| if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) |
| return failure(); |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, |
| mlir::transform::TransformState &state) { |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes()); |
| SmallVector<int64_t> tileInterchange = |
| extractFromI64ArrayAttr(getTileInterchange()); |
| |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.interchangeVector = tileInterchange; |
| tilingOptions = tilingOptions.setTileSizes(tileSizes); |
| scf::SCFTileAndFuseOptions tileAndFuseOptions; |
| tileAndFuseOptions.tilingOptions = tilingOptions; |
| LogicalResult result = applyTilingToAll( |
| getOperation(), state.getPayloadOps(getTarget()), |
| tileSizes.size() - llvm::count(tileSizes, 0), transformResults, |
| [&](TilingInterface tilingInterfaceOp) |
| -> FailureOr<scf::SCFTileAndFuseResult> { |
| SimpleRewriter rewriter(getContext()); |
| return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( |
| rewriter, tilingInterfaceOp, tileAndFuseOptions); |
| }); |
| return DiagnosedSilenceableFailure(result); |
| } |
| |
| ParseResult transform::FuseOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| return parseTileLikeOp( |
| parser, result, |
| transform::FuseOp::getTileSizesAttrName(result.name).getValue()); |
| } |
| |
| void transform::FuseOp::print(OpAsmPrinter &p) { |
| p << ' '; |
| p << getTarget(); |
| p.printOptionalAttrDict((*this)->getAttrs()); |
| } |
| |
| LogicalResult transform::FuseOp::verify() { |
| SmallVector<int64_t> permutation = |
| extractFromI64ArrayAttr(getTileInterchange()); |
| auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); |
| if (!std::is_permutation(sequence.begin(), sequence.end(), |
| permutation.begin(), permutation.end())) { |
| return emitOpError() << "expects interchange to be a permutation, found " |
| << getTileInterchange(); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FuseIntoContainingOp |
| //===----------------------------------------------------------------------===// |
| |
| void transform::FuseIntoContainingOp::build(OpBuilder &builder, |
| OperationState &result, |
| Value producerOp, |
| Value containingOp) { |
| result.addOperands({producerOp, containingOp}); |
| result.addTypes(pdl::OperationType::get(builder.getContext())); |
| } |
| |
| /// Find the first "extract" user of `producerOp` and tile it right before its |
| /// use. The tiled op is fused under the `containingOp`. |
| /// Return this fused op on success or nullptr if anything fails. |
| static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, |
| Diagnostic &diag, |
| Operation *producerOp, |
| Operation *containingOp) { |
| LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n"); |
| auto tileableProducer = dyn_cast<TilingInterface>(producerOp); |
| if (!tileableProducer) { |
| diag.attachNote(producerOp->getLoc()) |
| << "producer is not a TileableInterface: " << *producerOp; |
| return nullptr; |
| } |
| |
| // Search the producer slices accessed within the containing operation. |
| // TODO: Generalize to more extract/insert/parallel_insert triples, maybe |
| // evolve into an interface. |
| auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { |
| auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); |
| return sliceOp && containingOp->isProperAncestor(sliceOp); |
| }); |
| |
| // Find a fusion opportunity. |
| if (it == tileableProducer->getUsers().end()) { |
| diag.attachNote(tileableProducer->getLoc()) |
| << "could not find fusion opportunity for: " << *tileableProducer; |
| return nullptr; |
| } |
| auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it); |
| |
| // Try to fuse the producer in-place. |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(sliceOpToTile); |
| |
| // Tile the producer. |
| int64_t resultNumber = |
| sliceOpToTile.getSource().cast<OpResult>().getResultNumber(); |
| LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); |
| |
| FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue( |
| rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), |
| sliceOpToTile.getMixedSizes()); |
| if (failed(tiledProducer)) { |
| diag.attachNote(tileableProducer->getLoc()) |
| << "failed to tile producer op: " << *tileableProducer; |
| return nullptr; |
| } |
| LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); |
| |
| // Replace the extract op. |
| Operation *fusedOp = tiledProducer->getDefiningOp(); |
| rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); |
| return fusedOp; |
| } |
| |
| /// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure |
| /// it is exactly the `containingOp`, otherwise bail. |
| /// Then, find the first "extract" user of the tied block argument and tile it |
| /// right before its "extract" use. The tiled op is fused under the |
| /// `containingOp`. |
| /// Return this fused op on success or nullptr if anything fails. |
| static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( |
| RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, |
| Operation *containingOp) { |
| LLVM_DEBUG( |
| llvm::dbgs() << "Try to fuse an extract use through block argument\n"); |
| |
| auto tileableProducer = dyn_cast<TilingInterface>(producerOp); |
| if (!tileableProducer) { |
| diag.attachNote(producerOp->getLoc()) |
| << "producer is not a TileableInterface: " << *producerOp; |
| return nullptr; |
| } |
| |
| // Search the first use by a "scf::ForeachThreadOp" user. |
| scf::ForeachThreadOp foreachThreadOp; |
| auto itProducerUses = |
| llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { |
| foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner()); |
| return foreachThreadOp; |
| }); |
| // If it's not from the containing op, return. |
| if (!foreachThreadOp || foreachThreadOp != containingOp) { |
| diag.attachNote(tileableProducer->getLoc()) |
| << "could not find a use by the containing op: " << *tileableProducer; |
| return nullptr; |
| } |
| |
| // Search the producer slices accessed within the containing |
| // operation. |
| // TODO: Generalize to more extract/insert/parallel_insert triples. |
| // Maybe evolve into an interface. |
| OpOperand *pUse = &(*itProducerUses); |
| BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse); |
| |
| // Search the producer slices accessed within the containing operation. |
| // TODO: Generalize to more extract/insert/parallel_insert triples, maybe |
| // evolve into an interface. |
| auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { |
| auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); |
| return sliceOp && containingOp->isProperAncestor(sliceOp); |
| }); |
| |
| // Find a fusion opportunity. |
| if (itBBArgUsers == bbArg.getUsers().end()) { |
| diag.attachNote(containingOp->getLoc()) |
| << "could not find fusion opportunity for bbArg: " << bbArg; |
| return nullptr; |
| } |
| auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers); |
| |
| // Try to fuse the producer in-place. |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(sliceOpToTile); |
| |
| // Replace the use in the tileableProducer before tiling: clone, replace and |
| // then tile. |
| int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber(); |
| LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); |
| |
| // Gather destination tensors. |
| SmallVector<Value> destinationTensors; |
| if (failed(tensor::getOrCreateDestinations( |
| rewriter, tileableProducer->getLoc(), tileableProducer, |
| destinationTensors))) { |
| diag.attachNote(tileableProducer->getLoc()) |
| << "failed to get destination tensors for: " << *tileableProducer; |
| return nullptr; |
| } |
| |
| BlockAndValueMapping bvm; |
| bvm.map(destinationTensors[resultNumber], bbArg); |
| auto tileableProducerClone = |
| cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm)); |
| auto scopeGuard = |
| llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); |
| |
| // Tile the producer. |
| FailureOr<Value> tiledProducer = |
| tileableProducerClone.generateResultTileValue( |
| rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), |
| sliceOpToTile.getMixedSizes()); |
| if (failed(tiledProducer)) { |
| diag.attachNote(tileableProducer->getLoc()) |
| << "failed to tile producer op: " << *tileableProducer; |
| return nullptr; |
| } |
| LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"); |
| |
| // Replace the extract op. |
| Operation *fusedOp = tiledProducer->getDefiningOp(); |
| rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber)); |
| |
| // Replace the use in containingOp. |
| rewriter.updateRootInPlace(containingOp, [&]() { |
| containingOp->setOperand(pUse->getOperandNumber(), |
| destinationTensors.front()); |
| }); |
| |
| return fusedOp; |
| } |
| |
| static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, |
| Operation *producerOp, |
| Operation *containingOp) { |
| LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n"); |
| |
| // Gather all uses inside the containing op. |
| SmallVector<OpOperand *> uses; |
| for (OpResult result : producerOp->getOpResults()) { |
| for (OpOperand &use : result.getUses()) { |
| if (containingOp->isProperAncestor(use.getOwner())) { |
| uses.push_back(&use); |
| continue; |
| } |
| // Cannot clone and fuse if the use is by the containing op itself: fail |
| // immediately. |
| if (containingOp == use.getOwner()) { |
| diag.attachNote(producerOp->getLoc()) |
| << "producer op use by containing op cannot be fused by cloning"; |
| return nullptr; |
| } |
| } |
| } |
| |
| // Check for a non-empty list of fusion opportunities. |
| if (uses.empty()) { |
| diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; |
| return nullptr; |
| } |
| |
| // Clone and fuse inside the containing op. |
| Operation *fusedOp = nullptr; |
| OpOperand *use = uses.front(); |
| // Parallel insert slice is not a valid clone destination. |
| // TODO: Generalize to other type of ops. |
| assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && |
| "Parallel insert slice is not a valid clone destination"); |
| unsigned resultNumber = use->get().cast<OpResult>().getResultNumber(); |
| LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n"); |
| |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(use->getOwner()); |
| fusedOp = rewriter.clone(*producerOp); |
| rewriter.updateRootInPlace( |
| use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); |
| |
| return fusedOp; |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::FuseIntoContainingOp::apply(transform::TransformResults &results, |
| transform::TransformState &state) { |
| SmallVector<Operation *> fusedOps; |
| ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp()); |
| // If nothing to fuse, propagate success. |
| if (producerOps.empty()) { |
| results.set(getFusedOp().cast<OpResult>(), |
| SmallVector<mlir::Operation *>{}); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp()); |
| if (containingOps.size() != 1) { |
| return emitDefiniteFailure() |
| << "requires exactly one containing_op handle (got " |
| << containingOps.size() << ")"; |
| } |
| Operation *containingOp = containingOps.front(); |
| |
| // Helper function to find the next producer that should be fused. Take any |
| // producer that has a use inside the containing op. |
| SmallVector<Operation *> remainingProducers(producerOps.begin(), |
| producerOps.end()); |
| auto getNextProducer = [&]() -> FailureOr<Operation *> { |
| for (const auto &it : enumerate(remainingProducers)) { |
| Operation *producerOp = it.value(); |
| // The containing op may be a user of producerOp: use isAncestor. |
| int64_t numUsesInContainingOp = |
| llvm::count_if(producerOp->getUsers(), [&](Operation *op) { |
| return containingOp->isAncestor(op); |
| }); |
| // TODO: When resolving the TODO below (no duplicate ops), take an op |
| // that has no use among the remaining producers. This is a topological |
| // sorting. |
| if (numUsesInContainingOp > 0) { |
| if (numUsesInContainingOp == 1) |
| remainingProducers.erase(remainingProducers.begin() + it.index()); |
| return producerOp; |
| } |
| } |
| return failure(); |
| }; |
| |
| IRRewriter rewriter(getContext()); |
| while (!remainingProducers.empty()) { |
| auto nextProducer = getNextProducer(); |
| if (failed(nextProducer)) { |
| results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>()); |
| Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); |
| diag << "could not find next producer to fuse into container"; |
| return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); |
| } |
| |
| Operation *producerOp = *nextProducer; |
| |
| // Default diagnostic, to be complemented with more failure information. |
| Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); |
| diag << "could not fuse " << *producerOp << " into " << *containingOp; |
| |
| // TODO: If there are multiple uses of the producer in the containing op, |
| // we currently tile/clone the op multiple times (once per use). In some |
| // cases, we can tile/clone once and reuse the value for each use. |
| // Futhermore, producers should then be traversed according to a |
| // topological sorting. |
| Operation *tiled = |
| tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); |
| if (tiled) { |
| LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n" |
| << *containingOp); |
| fusedOps.push_back(tiled); |
| continue; |
| } |
| |
| Operation *tiledContainingOpOperand = |
| tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( |
| rewriter, diag, producerOp, containingOp); |
| if (tiledContainingOpOperand) { |
| LLVM_DEBUG(llvm::dbgs() |
| << "\nFused an extract use through block argument\n" |
| << *containingOp); |
| fusedOps.push_back(tiledContainingOpOperand); |
| continue; |
| } |
| |
| Operation *cloned = |
| cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); |
| if (cloned) { |
| LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n" |
| << *containingOp); |
| fusedOps.push_back(cloned); |
| continue; |
| } |
| results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>()); |
| return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); |
| } |
| |
| results.set(getFusedOp().cast<OpResult>(), fusedOps); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GeneralizeOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| // Exit early if no transformation is needed. |
| if (isa<GenericOp>(target)) { |
| results.push_back(target); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); |
| if (succeeded(generic)) { |
| results.push_back(generic->getOperation()); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| results.assign(1, nullptr); |
| return emitDefaultSilenceableFailure(target); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InterchangeOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::InterchangeOp::applyToOne(linalg::GenericOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| SmallVector<unsigned> interchangeVector = |
| extractUIntArray(getIteratorInterchange()); |
| // Exit early if no transformation is needed. |
| if (interchangeVector.empty()) { |
| results.push_back(target); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| SimpleRewriter rewriter(target->getContext()); |
| FailureOr<GenericOp> res = |
| interchangeGenericOp(rewriter, target, interchangeVector); |
| if (failed(res)) |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| results.push_back(res->getOperation()); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| LogicalResult transform::InterchangeOp::verify() { |
| SmallVector<unsigned> permutation = |
| extractUIntArray(getIteratorInterchange()); |
| auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size())); |
| if (!std::is_permutation(sequence.begin(), sequence.end(), |
| permutation.begin(), permutation.end())) { |
| return emitOpError() |
| << "expects iterator_interchange to be a permutation, found " |
| << getIteratorInterchange(); |
| } |
| return success(); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // MatchOp |
| //===---------------------------------------------------------------------===// |
| |
| void transform::MatchOp::build(OpBuilder &builder, OperationState &result, |
| Value target, ArrayRef<StringRef> opNames) { |
| result.addOperands(target); |
| result.addAttribute(MatchOp::getOpsAttrName(result.name), |
| builder.getStrArrayAttr(opNames)); |
| result.addTypes(pdl::OperationType::get(builder.getContext())); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::MatchOp::apply(transform::TransformResults &results, |
| transform::TransformState &state) { |
| llvm::StringSet<> strs; |
| if (getOps().has_value()) |
| strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), |
| getOps()->getAsValueRange<StringAttr>().end()); |
| |
| ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); |
| if (payloadOps.size() != 1) { |
| results.set(getResult().cast<OpResult>(), {}); |
| return DiagnosedSilenceableFailure( |
| this->emitOpError("requires exactly one target handle")); |
| } |
| |
| SmallVector<Operation *> res; |
| auto matchFun = [&](Operation *op) { |
| if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) |
| return; |
| |
| // Interfaces cannot be matched by name, just by ID. |
| // So we specifically encode the interfaces we care about for this op. |
| if (getInterface().has_value()) { |
| auto iface = getInterface().value(); |
| if (iface == transform::MatchInterfaceEnum::LinalgOp && |
| !isa<linalg::LinalgOp>(op)) |
| return; |
| if (iface == transform::MatchInterfaceEnum::TilingInterface && |
| isa<TilingInterface>(op)) |
| return; |
| } |
| |
| // Check if all specified attributes match. |
| if (getOpAttrs().has_value()) { |
| DictionaryAttr opAttrs = getOpAttrs().value(); |
| for (NamedAttribute attr : opAttrs) { |
| if (attr.getName() == getInterfaceAttrName() || |
| attr.getName() == getOpsAttrName()) |
| continue; |
| if (!op->hasAttr(attr.getName())) |
| return; |
| if (op->getAttr(attr.getName()) != attr.getValue()) |
| return; |
| } |
| } |
| |
| if (getFilterResultType().has_value()) { |
| Type t = getFilterResultType().value(); |
| if (op->getNumResults() != 1 || op->getResultTypes().front() != t) |
| return; |
| } |
| |
| // All constraints are satisfied. |
| res.push_back(op); |
| return; |
| }; |
| |
| payloadOps.front()->walk(matchFun); |
| results.set(getResult().cast<OpResult>(), res); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // MultiTileSizesOp |
| //===---------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( |
| LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { |
| OpBuilder builder(target.getContext()); |
| builder.setInsertionPoint(target); |
| OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); |
| OpFoldResult divisor = builder.getIndexAttr(getDivisor()); |
| FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( |
| builder, target, getDimension(), targetSize, divisor); |
| if (failed(spec)) { |
| return emitSilenceableError() << "could not generate tile size computation"; |
| } |
| |
| AffineExpr s0 = builder.getAffineSymbolExpr(0); |
| AffineExpr s1 = builder.getAffineSymbolExpr(1); |
| Operation *splitPoint = |
| makeComposedAffineApply(builder, target.getLoc(), s0 * s1, |
| {spec->lowTileSize, spec->lowTripCount}); |
| Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); |
| Operation *highTileSize = spec->highTileSize.getDefiningOp(); |
| assert(lowTileSize && highTileSize && splitPoint && |
| "tile sizes are not produced by operations"); |
| results.reserve(results.size() + 3); |
| results.push_back(lowTileSize); |
| results.push_back(highTileSize); |
| results.push_back(splitPoint); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::MultiTileSizesOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTarget(), effects); |
| producesHandle(getResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===---------------------------------------------------------------------===// |
| // PadOp |
| //===---------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::PadOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| // Convert the integer packing flags to booleans. |
| SmallVector<bool> packPaddings; |
| for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) |
| packPaddings.push_back(static_cast<bool>(packPadding)); |
| |
| // Convert the padding values to attributes. |
| SmallVector<Attribute> paddingValues; |
| for (auto const &it : |
| llvm::zip(getPaddingValues(), target->getOperandTypes())) { |
| auto attr = std::get<0>(it).dyn_cast<TypedAttr>(); |
| if (!attr) { |
| emitOpError("expects padding values to be typed attributes"); |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| Type elementType = getElementTypeOrSelf(std::get<1>(it)); |
| // Try to parse string attributes to obtain an attribute of element type. |
| if (auto stringAttr = attr.dyn_cast<StringAttr>()) { |
| paddingValues.push_back( |
| parseAttribute(attr.cast<StringAttr>(), elementType)); |
| if (!paddingValues.back()) { |
| auto diag = this->emitOpError("expects a padding that parses to ") |
| << elementType << ", got " << std::get<0>(it); |
| diag.attachNote(target.getLoc()) << "when applied to this op"; |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| continue; |
| } |
| // Otherwise, add the attribute directly. |
| if (attr.getType() != elementType) { |
| auto diag = this->emitOpError("expects a padding value of type ") |
| << elementType << ", got " << attr; |
| diag.attachNote(target.getLoc()) << "when applied to this op"; |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| paddingValues.push_back(attr); |
| } |
| |
| // Extract the transpose vectors. |
| SmallVector<SmallVector<int64_t>> transposePaddings; |
| for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) |
| transposePaddings.push_back( |
| extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); |
| |
| LinalgPaddingOptions paddingOptions; |
| paddingOptions.setPaddingValues(paddingValues); |
| paddingOptions.setPaddingDimensions( |
| extractFromI64ArrayAttr(getPaddingDimensions())); |
| paddingOptions.setPackPaddings(packPaddings); |
| paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); |
| paddingOptions.setTransposePaddings(transposePaddings); |
| |
| FailureOr<LinalgOp> result = |
| tryApply<LinalgPaddingPattern>(target, paddingOptions); |
| if (succeeded(result)) { |
| results.push_back(result->getOperation()); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| results.assign(1, nullptr); |
| return emitDefaultSilenceableFailure(target); |
| } |
| |
| LogicalResult transform::PadOp::verify() { |
| SmallVector<int64_t> packPaddings = |
| extractFromI64ArrayAttr(getPackPaddings()); |
| if (any_of(packPaddings, [](int64_t packPadding) { |
| return packPadding != 0 && packPadding != 1; |
| })) { |
| return emitOpError() |
| << "expects pack_paddings to contain booleans (0/1), found " |
| << getPackPaddings(); |
| } |
| |
| SmallVector<int64_t> paddingDimensions = |
| extractFromI64ArrayAttr(getPaddingDimensions()); |
| if (any_of(paddingDimensions, |
| [](int64_t paddingDimension) { return paddingDimension < 0; })) { |
| return emitOpError() << "expects padding_dimensions to contain positive " |
| "integers, found " |
| << getPaddingDimensions(); |
| } |
| |
| SmallVector<int64_t> hoistPaddings = |
| extractFromI64ArrayAttr(getHoistPaddings()); |
| if (any_of(hoistPaddings, |
| [](int64_t hoistPadding) { return hoistPadding < 0; })) { |
| return emitOpError() |
| << "expects hoist_paddings to contain positive integers, found " |
| << getHoistPaddings(); |
| } |
| |
| ArrayAttr transposes = getTransposePaddings(); |
| for (Attribute attr : transposes) { |
| SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); |
| auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); |
| if (!std::is_permutation(sequence.begin(), sequence.end(), |
| transpose.begin(), transpose.end())) { |
| return emitOpError() |
| << "expects transpose_paddings to be a permutation, found " |
| << attr; |
| } |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PromoteOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::PromoteOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| LinalgPromotionOptions promotionOptions; |
| if (!getOperandsToPromote().empty()) |
| promotionOptions = promotionOptions.setOperandsToPromote( |
| extractFromI64ArrayAttr(getOperandsToPromote())); |
| if (getUseFullTilesByDefault()) |
| promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( |
| getUseFullTilesByDefault()); |
| if (getUseAlloca()) |
| promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); |
| if (!getUseFullTileBuffers().empty()) |
| promotionOptions = promotionOptions.setUseFullTileBuffers( |
| llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); |
| if (getAlignment().has_value()) |
| promotionOptions = promotionOptions.setAlignment(*getAlignment()); |
| |
| if (failed(promoteSubviewsPrecondition(target, promotionOptions))) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| |
| SimpleRewriter rewriter(target->getContext()); |
| rewriter.setInsertionPoint(target); |
| FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); |
| if (failed(res)) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| results.push_back(target); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ScalarizeOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { |
| SmallVector<Value, 4> tileSizes; |
| Location loc = target.getLoc(); |
| SmallVector<OpFoldResult> allShapeSizes = |
| target.createFlatListOfOperandDims(b, loc); |
| AffineMap map = target.getShapesToLoopsMap(); |
| if (!map) |
| return tileSizes; |
| IRRewriter rewriter(b); |
| SmallVector<OpFoldResult> shapeSizes = |
| makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, |
| allShapeSizes); |
| // If the shape size is dynamic, tile by 1. |
| // Otherwise, do not tile (i.e. tile size 0). |
| for (OpFoldResult shapeSize : shapeSizes) { |
| tileSizes.push_back(getConstantIntValue(shapeSize) |
| ? b.create<arith::ConstantIndexOp>(loc, 0) |
| : b.create<arith::ConstantIndexOp>(loc, 1)); |
| } |
| return tileSizes; |
| }); |
| SmallVector<int64_t> emptyTileSizes; |
| SimpleRewriter rewriter(getContext()); |
| rewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp( |
| rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions); |
| if (failed(maybeTilingResult)) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| |
| results.append(maybeTilingResult->tiledOps); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SplitOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, |
| TransformState &state) { |
| // Collect the dynamic split points if provided. |
| ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); |
| SimpleRewriter rewriter(getContext()); |
| SmallVector<OpFoldResult> splitPoints; |
| splitPoints.reserve(payload.size()); |
| if (getDynamicSplitPoint()) { |
| auto diag = DiagnosedSilenceableFailure::success(); |
| splitPoints = llvm::to_vector(llvm::map_range( |
| state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { |
| if (op->getNumResults() != 1 || |
| !op->getResult(0).getType().isIndex()) { |
| diag = emitSilenceableError() |
| << "expected dynamic split point handle to point to a " |
| "single-result index-typed op"; |
| diag.attachNote(op->getLoc()) << "dynamic split point"; |
| } |
| return OpFoldResult(op->getResult(0)); |
| })); |
| if (diag.isSilenceableFailure()) { |
| results.set(getFirst().cast<OpResult>(), {}); |
| results.set(getSecond().cast<OpResult>(), {}); |
| return diag; |
| } |
| |
| if (splitPoints.size() != payload.size()) { |
| return emitDefiniteFailure() |
| << "expected the dynamic split point handle to point to as " |
| "many operations (" |
| << splitPoints.size() << ") as the target handle (" |
| << payload.size() << ")"; |
| } |
| } else { |
| splitPoints.resize(payload.size(), |
| rewriter.getIndexAttr(getStaticSplitPoint())); |
| } |
| |
| // Split each target operation. |
| SmallVector<Operation *> first, second; |
| for (const auto &pair : llvm::zip(payload, splitPoints)) { |
| Operation *target = std::get<0>(pair); |
| auto linalgOp = dyn_cast<LinalgOp>(target); |
| if (!linalgOp) { |
| auto diag = emitSilenceableError() << "only applies to structured ops"; |
| diag.attachNote(target->getLoc()) << "target op"; |
| results.set(getFirst().cast<OpResult>(), {}); |
| results.set(getSecond().cast<OpResult>(), {}); |
| return diag; |
| } |
| |
| if (getDimension() >= linalgOp.getNumLoops()) { |
| auto diag = emitSilenceableError() << "dimension " << getDimension() |
| << " does not exist in target op"; |
| diag.attachNote(target->getLoc()) << "target op"; |
| results.set(getFirst().cast<OpResult>(), {}); |
| results.set(getSecond().cast<OpResult>(), {}); |
| return diag; |
| } |
| |
| rewriter.setInsertionPoint(linalgOp); |
| std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( |
| rewriter, cast<TilingInterface>(linalgOp.getOperation()), |
| getDimension(), std::get<1>(pair)); |
| } |
| |
| results.set(getFirst().cast<OpResult>(), first); |
| results.set(getSecond().cast<OpResult>(), second); |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void SplitOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTarget(), effects); |
| if (getDynamicSplitPoint()) |
| onlyReadsHandle(getDynamicSplitPoint(), effects); |
| producesHandle(getResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; |
| IntegerAttr staticSplitPoint; |
| auto pdlOperationType = |
| pdl::OperationType::get(parser.getBuilder().getContext()); |
| if (parser.parseOperand(target) || |
| parser.resolveOperand(target, pdlOperationType, result.operands) || |
| parser.parseKeyword("after")) |
| return failure(); |
| |
| OptionalParseResult dynamicPointParseResult = |
| parser.parseOptionalOperand(dynamicSplitPoint); |
| if (!dynamicPointParseResult.has_value()) { |
| int64_t staticSplitPointValue; |
| if (failed(parser.parseInteger(staticSplitPointValue))) |
| return failure(); |
| |
| staticSplitPoint = |
| parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); |
| } else { |
| if (failed(*dynamicPointParseResult) || |
| parser.resolveOperand(dynamicSplitPoint, pdlOperationType, |
| result.operands)) { |
| return failure(); |
| } |
| |
| staticSplitPoint = |
| parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize); |
| } |
| |
| result.addAttribute( |
| SplitOp::getStaticSplitPointAttrName(result.name).getValue(), |
| staticSplitPoint); |
| if (failed(parser.parseOptionalAttrDict(result.attributes))) |
| return failure(); |
| |
| result.addTypes({pdlOperationType, pdlOperationType}); |
| return success(); |
| } |
| |
| void SplitOp::print(OpAsmPrinter &printer) { |
| printer << " " << getTarget() << " after "; |
| int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); |
| if (staticSplitSize != ShapedType::kDynamicSize) |
| printer << staticSplitSize; |
| else |
| printer << getDynamicSplitPoint(); |
| printer << " "; |
| printer.printOptionalAttrDict(getOperation()->getAttrs(), |
| {getStaticSplitPointAttrName()}); |
| } |
| |
| LogicalResult SplitOp::verify() { |
| if ((static_cast<int64_t>(getStaticSplitPoint()) != |
| ShapedType::kDynamicSize) ^ |
| (getDynamicSplitPoint() == nullptr)) { |
| return emitOpError() << "expects either a dynamic or a static split " |
| "point to be provided"; |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SplitReductionOp |
| //===----------------------------------------------------------------------===// |
| |
| void transform::SplitReductionOp::build( |
| OpBuilder &builder, OperationState &result, Value target, |
| int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, |
| bool useScalingAlgorithm, bool useAlloc) { |
| MLIRContext *ctx = builder.getContext(); |
| result.addOperands(target); |
| result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), |
| builder.getI64IntegerAttr(splitFactor)); |
| result.addAttribute( |
| SplitReductionOp::getInsertSplitDimensionAttrName(result.name), |
| builder.getI64IntegerAttr(insertSplitDimension)); |
| if (innerParallel) { |
| result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), |
| builder.getUnitAttr()); |
| } |
| if (useScalingAlgorithm) { |
| result.addAttribute( |
| SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), |
| builder.getUnitAttr()); |
| } |
| if (useAlloc) { |
| result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), |
| builder.getUnitAttr()); |
| } |
| auto resultType = pdl::OperationType::get(ctx); |
| result.addTypes({resultType, resultType, resultType, resultType}); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| ControlSplitReductionFn splitFn = [&](LinalgOp) { |
| return linalg::SplitReductionOptions{int64_t(getSplitFactor()), |
| unsigned(getInsertSplitDimension()), |
| bool(getInnerParallel())}; |
| }; |
| SimpleRewriter rewriter(getContext()); |
| rewriter.setInsertionPoint(target); |
| FailureOr<SplitReductionResult> splitResult = |
| (getUseScalingAlgorithm()) |
| ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) |
| : splitReduction(rewriter, target, splitFn, getUseAlloc()); |
| if (failed(splitResult)) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| |
| results.push_back(splitResult->initOrAlloc); |
| results.push_back(splitResult->fillOp); |
| results.push_back(splitResult->splitLinalgOp); |
| results.push_back(splitResult->resultCombiningLinalgOp); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SplitReductionOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( |
| linalg::LinalgOp target, SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| SimpleRewriter rewriter(getContext()); |
| rewriter.setInsertionPoint(target); |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes()); |
| SmallVector<OpFoldResult> sizes; |
| for (int64_t size : tileSizes) { |
| sizes.push_back(rewriter.getIndexAttr(size)); |
| } |
| |
| FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf( |
| rewriter, cast<PartialReductionOpInterface>(target.getOperation()), |
| sizes); |
| |
| if (failed(result)) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| results.push_back(result->initialOp); |
| results.push_back(result->parallelTiledOp); |
| results.push_back(result->mergeOp); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TileOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::TileOp::apply(TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); |
| |
| ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); |
| SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; |
| dynamicSizeProducers.reserve(getDynamicSizes().size()); |
| for (Value dynamicSizeProducerHandle : getDynamicSizes()) { |
| dynamicSizeProducers.push_back( |
| state.getPayloadOps(dynamicSizeProducerHandle)); |
| |
| if (dynamicSizeProducers.back().size() != targets.size()) { |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() |
| << "expected as many dynamic size-producing operations (" |
| << dynamicSizeProducers.back().size() << ") as target ops (" |
| << targets.size() << ")"; |
| diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; |
| return diag; |
| } |
| |
| for (Operation *op : dynamicSizeProducers.back()) { |
| if (op->getNumResults() == 1 && |
| op->getResult(0).getType().isa<IndexType>()) |
| continue; |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() << "expected sizes to be produced by ops " |
| "with a single index-type result"; |
| diag.attachNote(op->getLoc()) << "size producer op"; |
| diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; |
| return diag; |
| } |
| } |
| |
| SmallVector<Operation *> tiled; |
| SmallVector<SmallVector<Operation *, 4>, 4> loops; |
| loops.resize(getLoops().size()); |
| for (auto &en : llvm::enumerate(targets)) { |
| auto linalgOp = dyn_cast<LinalgOp>(en.value()); |
| if (!linalgOp) { |
| DiagnosedSilenceableFailure diag = emitSilenceableError() |
| << "only linalg ops are supported"; |
| diag.attachNote(en.value()->getLoc()) << "target op"; |
| return diag; |
| } |
| |
| scf::SCFTilingOptions tilingOptions; |
| unsigned index = en.index(); |
| if (!tileSizes.empty()) { |
| tilingOptions.setTileSizeComputationFunction( |
| [&, index](OpBuilder &b, Operation *) { |
| SmallVector<Value, 4> sizes; |
| sizes.reserve(tileSizes.size()); |
| unsigned dynamicIdx = 0; |
| for (OpFoldResult ofr : getMixedSizes()) { |
| if (auto attr = ofr.dyn_cast<Attribute>()) { |
| sizes.push_back(b.create<arith::ConstantIndexOp>( |
| getLoc(), attr.cast<IntegerAttr>().getInt())); |
| } else { |
| sizes.push_back( |
| dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); |
| } |
| } |
| return sizes; |
| }); |
| } |
| |
| tilingOptions.setInterchange(extractI64Array(getInterchange())); |
| SimpleRewriter rewriter(linalgOp.getContext()); |
| FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp( |
| rewriter, cast<TilingInterface>(linalgOp.getOperation()), |
| tilingOptions); |
| if (failed(maybeTilingResult)) |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| |
| if (linalgOp.hasBufferSemantics()) |
| rewriter.eraseOp(linalgOp); |
| else |
| rewriter.replaceOp(linalgOp, |
| maybeTilingResult->loops.front()->getResults()); |
| |
| tiled.append(maybeTilingResult->tiledOps); |
| for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) |
| loops[en2.index()].push_back(en2.value()); |
| } |
| |
| transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); |
| for (const auto &en : llvm::enumerate(loops)) |
| transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { |
| ValueRange dynamic = getDynamicSizes(); |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); |
| SmallVector<OpFoldResult> results; |
| results.reserve(tileSizes.size()); |
| unsigned dynamicPos = 0; |
| Builder builder(getContext()); |
| for (int64_t size : tileSizes) { |
| if (size == ShapedType::kDynamicSize) { |
| results.push_back(dynamic[dynamicPos++]); |
| } else { |
| results.push_back(builder.getIndexAttr(size)); |
| } |
| } |
| return results; |
| } |
| |
| ParseResult transform::TileOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::UnresolvedOperand target; |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; |
| ArrayAttr staticSizes; |
| auto pdlOperationType = pdl::OperationType::get(parser.getContext()); |
| if (parser.parseOperand(target) || |
| parser.resolveOperand(target, pdlOperationType, result.operands) || |
| parseDynamicIndexList(parser, dynamicSizes, staticSizes, |
| ShapedType::kDynamicSize) || |
| parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return ParseResult::failure(); |
| |
| result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); |
| size_t numExpectedLoops = |
| staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); |
| result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); |
| return success(); |
| } |
| |
| void TileOp::print(OpAsmPrinter &p) { |
| p << ' ' << getTarget(); |
| printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), |
| ShapedType::kDynamicSize); |
| p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); |
| } |
| |
| void transform::TileOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTarget(), effects); |
| onlyReadsHandle(getDynamicSizes(), effects); |
| producesHandle(getTiledLinalgOp(), effects); |
| producesHandle(getLoops(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TileToForeachThreadOp |
| //===----------------------------------------------------------------------===// |
| |
| void transform::TileToForeachThreadOp::build( |
| OpBuilder &builder, OperationState &result, Value target, |
| ArrayRef<int64_t> staticTileSizes, transform::TileSizesSpec, |
| ArrayRef<int64_t> threadDimMapping) { |
| return build(builder, result, target, |
| getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), |
| TileSizesSpec(), threadDimMapping); |
| } |
| |
| void transform::TileToForeachThreadOp::build( |
| OpBuilder &builder, OperationState &result, Value target, |
| ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec, |
| ArrayRef<int64_t> threadDimMapping) { |
| SmallVector<int64_t> staticTileSizes; |
| SmallVector<Value> dynamicTileSizes; |
| dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, |
| ShapedType::kDynamicSize); |
| // Call the default builder which sets up the proper operands segment sizes |
| // attributes for multiple variadic operands. In the absence of this, horrible |
| // bugs ensue. |
| MLIRContext *ctx = builder.getContext(); |
| auto operationType = pdl::OperationType::get(ctx); |
| auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); |
| ArrayAttr threadDimMappingAttr; |
| if (!threadDimMapping.empty()) |
| threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); |
| build(builder, result, TypeRange{operationType, operationType}, target, |
| /*numThreads=*/ValueRange{}, dynamicTileSizes, |
| /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, |
| threadDimMappingAttr); |
| } |
| |
| void transform::TileToForeachThreadOp::build( |
| OpBuilder &builder, OperationState &result, Value target, |
| ArrayRef<int64_t> staticNumThreads, transform::NumThreadsSpec, |
| ArrayRef<int64_t> threadDimMapping) { |
| return build(builder, result, target, |
| getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), |
| NumThreadsSpec(), threadDimMapping); |
| } |
| |
| void transform::TileToForeachThreadOp::build( |
| OpBuilder &builder, OperationState &result, Value target, |
| ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec, |
| ArrayRef<int64_t> threadDimMapping) { |
| SmallVector<int64_t> staticNumThreads; |
| SmallVector<Value> dynamicNumThreads; |
| dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, |
| staticNumThreads, ShapedType::kDynamicSize); |
| // Call the default builder which sets up the proper operands segment sizes |
| // attributes for multiple variadic operands. In the absence of this, horrible |
| // bugs ensue. |
| MLIRContext *ctx = builder.getContext(); |
| auto operationType = pdl::OperationType::get(ctx); |
| auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); |
| ArrayAttr threadDimMappingAttr; |
| if (!threadDimMapping.empty()) |
| threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); |
| build(builder, result, TypeRange{operationType, operationType}, target, |
| dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, |
| /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr); |
| } |
| |
| DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( |
| RewriterBase &rewriter, transform::TransformState &state, |
| TransformOpInterface transformOp, ArrayRef<Operation *> targets, |
| ArrayRef<OpFoldResult> mixedNumThreads, |
| ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping, |
| SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) { |
| if (targets.empty()) |
| return DiagnosedSilenceableFailure(success()); |
| |
| // Given a list of OpFoldResults that are either index attrs or op handles, |
| // return a list of OpFoldResults where all op handles are replaced with the |
| // first (and only) OpResult of that payload op. (There must be exactly one |
| // mapped payload op and it must have exactly one index result.) |
| auto getOpResultsOrIndexAttrs = |
| [&](SmallVector<OpFoldResult> &result, |
| ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) { |
| for (OpFoldResult ofr : opHandlesOrIndexAttrs) { |
| if (ofr.is<Attribute>()) { |
| result.push_back(ofr); |
| continue; |
| } |
| ArrayRef<Operation *> dynamicNumThreads = |
| state.getPayloadOps(ofr.get<Value>()); |
| if (dynamicNumThreads.size() != 1) { |
| DiagnosedSilenceableFailure diag = |
| transformOp.emitSilenceableError() |
| << "handle must be mapped to exactly 1 payload op"; |
| diag.attachNote(ofr.get<Value>().getLoc()) |
| << "mapped to " << dynamicNumThreads.size() << " ops"; |
| return diag; |
| } |
| Operation *op = dynamicNumThreads[0]; |
| if (op->getNumResults() != 1 || |
| !op->getResult(0).getType().isIndex()) { |
| DiagnosedSilenceableFailure diag = |
| transformOp.emitSilenceableError() |
| << "payload op must have exactly 1 index result"; |
| diag.attachNote(op->getLoc()) |
| << "has " << op->getNumResults() << " results"; |
| return diag; |
| } |
| result.push_back(op->getResult(0)); |
| } |
| |
| return DiagnosedSilenceableFailure(success()); |
| }; |
| |
| // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. |
| // Convert to OpFoldResults[index attributes or payload op]. |
| SmallVector<OpFoldResult> numThreads; |
| DiagnosedSilenceableFailure status = |
| getOpResultsOrIndexAttrs(numThreads, mixedNumThreads); |
| if (!status.succeeded()) |
| return status; |
| |
| // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. |
| // Convert to OpFoldResults[index attributes or payload op]. |
| SmallVector<OpFoldResult> tileSizes; |
| status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes); |
| if (!status.succeeded()) |
| return status; |
| |
| // Transform all targets one by one. |
| for (Operation *target : targets) { |
| auto tilableOp = dyn_cast<TilingInterface>(target); |
| if (!tilableOp) { |
| DiagnosedSilenceableFailure diag = |
| transformOp.emitSilenceableError() |
| << "only TilingInterface ops are supported"; |
| diag.attachNote(target->getLoc()) << "target op"; |
| return diag; |
| } |
| rewriter.setInsertionPoint(tilableOp); |
| auto maybeThreadDimMappingAttr = threadDimMapping; |
| auto dimMapping = llvm::to_vector( |
| maybeThreadDimMappingAttr |
| ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) |
| : ArrayRef<int64_t>{}); |
| |
| FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure(); |
| if (!mixedNumThreads.empty()) { |
| tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, |
| numThreads, dimMapping); |
| } else { |
| tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( |
| rewriter, tilableOp, tileSizes, dimMapping); |
| } |
| |
| if (failed(tilingResult)) |
| return transformOp.emitDefaultSilenceableFailure(tilableOp); |
| rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); |
| |
| tileOps.push_back(tilingResult->tileOp); |
| tiledOps.push_back(tilingResult->tiledOp); |
| } |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( |
| transform::TransformResults &transformResults, |
| transform::TransformState &state) { |
| IRRewriter rewriter(getContext()); |
| ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); |
| |
| // Result payload ops. |
| SmallVector<Operation *> tileOps; |
| SmallVector<Operation *> tiledOps; |
| |
| DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( |
| rewriter, state, cast<TransformOpInterface>(getOperation()), targets, |
| getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, |
| tiledOps); |
| |
| if (!diag.succeeded()) |
| return diag; |
| |
| transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps); |
| transformResults.set(getTiledOp().cast<OpResult>(), tiledOps); |
| |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| void transform::TileToForeachThreadOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTarget(), effects); |
| onlyReadsHandle(getTileSizes(), effects); |
| onlyReadsHandle(getNumThreads(), effects); |
| producesHandle(getResults(), effects); |
| } |
| |
| SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() { |
| return getMixedSizes(getStaticNumThreads(), getNumThreads()); |
| } |
| |
| SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() { |
| return getMixedSizes(getStaticTileSizes(), getTileSizes()); |
| } |
| |
| LogicalResult TileToForeachThreadOp::verify() { |
| if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) |
| return emitOpError("either num_threads or tile_sizes must be specified"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TileToScfForOp |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::TileToScfForOp::apply(TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); |
| |
| ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); |
| SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; |
| dynamicSizeProducers.reserve(getDynamicSizes().size()); |
| for (Value dynamicSizeProducerHandle : getDynamicSizes()) { |
| dynamicSizeProducers.push_back( |
| state.getPayloadOps(dynamicSizeProducerHandle)); |
| |
| if (dynamicSizeProducers.back().size() != targets.size()) { |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() |
| << "expected as many dynamic size-producing operations (" |
| << dynamicSizeProducers.back().size() << ") as target ops (" |
| << targets.size() << ")"; |
| diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; |
| return diag; |
| } |
| |
| for (Operation *op : dynamicSizeProducers.back()) { |
| if (op->getNumResults() == 1 && |
| op->getResult(0).getType().isa<IndexType>()) |
| continue; |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() << "expected sizes to be produced by ops " |
| "with a single index-type result"; |
| diag.attachNote(op->getLoc()) << "size producer op"; |
| diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; |
| return diag; |
| } |
| } |
| |
| SmallVector<Operation *> tiled; |
| SmallVector<SmallVector<Operation *, 4>, 4> loops; |
| loops.resize(getLoops().size()); |
| for (auto &en : llvm::enumerate(targets)) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value()); |
| if (!tilingInterfaceOp) { |
| DiagnosedSilenceableFailure diag = |
| emitSilenceableError() << "only TilingInterface ops are supported"; |
| diag.attachNote(en.value()->getLoc()) << "target op"; |
| return diag; |
| } |
| |
| scf::SCFTilingOptions tilingOptions; |
| unsigned index = en.index(); |
| if (!tileSizes.empty()) { |
| tilingOptions.setTileSizeComputationFunction( |
| [&, index](OpBuilder &b, Operation *) { |
| SmallVector<Value, 4> sizes; |
| sizes.reserve(tileSizes.size()); |
| unsigned dynamicIdx = 0; |
| for (OpFoldResult ofr : getMixedSizes()) { |
| if (auto attr = ofr.dyn_cast<Attribute>()) { |
| sizes.push_back(b.create<arith::ConstantIndexOp>( |
| getLoc(), attr.cast<IntegerAttr>().getInt())); |
| } else { |
| sizes.push_back( |
| dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); |
| } |
| } |
| return sizes; |
| }); |
| } |
| |
| tilingOptions.setInterchange(extractI64Array(getInterchange())); |
| SimpleRewriter rewriter(tilingInterfaceOp.getContext()); |
| FailureOr<scf::SCFTilingResult> tilingResult = |
| tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); |
| if (failed(tilingResult)) |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| |
| rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements); |
| |
| tiled.append(tilingResult->tiledOps); |
| for (const auto &en2 : llvm::enumerate(tilingResult->loops)) |
| loops[en2.index()].push_back(en2.value()); |
| } |
| |
| transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); |
| for (const auto &en : llvm::enumerate(loops)) |
| transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() { |
| ValueRange dynamic = getDynamicSizes(); |
| SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes()); |
| SmallVector<OpFoldResult> results; |
| results.reserve(tileSizes.size()); |
| unsigned dynamicPos = 0; |
| Builder builder(getContext()); |
| for (int64_t size : tileSizes) { |
| if (size == ShapedType::kDynamicSize) { |
| results.push_back(dynamic[dynamicPos++]); |
| } else { |
| results.push_back(builder.getIndexAttr(size)); |
| } |
| } |
| return results; |
| } |
| |
| ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser, |
| OperationState &result) { |
| OpAsmParser::UnresolvedOperand target; |
| SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; |
| ArrayAttr staticSizes; |
| auto pdlOperationType = pdl::OperationType::get(parser.getContext()); |
| if (parser.parseOperand(target) || |
| parser.resolveOperand(target, pdlOperationType, result.operands) || |
| parseDynamicIndexList(parser, dynamicSizes, staticSizes, |
| ShapedType::kDynamicSize) || |
| parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || |
| parser.parseOptionalAttrDict(result.attributes)) |
| return ParseResult::failure(); |
| |
| result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); |
| size_t numExpectedLoops = |
| staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); |
| result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); |
| return success(); |
| } |
| |
| void TileToScfForOp::print(OpAsmPrinter &p) { |
| p << ' ' << getTarget(); |
| printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), |
| ShapedType::kDynamicSize); |
| p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); |
| } |
| |
| void transform::TileToScfForOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTarget(), effects); |
| onlyReadsHandle(getDynamicSizes(), effects); |
| producesHandle(getTiledLinalgOp(), effects); |
| producesHandle(getLoops(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // VectorizeOp |
| //===----------------------------------------------------------------------===// |
| |
| void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, |
| Value target, bool vectorizePadding) { |
| result.addOperands(target); |
| if (vectorizePadding) { |
| result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), |
| builder.getUnitAttr()); |
| } |
| result.addTypes(pdl::OperationType::get(builder.getContext())); |
| } |
| |
| namespace { |
| /// This is an helper only to call vectorize via a pattern inside of |
| /// VectorizeOp::applyToOne. |
| struct VectorizationPattern : public RewritePattern { |
| explicit VectorizationPattern(MLIRContext *context) |
| : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| LinalgOp linalgOp = dyn_cast<LinalgOp>(op); |
| if (!linalgOp) |
| return failure(); |
| return vectorize(rewriter, linalgOp); |
| } |
| }; |
| } // namespace |
| |
| DiagnosedSilenceableFailure |
| transform::VectorizeOp::applyToOne(Operation *target, |
| SmallVectorImpl<Operation *> &results, |
| transform::TransformState &state) { |
| if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { |
| auto diag = this->emitOpError("requires isolated-from-above targets"); |
| diag.attachNote(target->getLoc()) << "non-isolated target"; |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| |
| MLIRContext *ctx = getContext(); |
| RewritePatternSet patterns(ctx); |
| patterns.add<VectorizationPattern>(ctx); |
| |
| if (!getDisableTransferPermutationMapLoweringPatterns()) |
| vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); |
| |
| if (!getDisableMultiReductionToContractPatterns()) |
| vector::populateVectorReductionToContractPatterns(patterns); |
| |
| patterns.add<linalg::LinalgCopyVTRForwardingPattern, |
| linalg::LinalgCopyVTWForwardingPattern>(ctx, |
| /*benefit=*/2); |
| vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); |
| vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); |
| |
| patterns.add<CopyVectorizationPattern>(ctx); |
| |
| if (getVectorizePadding()) |
| linalg::populatePadOpVectorizationPatterns(patterns); |
| |
| if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) |
| return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); |
| |
| results.push_back(target); |
| return DiagnosedSilenceableFailure(success()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Transform op registration |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Registers new ops and declares PDL as dependent dialect since the |
| /// additional ops are using PDL types for operands and results. |
| class LinalgTransformDialectExtension |
| : public transform::TransformDialectExtension< |
| LinalgTransformDialectExtension> { |
| public: |
| using Base::Base; |
| |
| void init() { |
| declareDependentDialect<pdl::PDLDialect>(); |
| declareDependentDialect<LinalgDialect>(); |
| declareGeneratedDialect<AffineDialect>(); |
| declareGeneratedDialect<arith::ArithDialect>(); |
| declareGeneratedDialect<scf::SCFDialect>(); |
| declareGeneratedDialect<vector::VectorDialect>(); |
| declareGeneratedDialect<gpu::GPUDialect>(); |
| |
| registerTransformOps< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" |
| >(); |
| } |
| }; |
| } // namespace |
| |
| #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" |
| |
| void mlir::linalg::registerTransformDialectExtension( |
| DialectRegistry ®istry) { |
| registry.addExtensions<LinalgTransformDialectExtension>(); |
| } |