| //===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file defines transform dialect operations used for testing |
| // TilingInterface |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Transform/IR/TransformAttrs.h" |
| #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
| #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| #include "mlir/IR/Dominance.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/Interfaces/TilingInterface.h" |
| #include "llvm/Support/Debug.h" |
| |
| #define DEBUG_TYPE "test-tiling-interface" |
| |
| #define GET_OP_CLASSES |
| #include "TestTilingInterfaceTransformOps.h.inc" |
| |
| using namespace mlir; |
| using namespace mlir::transform; |
| |
| //===----------------------------------------------------------------------===// |
| // TestFuseAndYieldOp |
| //===----------------------------------------------------------------------===// |
| |
| static llvm::SmallDenseSet<Operation *> collectTiledAndFusedOps(Operation *op) { |
| SmallVector<Operation *> worklist; |
| llvm::SmallDenseSet<Operation *> producers; |
| worklist.push_back(op); |
| producers.insert(op); |
| while (!worklist.empty()) { |
| Operation *current = worklist.pop_back_val(); |
| for (OpOperand &operand : current->getOpOperands()) { |
| Operation *producer = operand.get().getDefiningOp(); |
| if (!producer || !isa<TilingInterface>(producer) || |
| producers.contains(producer)) |
| continue; |
| worklist.push_back(producer); |
| producers.insert(producer); |
| } |
| } |
| return producers; |
| } |
| |
| /// Apply a tile and fuse transformation to all payload ops and store both the |
| /// tiled operation as well as the created tile loops. |
| template <typename Range> |
| static LogicalResult |
| applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, |
| Range &&payloadOps, unsigned numLoops, |
| scf::SCFTilingOptions tilingOptions, |
| TransformResults &transformResults) { |
| SmallVector<Operation *> tiledOps; |
| SmallVector<SmallVector<Operation *>> loopOps(numLoops); |
| |
| for (Operation *target : payloadOps) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) |
| return transformOp->emitError("only TilingInterface ops are supported"); |
| DominanceInfo dominanceInfo(tilingInterfaceOp); |
| |
| llvm::SmallDenseSet<Operation *> tiledAndFusedOps = |
| collectTiledAndFusedOps(tilingInterfaceOp); |
| llvm::DenseSet<Operation *> yieldReplacementsFor; |
| for (auto *op : tiledAndFusedOps) { |
| if (llvm::any_of(op->getUsers(), [&](Operation *user) { |
| return dominanceInfo.properlyDominates(tilingInterfaceOp, user); |
| })) { |
| yieldReplacementsFor.insert(op); |
| } |
| } |
| |
| scf::SCFTileAndFuseOptions tileAndFuseOptions; |
| tileAndFuseOptions.setTilingOptions(tilingOptions); |
| |
| scf::SCFTileAndFuseOptions::ControlFnTy controlFn = |
| [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, |
| bool isDestinationOperand) |
| -> std::optional<scf::SCFTileAndFuseOptions::ControlFnResult> { |
| Operation *owner = originalProducer.getOwner(); |
| bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); |
| return scf::SCFTileAndFuseOptions::ControlFnResult{ |
| yieldProducerReplacement}; |
| }; |
| tileAndFuseOptions.setFusionControlFn(controlFn); |
| |
| rewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
| scf::tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
| tileAndFuseOptions); |
| if (failed(tiledResults)) |
| return failure(); |
| |
| // Perform the replacement of tiled and fused values. |
| SmallVector<Operation *> opsToReplace{target}; |
| llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
| for (Operation *toReplace : opsToReplace) { |
| for (OpResult res : toReplace->getResults()) |
| if (auto replacement = tiledResults->replacements.lookup(res)) { |
| Operation *replacementOp = replacement.getDefiningOp(); |
| rewriter.replaceUsesWithIf(res, replacement, [&](OpOperand &use) { |
| Operation *user = use.getOwner(); |
| return dominanceInfo.properlyDominates(replacementOp, user) && |
| user->getParentOp() == replacementOp->getParentOp(); |
| }); |
| } |
| |
| if (toReplace->use_empty()) { |
| rewriter.eraseOp(toReplace); |
| } |
| } |
| |
| // Report back the relevant handles to the transform op. |
| tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); |
| assert(tiledResults->loops.size() == numLoops && |
| "Mismatched number of loops, tile and fuse transform should have " |
| "failed"); |
| for (unsigned int i = 0; i < numLoops; ++i) |
| loopOps[i].push_back(tiledResults->loops[i]); |
| } |
| |
| transformResults.set(transformOp->getOpResult(0), tiledOps); |
| for (unsigned int i = 0; i < numLoops; ++i) |
| transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); |
| |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, |
| TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<int64_t> tileSizes = |
| extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
| SmallVector<int64_t> tileInterchange = |
| extractFromIntegerArrayAttr<int64_t>(getTileInterchange()); |
| |
| SmallVector<OpFoldResult> tileSizesOfr = |
| getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
| |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.setTileSizes(tileSizesOfr).setInterchange(tileInterchange); |
| if (getUseForall()) { |
| tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
| } |
| |
| LogicalResult result = applyTileAndFuseToAll( |
| rewriter, getOperation(), state.getPayloadOps(getTarget()), |
| tileSizes.size() - llvm::count(tileSizes, 0), tilingOptions, |
| transformResults); |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestFuseConsumerOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Fuse the consumer and store both the original consumer operation as well as |
| /// the fused consumer operation. |
| static LogicalResult |
| applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, |
| Operation *consumer, |
| MutableArrayRef<LoopLikeOpInterface> loops, |
| TransformResults &transformResults) { |
| SmallVector<Operation *> fusedConsumerOps; |
| rewriter.setInsertionPoint(consumer); |
| |
| FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = |
| scf::tileAndFuseConsumer(rewriter, consumer, loops); |
| if (failed(fuseConsumerResults)) |
| return consumer->emitOpError("failed to fuse consumer of slice"); |
| |
| // Report back the relevant handles to the transform op. |
| for (OpOperand *tiledAndFusedConsumerOperand : |
| fuseConsumerResults->tiledAndFusedConsumerOperands) { |
| fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); |
| } |
| transformResults.set(transformOp->getOpResult(0), fusedConsumerOps); |
| for (auto [index, loop] : llvm::enumerate(loops)) { |
| transformResults.set(transformOp->getOpResult(index + 1), {loop}); |
| } |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::TestFuseConsumerOp::apply(TransformRewriter &rewriter, |
| TransformResults &transformResults, |
| TransformState &state) { |
| Operation *consumer = *state.getPayloadOps(getConsumer()).begin(); |
| |
| SmallVector<LoopLikeOpInterface> loops; |
| // Since the matcher works inside-out, we need to iterate the loops in |
| // reverse. |
| for (auto loop : llvm::reverse(getLoops())) { |
| auto loopLikeOp = |
| dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(loop).begin()); |
| if (!loopLikeOp) { |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| loops.push_back(loopLikeOp); |
| } |
| LogicalResult result = applyFuseConsumer(rewriter, getOperation(), consumer, |
| loops, transformResults); |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestFuseConsumerOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getConsumerMutable(), effects); |
| consumesHandle(getLoopsMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestFuseConsumerUsingSliceOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Apply fusing of consumer transformation to all payload ops and store both |
| /// the original consumer operation as well as the fused consumer operation. |
| static LogicalResult applyFuseConsumerUsingSlices( |
| RewriterBase &rewriter, Operation *transformOp, |
| ArrayRef<Operation *> slices, MutableArrayRef<LoopLikeOpInterface> loops, |
| uint32_t numConsumerToFuse, TransformResults &transformResults) { |
| SmallVector<Operation *> originalConsumerOps; |
| SmallVector<Operation *> fusedConsumerOps; |
| |
| rewriter.setInsertionPoint(slices.front()); |
| |
| while (numConsumerToFuse--) { |
| FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = |
| scf::tileAndFuseConsumerOfSlices(rewriter, slices, loops); |
| |
| if (failed(fuseConsumerResults)) |
| return slices.front()->emitOpError("failed to fuse consumer of slice"); |
| |
| // Report back the relevant handles to the transform op. |
| for (OpOperand *origConsumerOperand : |
| fuseConsumerResults->origConsumerOperands) { |
| originalConsumerOps.push_back(origConsumerOperand->getOwner()); |
| } |
| for (OpOperand *tiledAndFusedConsumerOperand : |
| fuseConsumerResults->tiledAndFusedConsumerOperands) { |
| fusedConsumerOps.push_back(tiledAndFusedConsumerOperand->getOwner()); |
| } |
| } |
| |
| transformResults.set(transformOp->getOpResult(0), originalConsumerOps); |
| transformResults.set(transformOp->getOpResult(1), fusedConsumerOps); |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure transform::TestFuseConsumerUsingSliceOp::apply( |
| TransformRewriter &rewriter, TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<Operation *> slices; |
| for (auto op : getTargets()) { |
| auto *sliceOp = *state.getPayloadOps(op).begin(); |
| slices.push_back(sliceOp); |
| } |
| |
| SmallVector<LoopLikeOpInterface> loops; |
| for (auto op : llvm::reverse(getLoops())) { |
| auto loopLikeOp = |
| dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin()); |
| if (!loopLikeOp) { |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| loops.push_back(loopLikeOp); |
| } |
| LogicalResult result = |
| applyFuseConsumerUsingSlices(rewriter, getOperation(), slices, loops, |
| getNumConsumerToFuse(), transformResults); |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestFuseConsumerUsingSliceOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTargetsMutable(), effects); |
| consumesHandle(getLoopsMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestTileUsingForallOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Apply a tiling transformation to all payload ops and store both the |
| /// tiled operation as well as the created tile loops. |
| template <typename Range> |
| static LogicalResult |
| applyTileToAll(RewriterBase &rewriter, Operation *transformOp, |
| Range &&payloadOps, ArrayRef<OpFoldResult> tileSizes, |
| ArrayRef<int64_t> interchange, std::optional<ArrayAttr> mapping, |
| TransformResults &transformResults) { |
| SmallVector<Operation *> tiledOps; |
| SmallVector<Operation *> loopOps; |
| |
| for (Operation *target : payloadOps) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) |
| return transformOp->emitError("only TilingInterface ops are supported"); |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); |
| if (mapping) { |
| tilingOptions.setMapping(mapping.value().getValue()); |
| } |
| tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
| |
| rewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTilingResult> tiledResults = |
| scf::tileUsingSCF(rewriter, tilingInterfaceOp, tilingOptions); |
| if (failed(tiledResults)) |
| return failure(); |
| |
| // Perform the replacement of tiled and fused values. |
| rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); |
| |
| // Report back the relevant handles to the transform op. |
| tiledOps.push_back(tiledResults->tiledOps.front()); |
| for (Operation *loop : tiledResults->loops) |
| loopOps.push_back(loop); |
| } |
| |
| transformResults.set(transformOp->getOpResult(0), tiledOps); |
| for (auto [index, loop] : llvm::enumerate(loopOps)) |
| transformResults.set(transformOp->getOpResult(index + 1), {loop}); |
| |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, |
| TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<int64_t> tileSizes = |
| extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
| SmallVector<int64_t> interchange = |
| extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
| SmallVector<OpFoldResult> tileSizesOfr = |
| getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
| |
| LogicalResult result = |
| applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), |
| tileSizesOfr, interchange, getMapping(), transformResults); |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestTileUsingForallOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getTargetMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestFuseUsingForallOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Apply a tiling transformation to all payload ops and store both the |
| /// tiled operation as well as the created tile loops. |
| template <typename Range> |
| static LogicalResult applyTilingToAll( |
| RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, |
| unsigned numLoops, TransformResults &transformResults, |
| function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> |
| applyFn) { |
| SmallVector<Operation *> tiledLinalgOps; |
| SmallVector<SmallVector<Operation *>> loopOps(1); |
| |
| for (Operation *target : payloadOps) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) |
| return transformOp->emitError("only TilingInterface ops are supported"); |
| |
| rewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTileAndFuseResult> tiledResults = |
| applyFn(tilingInterfaceOp); |
| if (failed(tiledResults)) |
| return failure(); |
| |
| // Perform the replacement of tiled and fused values. |
| SmallVector<Operation *> opsToReplace{target}; |
| llvm::append_range(opsToReplace, tiledResults->fusedProducers); |
| for (Operation *toReplace : opsToReplace) { |
| for (OpResult res : toReplace->getResults()) |
| if (auto replacement = tiledResults->replacements.lookup(res)) |
| rewriter.replaceAllUsesWith(res, replacement); |
| if (toReplace->use_empty()) |
| rewriter.eraseOp(toReplace); |
| } |
| |
| // Report back the relevant handles to the transform op. |
| tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); |
| assert(tiledResults->loops.size() == 1 && |
| cast<scf::ForallOp>(tiledResults->loops[0]).getRank() == numLoops && |
| "Mismatched number of loops, tile and fuse transform should have " |
| "failed"); |
| loopOps[0] = {tiledResults->loops[0]}; |
| } |
| |
| transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); |
| if (!loopOps.empty()) |
| transformResults.set(transformOp->getOpResult(1), loopOps[0]); |
| |
| return success(); |
| } |
| |
| DiagnosedSilenceableFailure |
| transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, |
| TransformResults &transformResults, |
| TransformState &state) { |
| SmallVector<int64_t> tileSizes = |
| extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
| SmallVector<int64_t> tileInterchange = |
| extractFromIntegerArrayAttr<int64_t>(getInterchange()); |
| |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.interchangeVector = tileInterchange; |
| SmallVector<OpFoldResult> tileSizesOfr = |
| getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); |
| tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); |
| tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); |
| scf::SCFTileAndFuseOptions tileAndFuseOptions; |
| tileAndFuseOptions.tilingOptions = tilingOptions; |
| LogicalResult result = applyTilingToAll( |
| rewriter, getOperation(), state.getPayloadOps(getRootOp()), |
| tileSizes.size() - llvm::count(tileSizes, 0), transformResults, |
| [&](TilingInterface tilingInterfaceOp) |
| -> FailureOr<scf::SCFTileAndFuseResult> { |
| return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp, |
| tileAndFuseOptions); |
| }); |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestFuseUsingForallOp::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| consumesHandle(getRootOpMutable(), effects); |
| producesHandle(getOperation()->getOpResults(), effects); |
| modifiesPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestTileAndFuseOuterParallelPartialReduction |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure |
| transform::TestTileAndFuseOuterParallelPartialReductionOp::apply( |
| TransformRewriter &rewriter, TransformResults &transformResults, |
| TransformState &state) { |
| auto target = |
| dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin()); |
| if (!target) { |
| emitOpError("expected root operation to implement `TilingInterface`"); |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| |
| SmallVector<unsigned> reductionDims = |
| extractFromIntegerArrayAttr<unsigned>(getReductionDims()); |
| if (reductionDims.empty()) { |
| for (auto [index, iterator] : |
| llvm::enumerate(target.getLoopIteratorTypes())) |
| if (iterator == utils::IteratorType::reduction) |
| reductionDims.push_back(index); |
| } |
| |
| if (reductionDims.empty()) { |
| emitOpError( |
| "no reduction dimension specified or found in the target operation"); |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| |
| SmallVector<int64_t> reductionTileSizes = |
| extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
| if (reductionTileSizes.size() != reductionDims.size()) { |
| emitOpError( |
| "missing tile sizes for reduction dimensions that are to be tiled"); |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| |
| // Adjust tile sizes so that it corresponds to the reduction iterator types. |
| SmallVector<OpFoldResult> tileSizes; |
| int reductionTileSizeNum = 0; |
| OpFoldResult zero = rewriter.getIndexAttr(0); |
| for (auto iterator : target.getLoopIteratorTypes()) { |
| if (iterator == utils::IteratorType::parallel) { |
| tileSizes.push_back(zero); |
| continue; |
| } |
| tileSizes.push_back( |
| rewriter.getIndexAttr(reductionTileSizes[reductionTileSizeNum++])); |
| } |
| |
| scf::SCFTilingOptions tilingOptions; |
| tilingOptions.setTileSizes(tileSizes) |
| .setLoopType(scf::SCFTilingOptions::LoopType::ForallOp) |
| .setReductionTilingStrategy( |
| ReductionTilingStrategy::PartialReductionOuterParallel) |
| .setReductionDims(reductionDims); |
| if (auto mapping = getMapping()) { |
| tilingOptions.setMapping(getMapping().value()); |
| } |
| |
| LogicalResult result = applyTileAndFuseToAll( |
| rewriter, getOperation(), state.getPayloadOps(getRootOp()), |
| /*numLoops =*/1, tilingOptions, transformResults); |
| |
| return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() |
| : DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestTileAndFuseOuterParallelPartialReduction |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply( |
| TransformRewriter &transformRewriter, TransformResults &transformResults, |
| TransformState &state) { |
| auto target = |
| dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin()); |
| if (!target) { |
| emitOpError("expected root operation to implement `TilingInterface`"); |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| |
| OpFoldResult oneOfr = transformRewriter.getIndexAttr(1); |
| |
| scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn = |
| [&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges, |
| ArrayRef<OpFoldResult> givenTileSizes, |
| ValueRange outerDestinationTensors) |
| -> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> { |
| // Check that the strides are all 1 (to make it easier in the test). |
| if (llvm::any_of(loopRanges, [](Range r) { |
| return !isConstantIntValue(r.stride, 1); |
| })) { |
| return emitOpError("unable to handle loop ranges with strides != 1"); |
| } |
| // Check number of tile sizes is equal to loop dimensions. |
| if (loopRanges.size() != givenTileSizes.size()) { |
| return emitOpError("expected number of tile sizes to be same as the " |
| "number of loops in the operation"); |
| } |
| // For testing disallow any of the tile sizes being 0. |
| if (llvm::any_of(givenTileSizes, isZeroInteger)) { |
| return emitOpError("unhandled case of zero tile size"); |
| } |
| // For testing, only handle tensor tiling. |
| if (outerDestinationTensors.empty()) { |
| return emitOpError("expected destination tensors"); |
| } |
| |
| // Compute the number of iterations for each of the loops. |
| AffineExpr s0, s1, s2; |
| bindSymbols(rewriter.getContext(), s0, s1, s2); |
| AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize |
| |
| SmallVector<OpFoldResult> allNumIters; |
| allNumIters.reserve(loopRanges.size()); |
| for (auto [loopRange, tileSize] : |
| llvm::zip_equal(loopRanges, givenTileSizes)) { |
| OpFoldResult numIters = affine::makeComposedFoldedAffineApply( |
| rewriter, loc, numItersExpr, |
| {loopRange.offset, loopRange.size, tileSize}); |
| allNumIters.push_back(numIters); |
| } |
| if (allNumIters.empty()) { |
| return emitOpError("invalid empty tile sizes and loop ranges"); |
| } |
| |
| AffineExpr mulExpr = s0 * s1; |
| OpFoldResult cumulative = oneOfr; |
| for (auto numIters : allNumIters) { |
| cumulative = affine::makeComposedFoldedAffineApply( |
| rewriter, loc, mulExpr, {cumulative, numIters}); |
| } |
| |
| Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0); |
| Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1); |
| Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cumulative); |
| |
| SmallVector<OpFoldResult> offsets; |
| SmallVector<OpFoldResult> sizes; |
| SmallVector<Value> innerDestinationTensors; |
| offsets.reserve(loopRanges.size()); |
| sizes.reserve(loopRanges.size()); |
| |
| AffineExpr d0; |
| bindDims(rewriter.getContext(), d0); |
| AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize |
| AffineMap minMap = |
| AffineMap::get(1, 2, {s0 - d0, s1}, |
| rewriter.getContext()); // min(ub - offset, tileSize) |
| auto forOp = scf::ForOp::create( |
| rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors, |
| [&](OpBuilder &b, Location bodyLoc, Value linearizedIv, |
| ValueRange destinations) { |
| auto delinearizeOp = affine::AffineDelinearizeIndexOp::create( |
| b, bodyLoc, linearizedIv, allNumIters); |
| for (auto [normalizedIv, range, tileSize] : llvm::zip_equal( |
| delinearizeOp.getResults(), loopRanges, givenTileSizes)) { |
| |
| OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv); |
| OpFoldResult offset = affine::makeComposedFoldedAffineApply( |
| b, bodyLoc, offsetExpr, |
| {normalizedIvOfr, range.offset, tileSize}); |
| offsets.push_back(offset); |
| |
| OpFoldResult size = affine::makeComposedFoldedAffineMin( |
| b, bodyLoc, minMap, {offset, range.size, tileSize}); |
| sizes.push_back(size); |
| } |
| innerDestinationTensors = llvm::to_vector(destinations); |
| }); |
| rewriter.setInsertionPointToEnd(forOp.getBody()); |
| return scf::SCFTilingOptions::CustomLoopHeaderInfo{ |
| {cast<LoopLikeOpInterface>(forOp.getOperation())}, |
| offsets, |
| sizes, |
| innerDestinationTensors}; |
| }; |
| |
| scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn = |
| [&](RewriterBase &rewriter, Location loc, |
| ArrayRef<LoopLikeOpInterface> loops, ValueRange tiledResults, |
| ArrayRef<SmallVector<OpFoldResult>> resultOffsets, |
| ArrayRef<SmallVector<OpFoldResult>> resultSizes, |
| ValueRange destinationTensors) -> LogicalResult { |
| SmallVector<Value> yieldValues; |
| yieldValues.reserve(destinationTensors.size()); |
| for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal( |
| tiledResults, resultOffsets, resultSizes, destinationTensors)) { |
| SmallVector<OpFoldResult> strides(offsets.size(), oneOfr); |
| Value insertedVal = tensor::InsertSliceOp::create( |
| rewriter, loc, tiledResult, destination, offsets, sizes, strides); |
| yieldValues.push_back(insertedVal); |
| } |
| scf::YieldOp::create(rewriter, loc, yieldValues); |
| return success(); |
| }; |
| |
| scf::SCFTilingOptions tilingOptions; |
| SmallVector<int64_t> staticTileSizes = |
| extractFromIntegerArrayAttr<int64_t>(getTileSizes()); |
| SmallVector<OpFoldResult> tileSizes = |
| getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes); |
| tilingOptions.setTileSizes(tileSizes) |
| .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp) |
| .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn); |
| |
| OpBuilder::InsertionGuard g(transformRewriter); |
| transformRewriter.setInsertionPoint(target); |
| FailureOr<scf::SCFTilingResult> tiledResults = |
| scf::tileUsingSCF(transformRewriter, target, tilingOptions); |
| if (failed(tiledResults)) { |
| return DiagnosedSilenceableFailure::definiteFailure(); |
| } |
| transformRewriter.replaceOp(target, tiledResults->replacements); |
| transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps); |
| transformResults.set(getOperation()->getResult(1), tiledResults->loops); |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestQueryProducerFusability |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure transform::TestQueryProducerFusability::apply( |
| TransformRewriter &rewriter, TransformResults &transformResults, |
| TransformState &state) { |
| for (Operation *target : state.getPayloadOps(getTarget())) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) { |
| return emitSilenceableError() |
| << "target operation does not implement TilingInterface"; |
| } |
| |
| // Collect operand numbers and their corresponding producer insert_slice |
| // offsets and sizes. |
| SmallVector<unsigned> operandNumbers; |
| SmallVector<SmallVector<OpFoldResult>> allOffsets; |
| SmallVector<SmallVector<OpFoldResult>> allSizes; |
| |
| for (OpOperand &operand : target->getOpOperands()) { |
| Value operandValue = operand.get(); |
| Operation *definingOp = operandValue.getDefiningOp(); |
| |
| // Look for a producer tensor.insert_slice. This is only for testing |
| // purposes and otherwise is not a useful transformation. |
| if (auto insertSliceOp = |
| dyn_cast_or_null<tensor::InsertSliceOp>(definingOp)) { |
| operandNumbers.push_back(operand.getOperandNumber()); |
| allOffsets.push_back(insertSliceOp.getMixedOffsets()); |
| allSizes.push_back(insertSliceOp.getMixedSizes()); |
| } |
| } |
| |
| if (!operandNumbers.empty()) { |
| bool isFusable = tilingInterfaceOp.isOpFusableWithProducerSlices( |
| operandNumbers, allOffsets, allSizes); |
| |
| if (isFusable) { |
| target->emitRemark() |
| << "can be fused with producer tensor.insert_slice ops"; |
| } else { |
| target->emitRemark() |
| << "cannot be fused with producer tensor.insert_slice ops"; |
| } |
| } |
| } |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestQueryProducerFusability::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsPayload(effects); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TestQueryConsumerFusability |
| //===----------------------------------------------------------------------===// |
| |
| DiagnosedSilenceableFailure transform::TestQueryConsumerFusability::apply( |
| TransformRewriter &rewriter, TransformResults &transformResults, |
| TransformState &state) { |
| for (Operation *target : state.getPayloadOps(getTarget())) { |
| auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); |
| if (!tilingInterfaceOp) { |
| return emitSilenceableError() |
| << "target operation does not implement TilingInterface"; |
| } |
| |
| // Look for tensor.extract_slice ops that consume results of the tilable op. |
| for (OpResult result : target->getResults()) { |
| for (OpOperand &use : result.getUses()) { |
| Operation *user = use.getOwner(); |
| |
| // Look for a consumer tensor.extract_slice. This is only for testing |
| // purposes and otherwise is not a useful transformation. |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) { |
| bool isFusable = tilingInterfaceOp.isOpFusableWithConsumerSlice( |
| result.getResultNumber(), extractSliceOp.getMixedOffsets(), |
| extractSliceOp.getMixedSizes()); |
| |
| if (isFusable) { |
| target->emitRemark() |
| << "can be fused with consumer tensor.extract_slice op"; |
| } else { |
| target->emitRemark() |
| << "cannot be fused with consumer tensor.extract_slice op"; |
| } |
| } |
| } |
| } |
| } |
| |
| return DiagnosedSilenceableFailure::success(); |
| } |
| |
| void transform::TestQueryConsumerFusability::getEffects( |
| SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| onlyReadsHandle(getTargetMutable(), effects); |
| onlyReadsPayload(effects); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "TestTilingInterfaceTransformOps.cpp.inc" |
| |
| namespace { |
| class TestTilingInterfaceDialectExtension |
| : public transform::TransformDialectExtension< |
| TestTilingInterfaceDialectExtension> { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( |
| TestTilingInterfaceDialectExtension) |
| |
| using Base::Base; |
| |
| void init() { |
| declareDependentDialect<affine::AffineDialect>(); |
| declareDependentDialect<index::IndexDialect>(); |
| declareDependentDialect<scf::SCFDialect>(); |
| declareDependentDialect<tensor::TensorDialect>(); |
| |
| registerTransformOps< |
| #define GET_OP_LIST |
| #include "TestTilingInterfaceTransformOps.cpp.inc" |
| >(); |
| } |
| }; |
| } // namespace |
| |
| namespace test { |
| void registerTestTilingInterfaceTransformDialectExtension( |
| DialectRegistry ®istry) { |
| registry.addExtensions<TestTilingInterfaceDialectExtension>(); |
| } |
| } // namespace test |