| //===- Detensorize.cpp - Linalg transformations as patterns ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" |
| #include "mlir/Dialect/Linalg/Passes.h" |
| #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include <iterator> |
| #include <memory> |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| |
| static Value sourceMaterializationCallback(OpBuilder &builder, Type type, |
| ValueRange inputs, Location loc) { |
| assert(inputs.size() == 1); |
| if (inputs[0].getType().isa<TensorType>()) |
| return nullptr; |
| |
| // A detensored value is converted back by creating a new tensor from its |
| // element(s). |
| auto createNewTensorOp = builder.create<tensor::FromElementsOp>( |
| loc, inputs[0].getType(), inputs[0]); |
| |
| // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to |
| // a tensor<dtype> instead. |
| return builder.create<linalg::TensorCollapseShapeOp>( |
| loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{}); |
| } |
| |
| namespace { |
| /// Defines the criteria a TensorType must follow in order to be considered |
| /// "detensorable". |
| /// |
| /// NOTE: For now, only 0-D tensors are supported. |
| /// |
| /// Returns true if tensorType can be detensored. |
| bool canBeDetensored(TensorType tensorType) { |
| return tensorType.hasRank() && tensorType.getRank() == 0; |
| } |
| |
| bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { |
| GenericOp genericOp = dyn_cast_or_null<GenericOp>(op); |
| return genericOp && |
| llvm::all_of( |
| genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { |
| return !typeConverter.isLegal(opOperand->get().getType()); |
| }); |
| } |
| |
| /// A conversion patttern for detensoring `linalg.generic` ops. |
| class DetensorizeGenericOp : public OpConversionPattern<GenericOp> { |
| public: |
| using OpConversionPattern::OpConversionPattern; |
| LogicalResult |
| matchAndRewrite(GenericOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Block *originalBlock = op->getBlock(); |
| |
| // Gather some information about the op before inling its region. |
| Block *opEntryBlock = &*op.region().begin(); |
| YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator()); |
| |
| // Split the op's region before the op. This way, we have a clear insertion |
| // point in which the op can be inlined. |
| Block *newBlock = originalBlock->splitBlock(op); |
| rewriter.inlineRegionBefore(op.region(), newBlock); |
| // Now that op's region is inlined, the operands of its YieldOp are mapped |
| // to the materialized target values. Therefore, we can replace the op's |
| // uses with those of its YielOp's operands. |
| rewriter.replaceOp(op, yieldOp->getOperands()); |
| |
| // No need for these intermediate blocks, merge them into 1. |
| rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); |
| rewriter.mergeBlocks(newBlock, originalBlock, {}); |
| |
| rewriter.eraseOp(&*Block::iterator(yieldOp)); |
| |
| return success(); |
| } |
| }; |
| |
| /// A conversion pattern for detensoring internal (non-entry) blocks within a |
| /// function. |
| struct FunctionNonEntryBlockConversion : public ConversionPattern { |
| FunctionNonEntryBlockConversion(StringRef functionLikeOpName, |
| MLIRContext *ctx, TypeConverter &converter, |
| DenseSet<BlockArgument> blockArgsToDetensor) |
| : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx), |
| blockArgsToDetensor(blockArgsToDetensor) {} |
| |
| LogicalResult |
| matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.startRootUpdate(op); |
| Region ®ion = function_like_impl::getFunctionBody(op); |
| SmallVector<TypeConverter::SignatureConversion, 2> conversions; |
| |
| for (Block &block : llvm::drop_begin(region, 1)) { |
| conversions.emplace_back(block.getNumArguments()); |
| TypeConverter::SignatureConversion &back = conversions.back(); |
| |
| for (BlockArgument blockArgument : block.getArguments()) { |
| int idx = blockArgument.getArgNumber(); |
| |
| if (blockArgsToDetensor.count(blockArgument)) |
| back.addInputs(idx, {getTypeConverter()->convertType( |
| block.getArgumentTypes()[idx])}); |
| else |
| back.addInputs(idx, {block.getArgumentTypes()[idx]}); |
| } |
| } |
| |
| if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, |
| conversions))) { |
| rewriter.cancelRootUpdate(op); |
| return failure(); |
| } |
| |
| rewriter.finalizeRootUpdate(op); |
| return success(); |
| } |
| |
| private: |
| const DenseSet<BlockArgument> blockArgsToDetensor; |
| }; |
| |
| class DetensorizeTypeConverter : public TypeConverter { |
| public: |
| DetensorizeTypeConverter() { |
| addConversion([](Type type) { return type; }); |
| |
| // A TensorType that can be detensored, is converted to the underlying |
| // element type. |
| addConversion([](TensorType tensorType) -> Type { |
| if (canBeDetensored(tensorType)) |
| return tensorType.getElementType(); |
| |
| return tensorType; |
| }); |
| |
| // A tensor value is detensoried by extracting its element(s). |
| addTargetMaterialization([](OpBuilder &builder, Type type, |
| ValueRange inputs, Location loc) -> Value { |
| return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); |
| }); |
| |
| addSourceMaterialization(sourceMaterializationCallback); |
| addArgumentMaterialization(sourceMaterializationCallback); |
| } |
| }; |
| |
| /// Canonicalizes the pattern of the form |
| /// |
| /// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> |
| /// %reshaped_tensor = linalg.tensor_collapse_shape %tensor [] |
| /// : tensor<1xi32> into tensor<i32> |
| /// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32> |
| /// |
| /// to just %element. |
| struct ExtractFromReshapeFromElements |
| : public OpRewritePattern<tensor::ExtractOp> { |
| using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractOp extract, |
| PatternRewriter &rewriter) const final { |
| if (!extract.indices().empty()) |
| return failure(); |
| |
| auto tensorReshape = |
| extract.tensor().getDefiningOp<TensorCollapseShapeOp>(); |
| if (tensorReshape == nullptr) |
| return failure(); |
| |
| auto tensorFromElements = |
| tensorReshape.getOperand() |
| .getDefiningOp<mlir::tensor::FromElementsOp>(); |
| if (tensorFromElements == nullptr) |
| return failure(); |
| |
| rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); |
| return success(); |
| } |
| }; |
| |
| /// @see LinalgDetensorize in Linalg/Passes.td for more details. |
| struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> { |
| LinalgDetensorize() = default; |
| LinalgDetensorize(const LinalgDetensorize &pass) |
| : LinalgDetensorizeBase<LinalgDetensorize>() {} |
| |
| class CostModel { |
| public: |
| virtual ~CostModel() = default; |
| |
| /// A cost model algorithm computes the following outputs: |
| /// |
| /// - opsToDetensor: the list of linalg ops that should be |
| /// detensored. |
| /// |
| /// - blockArgsToDetensor: since the operands and results of detensored |
| /// linalg ops can cross the BB boundary (e.g. a linalg op's input can come |
| /// from a BB argument and a linalg op's output can be passed to successor |
| /// BBs), we need to maintain the sub-set of arguments that should be |
| /// detensored (i.e. converted by typeConverter) for each affected BB. |
| /// |
| /// Example: |
| /// |
| /// For the following snippet: |
| /// ... |
| /// ^bb1(%6: tensor<i32>, %9: tensor<i32>): |
| /// %7 = linalg.init_tensor [] : tensor<i32> |
| /// %8 = linalg.generic #attrs |
| /// ins(%6, %6 : tensor<i32>, tensor<i32>) |
| /// outs(%7 : tensor<i32>) { |
| /// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): |
| /// %9 = arith.addi %arg0, %arg1 : i32 |
| /// linalg.yield %9 : i32 |
| /// } -> tensor<i32> |
| /// %10 = "some.op"(%9) |
| /// br ^bb2(%8 : tensor<i32>) |
| /// ... |
| /// |
| /// if the cost model decides that the linalg.generic op should be |
| /// detensored, then: |
| /// - opsToDetensor should be = {linalg.generic{add}}. |
| /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. |
| virtual void compute(FuncOp func, DetensorizeTypeConverter typeConverter, |
| DenseSet<Operation *> &opsToDetensor, |
| DenseSet<BlockArgument> &blockArgsToDetensor) = 0; |
| |
| /// From the blockArgsToDetensor set computed by a CostModel |
| /// implementation, this method computes the corresponding branch op |
| /// detensoring. The result is a map from a branch op to a subset of indices |
| /// of its operands. The indices specify which of the branch op's operands |
| /// should be detensored. |
| /// |
| /// For the previous example, this method would compute: {bb2 -> {0}}. |
| static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring( |
| const DenseSet<BlockArgument> &blockArgsToDetensor) { |
| DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; |
| |
| for (auto blockArgumentElem : blockArgsToDetensor) { |
| Block *block = blockArgumentElem.getOwner(); |
| |
| for (PredecessorIterator pred = block->pred_begin(); |
| pred != block->pred_end(); ++pred) { |
| BranchOpInterface terminator = |
| dyn_cast<BranchOpInterface>((*pred)->getTerminator()); |
| auto blockOperands = |
| terminator.getSuccessorOperands(pred.getSuccessorIndex()); |
| |
| if (!blockOperands || blockOperands->empty()) |
| continue; |
| |
| detensorableBranchOps[terminator].insert( |
| blockOperands->getBeginOperandIndex() + |
| blockArgumentElem.getArgNumber()); |
| } |
| } |
| |
| return detensorableBranchOps; |
| } |
| }; |
| |
| /// Detensorize linalg ops involved in control-flow within a function. |
| /// |
| /// This model starts from BranchOps and CondBranchOps within a function. For |
| /// each such branch, the model then walks the use-def chain for the branch's |
| /// condition backwards in order to understand where the condition's value |
| /// comes from. If the condition value is (indirectly) computed by a linalg op |
| /// that can be detensored, the model then continues walking the use-def chain |
| /// in order to understand where the linalg op's operands come from. This |
| /// leads to discovering a "detensoring component". A detensoring component is |
| /// the set of operations + block arguments that are involved in control-flow |
| /// AND can be detensored. |
| class ControlFlowDetectionModel : public CostModel { |
| public: |
| void compute(FuncOp func, DetensorizeTypeConverter typeConverter, |
| DenseSet<Operation *> &opsToDetensor, |
| DenseSet<BlockArgument> &blockArgsToDetensor) override { |
| SmallVector<Value> workList; |
| |
| func.walk([&](CondBranchOp condBr) { |
| for (auto operand : condBr.getOperands()) { |
| workList.push_back(operand); |
| } |
| }); |
| |
| func.walk([&](BranchOp br) { |
| for (auto operand : br.getOperands()) { |
| workList.push_back(operand); |
| } |
| }); |
| |
| DenseSet<Value> visitedValues; |
| DenseSet<Operation *> visitedOps; |
| |
| // For a (to-be-detesored) value, check if it "escapes" the block by being |
| // passed to terminator. If it does, then workList is updated with the |
| // corresponding argument to the successor block. |
| auto updateWorkListWithSuccessorArguments = |
| [&](Value value, BranchOpInterface terminator) { |
| if (!terminator) |
| return; |
| |
| for (auto operandIdx : |
| llvm::seq<unsigned>(0, terminator->getOperands().size())) { |
| Value operand = terminator->getOperand(operandIdx); |
| |
| if (operand == value) { |
| auto succBlockArg = |
| terminator.getSuccessorBlockArgument(operandIdx); |
| |
| if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg)) |
| workList.push_back(*succBlockArg); |
| } |
| } |
| }; |
| |
| while (!workList.empty()) { |
| Value currentItem = workList.pop_back_val(); |
| |
| if (!visitedValues.insert(currentItem).second) |
| continue; |
| |
| // 1 - Look forward: |
| // 1.1 - If currentItem escapes to one or more successors, add |
| // the corresponding successor arguments to workList. |
| updateWorkListWithSuccessorArguments( |
| currentItem, dyn_cast<BranchOpInterface>( |
| currentItem.getParentBlock()->getTerminator())); |
| |
| // 1.2 - For each user of currentItem, add the defined values to |
| // workList. This way, the user ops can be inspected later if they are |
| // detensorable and if so, their operands will be added to workList to |
| // potentially discover other parts of the detensorable component. |
| for (auto *user : currentItem.getUsers()) |
| for (Value result : user->getResults()) |
| workList.push_back(result); |
| |
| // 2 - Look backward: |
| // 2.1 - The current item is defined by a block argument. If the owner |
| // block is a non-entry one, then: |
| // * Add the argument to blockArgsToDetensor. |
| // * Walk the use-def chain backwards to add each predecessor's |
| // terminator-operands corresponding to currentItem to workList. |
| if (currentItem.dyn_cast<BlockArgument>()) { |
| BlockArgument currentItemBlockArgument = |
| currentItem.cast<BlockArgument>(); |
| Block *ownerBlock = currentItemBlockArgument.getOwner(); |
| |
| // Function arguments are not detensored/converted. |
| if (&*ownerBlock->getParent()->begin() == ownerBlock) |
| continue; |
| |
| // This inner-block argument is involved in control-flow, it should be |
| // detensored. |
| blockArgsToDetensor.insert(currentItemBlockArgument); |
| |
| for (PredecessorIterator pred = ownerBlock->pred_begin(); |
| pred != ownerBlock->pred_end(); ++pred) { |
| BranchOpInterface predTerminator = |
| dyn_cast<BranchOpInterface>((*pred)->getTerminator()); |
| |
| // TODO: For now, we give up if any of the control-flow components |
| // in a function is not detensorable. Fix that. |
| if (!predTerminator) { |
| opsToDetensor.clear(); |
| blockArgsToDetensor.clear(); |
| return; |
| } |
| |
| auto ownerBlockOperands = |
| predTerminator.getSuccessorOperands(pred.getSuccessorIndex()); |
| |
| if (!ownerBlockOperands || ownerBlockOperands->empty()) |
| continue; |
| |
| // For each predecessor, add the value it passes to that argument to |
| // workList to find out how it's computed. |
| workList.push_back( |
| ownerBlockOperands |
| .getValue()[currentItemBlockArgument.getArgNumber()]); |
| } |
| |
| continue; |
| } |
| |
| Operation *currentItemDefiningOp = currentItem.getDefiningOp(); |
| |
| if (!visitedOps.insert(currentItemDefiningOp).second) |
| continue; |
| |
| // 2.2 - The current item is computed by a GenericOp. If the op should |
| // be detensored, then: |
| // * Add it to opsToDetensor. |
| // * Add its operands to workList to discover other parts of the |
| // potentially detensorable component. |
| if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) { |
| // The op was encountered already, no need to inspect it again. |
| if (opsToDetensor.count(genericOp)) |
| continue; |
| |
| // The op should not be detensored, give up on it but continue with |
| // discovering the rest of the control-flow component. |
| if (!shouldBeDetensored(genericOp, typeConverter)) { |
| continue; |
| } |
| |
| opsToDetensor.insert(genericOp); |
| |
| for (Value genericOpOperand : genericOp.inputs()) |
| workList.push_back(genericOpOperand); |
| |
| continue; |
| } |
| |
| // 2.3 - The current item is the result of a FromElementsOp, it will be |
| // trivially detensored later as part of canonicalization patterns |
| // applied at the end of detensoring. |
| // |
| // Note: No need to check whether the result type of this op is |
| // detensorable since if it wasn't we wouldn't reach that point in the |
| // work list. |
| if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp)) |
| continue; |
| |
| // 2.4 - The current item is the result of a scalar op, add all its |
| // operands to the work list. |
| if (llvm::all_of( |
| currentItemDefiningOp->getResultTypes(), |
| [&](Type resultType) { return resultType.isIntOrFloat(); })) |
| for (Value scalarOpOperand : currentItemDefiningOp->getOperands()) |
| workList.push_back(scalarOpOperand); |
| } |
| |
| // Since the cost model gives up on some ops (see the details of step 2.2 |
| // above), block arguments that correspond to the values produced by those |
| // ops should not be detensored as well. |
| |
| DenseSet<BlockArgument> blockArgsToRemove; |
| |
| for (auto &blockArg : blockArgsToDetensor) { |
| Block *block = blockArg.getParentBlock(); |
| |
| // For the potentially detensorable block argument, find the |
| // correpsonding operands in predecessor blocks. |
| for (PredecessorIterator pred = block->pred_begin(); |
| pred != block->pred_end(); ++pred) { |
| BranchOpInterface terminator = |
| dyn_cast<BranchOpInterface>((*pred)->getTerminator()); |
| auto blockOperands = |
| terminator.getSuccessorOperands(pred.getSuccessorIndex()); |
| |
| if (!blockOperands || blockOperands->empty()) |
| continue; |
| |
| Operation *definingOp = |
| terminator |
| ->getOperand(blockOperands->getBeginOperandIndex() + |
| blockArg.getArgNumber()) |
| .getDefiningOp(); |
| |
| // If the operand is defined by a GenericOp that will not be |
| // detensored, then do not detensor the corresponding block argument. |
| if (dyn_cast_or_null<GenericOp>(definingOp) && |
| opsToDetensor.count(definingOp) == 0) { |
| blockArgsToRemove.insert(blockArg); |
| break; |
| } |
| } |
| } |
| |
| for (auto &blockArg : blockArgsToRemove) { |
| blockArgsToDetensor.erase(blockArg); |
| } |
| } |
| }; |
| |
| /// Detensorize everything that can detensored. |
| class AggressiveDetensoringModel : public CostModel { |
| public: |
| void compute(FuncOp func, DetensorizeTypeConverter typeConverter, |
| DenseSet<Operation *> &opsToDetensor, |
| DenseSet<BlockArgument> &blockArgsToDetensor) override { |
| func.walk([&](GenericOp genericOp) { |
| if (shouldBeDetensored(genericOp, typeConverter)) |
| opsToDetensor.insert(genericOp); |
| }); |
| |
| for (Block &block : llvm::drop_begin(func.getBody(), 1)) |
| for (BlockArgument blockArgument : block.getArguments()) |
| blockArgsToDetensor.insert(blockArgument); |
| } |
| }; |
| |
| void runOnFunction() override { |
| MLIRContext *context = &getContext(); |
| DetensorizeTypeConverter typeConverter; |
| RewritePatternSet patterns(context); |
| ConversionTarget target(*context); |
| DenseSet<Operation *> opsToDetensor; |
| DenseMap<Operation *, DenseSet<int>> detensorableBranchOps; |
| DenseSet<BlockArgument> blockArgsToDetensor; |
| |
| if (aggressiveMode.getValue()) { |
| AggressiveDetensoringModel costModel; |
| costModel.compute(getFunction(), typeConverter, opsToDetensor, |
| blockArgsToDetensor); |
| |
| } else { |
| ControlFlowDetectionModel costModel; |
| costModel.compute(getFunction(), typeConverter, opsToDetensor, |
| blockArgsToDetensor); |
| } |
| |
| detensorableBranchOps = |
| CostModel::computeBranchOpDetensoring(blockArgsToDetensor); |
| |
| target.addDynamicallyLegalOp<GenericOp>( |
| [&](GenericOp op) { return !opsToDetensor.count(op); }); |
| |
| target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { |
| // A function is legal if all of its non-entry blocks are legal. We |
| // don't legalize the entry block (i.e. the function's signature) |
| // since detensoring can't happen along external calling convention |
| // boundaries, which we conservatively approximate as all function |
| // signatures. |
| return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) { |
| if (llvm::any_of(blockArgsToDetensor, [&](BlockArgument blockArgument) { |
| return blockArgument.getOwner() == &block && |
| !typeConverter.isLegal(blockArgument.getType()); |
| })) { |
| return false; |
| } |
| return true; |
| }); |
| }); |
| |
| target.markUnknownOpDynamicallyLegal([&](Operation *op) { |
| if (isNotBranchOpInterfaceOrReturnLikeOp(op) || |
| isLegalForReturnOpTypeConversionPattern(op, typeConverter, |
| /*returnOpAlwaysLegal*/ true)) |
| return true; |
| |
| if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { |
| if (!detensorableBranchOps.count(branchOp)) |
| return true; |
| |
| for (auto operandIdx : detensorableBranchOps[branchOp]) |
| if (!typeConverter.isLegal( |
| branchOp->getOperand(operandIdx).getType())) |
| return false; |
| |
| return true; |
| } |
| |
| return false; |
| }); |
| |
| patterns.insert<DetensorizeGenericOp>(typeConverter, context); |
| patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(), |
| context, typeConverter, |
| blockArgsToDetensor); |
| // Since non-entry block arguments get detensorized, we also need to |
| // update the control flow inside the function to reflect the correct |
| // types. |
| auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp, |
| int operandIdx) -> bool { |
| return detensorableBranchOps.count(branchOp) && |
| detensorableBranchOps[branchOp].count(operandIdx); |
| }; |
| |
| populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, |
| shouldConvertBranchOperand); |
| |
| if (failed(applyFullConversion(getFunction(), target, std::move(patterns)))) |
| signalPassFailure(); |
| |
| RewritePatternSet canonPatterns(context); |
| canonPatterns.add<ExtractFromReshapeFromElements>(context); |
| if (failed(applyPatternsAndFoldGreedily(getFunction(), |
| std::move(canonPatterns)))) |
| signalPassFailure(); |
| } |
| |
| Option<bool> aggressiveMode{ |
| *this, "aggressive-mode", |
| llvm::cl::desc("Detensorize all ops that qualify for detensoring along " |
| "with branch operands and basic-block arguments.")}; |
| }; |
| } // namespace |
| |
| std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() { |
| return std::make_unique<LinalgDetensorize>(); |
| } |