| //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a pass to expand affine index ops into one or more more |
| // fundamental operations. |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/LoopUtils.h" |
| #include "mlir/Dialect/Affine/Passes.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| |
| namespace mlir { |
| namespace affine { |
| #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS |
| #include "mlir/Dialect/Affine/Passes.h.inc" |
| } // namespace affine |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::affine; |
| |
| /// Given a basis (in static and dynamic components), return the sequence of |
| /// suffix products of the basis, including the product of the entire basis, |
| /// which must **not** contain an outer bound. |
| /// |
| /// If excess dynamic values are provided, the values at the beginning |
| /// will be ignored. This allows for dropping the outer bound without |
| /// needing to manipulate the dynamic value array. `knownPositive` |
| /// indicases that the values being used to compute the strides are known |
| /// to be non-negative. |
| static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, |
| ValueRange dynamicBasis, |
| ArrayRef<int64_t> staticBasis, |
| bool knownNonNegative) { |
| if (staticBasis.empty()) |
| return {}; |
| |
| SmallVector<Value> result; |
| result.reserve(staticBasis.size()); |
| size_t dynamicIndex = dynamicBasis.size(); |
| Value dynamicPart = nullptr; |
| int64_t staticPart = 1; |
| // The products of the strides can't have overflow by definition of |
| // affine.*_index. |
| arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw; |
| if (knownNonNegative) |
| ovflags = ovflags | arith::IntegerOverflowFlags::nuw; |
| for (int64_t elem : llvm::reverse(staticBasis)) { |
| if (ShapedType::isDynamic(elem)) { |
| // Note: basis elements and their products are, definitionally, |
| // non-negative, so `nuw` is justified. |
| if (dynamicPart) |
| dynamicPart = |
| arith::MulIOp::create(rewriter, loc, dynamicPart, |
| dynamicBasis[dynamicIndex - 1], ovflags); |
| else |
| dynamicPart = dynamicBasis[dynamicIndex - 1]; |
| --dynamicIndex; |
| } else { |
| staticPart *= elem; |
| } |
| |
| if (dynamicPart && staticPart == 1) { |
| result.push_back(dynamicPart); |
| } else { |
| Value stride = |
| rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); |
| if (dynamicPart) |
| stride = |
| arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags); |
| result.push_back(stride); |
| } |
| } |
| std::reverse(result.begin(), result.end()); |
| return result; |
| } |
| |
| LogicalResult |
| affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, |
| AffineDelinearizeIndexOp op) { |
| Location loc = op.getLoc(); |
| Value linearIdx = op.getLinearIndex(); |
| unsigned numResults = op.getNumResults(); |
| ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| if (numResults == staticBasis.size()) |
| staticBasis = staticBasis.drop_front(); |
| |
| if (numResults == 1) { |
| rewriter.replaceOp(op, linearIdx); |
| return success(); |
| } |
| |
| SmallVector<Value> results; |
| results.reserve(numResults); |
| SmallVector<Value> strides = |
| computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
| /*knownNonNegative=*/true); |
| |
| Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); |
| |
| Value initialPart = |
| arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front()); |
| results.push_back(initialPart); |
| |
| auto emitModTerm = [&](Value stride) -> Value { |
| Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride); |
| Value remainderNegative = arith::CmpIOp::create( |
| rewriter, loc, arith::CmpIPredicate::slt, remainder, zero); |
| // If the correction is relevant, this term is <= stride, which is known |
| // to be positive in `index`. Otherwise, while 2 * stride might overflow, |
| // this branch won't be taken, so the risk of `poison` is fine. |
| Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride, |
| arith::IntegerOverflowFlags::nsw); |
| Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative, |
| corrected, remainder); |
| return mod; |
| }; |
| |
| // Generate all the intermediate parts |
| for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { |
| Value thisStride = strides[i]; |
| Value nextStride = strides[i + 1]; |
| Value modulus = emitModTerm(thisStride); |
| // We know both inputs are positive, so floorDiv == div. |
| // This could potentially be a divui, but it's not clear if that would |
| // cause issues. |
| Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride); |
| results.push_back(divided); |
| } |
| |
| results.push_back(emitModTerm(strides.back())); |
| |
| rewriter.replaceOp(op, results); |
| return success(); |
| } |
| |
| LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, |
| AffineLinearizeIndexOp op) { |
| // Should be folded away, included here for safety. |
| if (op.getMultiIndex().empty()) { |
| rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); |
| return success(); |
| } |
| |
| Location loc = op.getLoc(); |
| ValueRange multiIndex = op.getMultiIndex(); |
| size_t numIndexes = multiIndex.size(); |
| ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
| if (numIndexes == staticBasis.size()) |
| staticBasis = staticBasis.drop_front(); |
| |
| SmallVector<Value> strides = |
| computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
| /*knownNonNegative=*/op.getDisjoint()); |
| SmallVector<std::pair<Value, int64_t>> scaledValues; |
| scaledValues.reserve(numIndexes); |
| |
| // Note: strides doesn't contain a value for the final element (stride 1) |
| // and everything else lines up. We use the "mutable" accessor so we can get |
| // our hands on an `OpOperand&` for the loop invariant counting function. |
| for (auto [stride, idxOp] : |
| llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { |
| Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride, |
| arith::IntegerOverflowFlags::nsw); |
| int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); |
| scaledValues.emplace_back(scaledIdx, numHoistableLoops); |
| } |
| scaledValues.emplace_back( |
| multiIndex.back(), |
| numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1])); |
| |
| // Sort by how many enclosing loops there are, ties implicitly broken by |
| // size of the stride. |
| llvm::stable_sort(scaledValues, |
| [&](auto l, auto r) { return l.second > r.second; }); |
| |
| Value result = scaledValues.front().first; |
| for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) { |
| std::ignore = numHoistableLoops; |
| result = arith::AddIOp::create(rewriter, loc, result, scaledValue, |
| arith::IntegerOverflowFlags::nsw); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| } |
| |
| namespace { |
| struct LowerDelinearizeIndexOps |
| : public OpRewritePattern<AffineDelinearizeIndexOp> { |
| using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, |
| PatternRewriter &rewriter) const override { |
| return affine::lowerAffineDelinearizeIndexOp(rewriter, op); |
| } |
| }; |
| |
| struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { |
| using OpRewritePattern::OpRewritePattern; |
| LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, |
| PatternRewriter &rewriter) const override { |
| return affine::lowerAffineLinearizeIndexOp(rewriter, op); |
| } |
| }; |
| |
| class ExpandAffineIndexOpsPass |
| : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> { |
| public: |
| ExpandAffineIndexOpsPass() = default; |
| |
| void runOnOperation() override { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(context); |
| populateAffineExpandIndexOpsPatterns(patterns); |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::affine::populateAffineExpandIndexOpsPatterns( |
| RewritePatternSet &patterns) { |
| patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( |
| patterns.getContext()); |
| } |
| |
| std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() { |
| return std::make_unique<ExpandAffineIndexOpsPass>(); |
| } |