blob: 60c8660658a953e2a20a6802824bb20a0d7b60ea [file] [log] [blame]
//===- MaskedloadToLoad.cpp - Lowers maskedload to load -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/MathExtras.h"
namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUMASKEDLOADTOLOADPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace mlir::amdgpu
using namespace mlir;
using namespace mlir::amdgpu;
/// This pattern supports lowering of: `vector.maskedload` to `vector.load`
/// and `arith.select` if the memref is in buffer address space.
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
vector::MaskedLoadOp maskedOp) {
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
if (!memRefType)
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
return rewriter.notifyMatchFailure(maskedOp, "no address space");
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
return success();
}
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
vector::MaskedLoadOp maskedOp,
bool passthru) {
VectorType vectorType = maskedOp.getVectorType();
Value load = builder.create<vector::LoadOp>(
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
if (passthru)
load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
load, maskedOp.getPassThru());
return load;
}
/// Check if the given value comes from a broadcasted i1 condition.
static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
if (!broadcastOp)
return failure();
if (isa<VectorType>(broadcastOp.getSourceType()))
return failure();
return broadcastOp.getSource();
}
static constexpr char kMaskedloadNeedsMask[] =
"amdgpu.buffer_maskedload_needs_mask";
namespace {
struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp,
PatternRewriter &rewriter) const override {
if (maskedOp->hasAttr(kMaskedloadNeedsMask))
return failure();
if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
return failure();
}
// Check if this is either a full inbounds load or an empty, oob load. If
// so, take the fast path and don't generate an if condition, because we
// know doing the oob load is always safe.
if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
maskedOp, /*passthru=*/true);
rewriter.replaceOp(maskedOp, load);
return success();
}
Location loc = maskedOp.getLoc();
Value src = maskedOp.getBase();
VectorType vectorType = maskedOp.getVectorType();
int64_t vectorSize = vectorType.getNumElements();
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
SmallVector<OpFoldResult> indices = maskedOp.getIndices();
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
SmallVector<OpFoldResult> strides =
stridedMetadata.getConstifiedMixedStrides();
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
memref::LinearizedMemRefInfo linearizedInfo;
OpFoldResult linearizedIndices;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
elementBitWidth, offset, sizes,
strides, indices);
// delta = bufferSize - linearizedOffset
Value vectorSizeOffset =
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
Value linearIndex =
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
Value totalSize = getValueOrCreateConstantIndexOp(
rewriter, loc, linearizedInfo.linearizedSize);
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
// 1) check if delta < vectorSize
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
// 2) check if (detla % elements_per_word != 0)
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
loc, llvm::divideCeil(32, elementBitWidth));
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne,
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
// We take the fallback of maskedload default lowering only it is both
// out-of-bounds and not word aligned. The fallback ensures correct results
// when loading at the boundary of the buffer since buffer load returns
// inconsistent zeros for the whole word when boundary is crossed.
Value ifCondition =
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
Operation *read = builder.clone(*maskedOp.getOperation());
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
Value readResult = read->getResult(0);
builder.create<scf::YieldOp>(loc, readResult);
};
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
/*passthru=*/true);
rewriter.create<scf::YieldOp>(loc, res);
};
auto ifOp =
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
rewriter.replaceOp(maskedOp, ifOp);
return success();
}
};
struct FullMaskedLoadToConditionalLoad
: OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
PatternRewriter &rewriter) const override {
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
if (failed(maybeCond)) {
return failure();
}
Value cond = maybeCond.value();
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
/*passthru=*/false);
rewriter.create<scf::YieldOp>(loc, res);
};
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
};
auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
falseBuilder);
rewriter.replaceOp(loadOp, ifOp);
return success();
}
};
struct FullMaskedStoreToConditionalStore
: OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
PatternRewriter &rewriter) const override {
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
if (failed(maybeCond)) {
return failure();
}
Value cond = maybeCond.value();
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
storeOp.getBase(), storeOp.getIndices());
rewriter.create<scf::YieldOp>(loc);
};
auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
rewriter.replaceOp(storeOp, ifOp);
return success();
}
};
} // namespace
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
FullMaskedStoreToConditionalStore>(patterns.getContext(),
benefit);
}
struct AmdgpuMaskedloadToLoadPass final
: amdgpu::impl::AmdgpuMaskedloadToLoadPassBase<AmdgpuMaskedloadToLoadPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAmdgpuMaskedloadToLoadPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};