| //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| #include "mlir/Dialect/Utils/IndexingUtils.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/PatternMatch.h" |
| |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir; |
| using namespace mlir::tensor; |
| |
| namespace { |
| |
| /// Rewrite tensor.generate with arith.constant if the yielded value is a |
| /// constant and the tensor type is static. |
| struct GenerateToConstant : public OpRewritePattern<GenerateOp> { |
| using OpRewritePattern<GenerateOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(GenerateOp generateOp, |
| PatternRewriter &rewriter) const override { |
| auto tensorType = |
| llvm::cast<RankedTensorType>(generateOp.getResult().getType()); |
| if (!tensorType.hasStaticShape()) |
| return failure(); |
| auto terminatorOp = |
| cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator()); |
| Attribute attr; |
| if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr))) |
| return failure(); |
| Operation *constantOp = |
| rewriter.getContext() |
| ->getLoadedDialect<TensorDialect>() |
| ->materializeConstant(rewriter, |
| DenseElementsAttr::get(tensorType, attr), |
| tensorType, generateOp->getLoc()); |
| if (!constantOp) |
| return failure(); |
| rewriter.replaceOp(generateOp, constantOp->getResults()); |
| return success(); |
| } |
| }; |
| |
| /// Transform a linear index from one indexing space to another given: |
| /// |
| /// - the shape of the source indexing space, |
| /// - the strides of the target indexing space, |
| /// - a linear index into the source indexing space. |
| /// |
| /// This function is logically a sequence of linearize/delinearize over |
| /// different bases but avoids allocating intermediate SmallVectors. |
| int64_t transformIndexSpace(ArrayRef<int64_t> inputShape, |
| ArrayRef<int64_t> outputStrides, |
| int64_t srcLinearIndex) { |
| assert(inputShape.size() == outputStrides.size()); |
| |
| int64_t dstLinearIndex = 0; |
| |
| for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { |
| // Compute the index into the current dimension of the source tensor. |
| // `quotient` is the remaining linear index after accounting for the |
| // current dimension. |
| // |
| // `remainder` is the index into the source tensor for the current |
| // dimension. |
| auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]); |
| |
| srcLinearIndex = quotient; |
| |
| // Add the contribution of the current dimension to the output using the |
| // permutation map. |
| dstLinearIndex += outputStrides[dim] * remainder; |
| } |
| |
| return dstLinearIndex; |
| } |
| |
| template <typename ElemType, typename AttrType> |
| Value constantFoldPadOp(PatternRewriter &rewriter, Location loc, |
| DenseElementsAttr input, AttrType padValue, |
| ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) { |
| auto inputValues = input.tryGetValues<ElemType>(); |
| if (failed(inputValues)) |
| return nullptr; |
| |
| auto oldShape = input.getType().getShape(); |
| |
| // Compute the output shape of the new value. |
| auto newShape = |
| llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh), |
| [](std::tuple<int64_t, int64_t, int64_t> pack) { |
| auto [old, low, high] = pack; |
| return old + low + high; |
| }); |
| |
| int64_t outputSize = computeProduct(newShape); |
| |
| // Fully initialize the vector with the padding value. |
| // The non-padded area will then be copied. |
| SmallVector<ElemType> values(outputSize, padValue.getValue()); |
| |
| // Strides for input and output are used to transform between the indexing |
| // space of the input and output tensors. |
| SmallVector<int64_t> outputStrides = computeStrides(newShape); |
| |
| // The contribution of the low padding to the offset in the output tensor. |
| // This is the starting position of the source tensor within the padding |
| // tensor. |
| int64_t startingOffset = linearize(padLow, outputStrides); |
| |
| // Copy values from the input tensor to the corresponding sub-region |
| // of the output tensor. |
| for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) { |
| auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex); |
| values[outputIndex + startingOffset] = inputValue; |
| } |
| |
| // Create an attribute for the folded value. |
| auto newType = input.getType().clone(newShape); |
| auto newAttr = DenseElementsAttr::get(newType, values); |
| |
| Operation *constantOp = |
| rewriter.getContext() |
| ->getLoadedDialect<TensorDialect>() |
| ->materializeConstant(rewriter, newAttr, newType, loc); |
| |
| return constantOp ? constantOp->getResult(0) : nullptr; |
| } |
| |
| struct PadOpToConstant final : public OpRewritePattern<PadOp> { |
| |
| PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {} |
| |
| LogicalResult matchAndRewrite(PadOp padTensorOp, |
| PatternRewriter &rewriter) const override { |
| if (padTensorOp.getNofold()) |
| return rewriter.notifyMatchFailure( |
| padTensorOp, "refusing to fold nofold pad operation"); |
| |
| TypedValue<RankedTensorType> input = padTensorOp.getSource(); |
| RankedTensorType resultType = padTensorOp.getResult().getType(); |
| |
| DenseElementsAttr inputAttr = nullptr; |
| if (!matchPattern(input, m_Constant(&inputAttr))) |
| return failure(); |
| |
| Value paddingValue = padTensorOp.getConstantPaddingValue(); |
| |
| // Extract the constant value used for padding or bail out. |
| Attribute paddingAttr = nullptr; |
| if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr))) |
| return rewriter.notifyMatchFailure(padTensorOp, |
| "unable to get constant value"); |
| |
| // Try to extract the constant values of the low and high padding. |
| auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad()); |
| auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad()); |
| |
| // If the padding cannot be extracted, bail out. |
| if (!lowPad || !highPad) |
| return rewriter.notifyMatchFailure(padTensorOp, |
| "unable to extract constant padding"); |
| |
| // We have a potential candidate, consult the control function to |
| // determine if the op should fold. |
| if (!controlFn(&padTensorOp.getSourceMutable())) |
| return rewriter.notifyMatchFailure(padTensorOp, |
| "not folding due to cost function"); |
| |
| Location loc = padTensorOp.getLoc(); |
| |
| // Try constant folding the supported cases of integer and float values. |
| Value newOp = |
| llvm::TypeSwitch<Attribute, Value>(paddingAttr) |
| .Case([&](FloatAttr floatAttr) { |
| return constantFoldPadOp<llvm::APFloat>( |
| rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad); |
| }) |
| .Case([&](IntegerAttr integerAttr) { |
| return constantFoldPadOp<llvm::APInt>( |
| rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad); |
| }) |
| .Default(Value()); |
| |
| if (!newOp) |
| return rewriter.notifyMatchFailure(padTensorOp, |
| "tensor type not supported"); |
| |
| if (newOp.getType() != resultType) |
| newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp); |
| |
| rewriter.replaceOp(padTensorOp, newOp); |
| return success(); |
| } |
| |
| private: |
| ControlFoldFn controlFn; |
| }; |
| |
| } // namespace |
| |
| void mlir::tensor::populateRewriteAsConstantPatterns( |
| RewritePatternSet &patterns, const ControlFoldFn &controlFn) { |
| patterns.add<GenerateToConstant>(patterns.getContext()); |
| |
| patterns.add<PadOpToConstant>(patterns.getContext(), controlFn); |
| } |