[mlir][gpu] Warp execute terminator getter (#154729)
Adds a utility getter to `warp_execute_on_lane_0` which simplifies
access to the op's terminator.
Uses are refactored to utilize the new terminator getter.
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f946bb7..a5c3a92 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -3209,6 +3209,9 @@
bool isDefinedOutsideOfRegion(Value value) {
return !getRegion().isAncestor(value.getParentRegion());
}
+
+ /// Get the terminator of the warp region.
+ gpu::YieldOp getTerminator();
}];
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 2503ccb..cc77aa6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2486,8 +2486,7 @@
if (getArgs().size() != getWarpRegion().getNumArguments())
return emitOpError(
"expected same number op arguments and block arguments.");
- auto yield =
- cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = getTerminator();
if (yield.getNumOperands() != getNumResults())
return emitOpError(
"expected same number of yield operands and return values.");
@@ -2511,6 +2510,10 @@
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}
+gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() {
+ return cast<gpu::YieldOp>(getBody()->getTerminator());
+}
+
//===----------------------------------------------------------------------===//
// GPU KernelMetadataAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index be71bd0..88f531f 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -56,8 +56,7 @@
SmallVector<size_t> &indices) const {
SmallVector<Type> types(warpOp.getResultTypes().begin(),
warpOp.getResultTypes().end());
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
SmallVector<Value> yieldValues(yield.getOperands().begin(),
yield.getOperands().end());
llvm::SmallDenseMap<Value, unsigned> indexLookup;
@@ -89,8 +88,7 @@
OpOperand *WarpDistributionPattern::getWarpResult(
WarpExecuteOnLane0Op warpOp,
llvm::function_ref<bool(Operation *)> fn) const {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
for (OpOperand &yieldOperand : yield->getOpOperands()) {
Value yieldValues = yieldOperand.get();
Operation *definedOp = yieldValues.getDefiningOp();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index be0d28a..60aa0e9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -528,8 +528,7 @@
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode);
if (!writeOp)
@@ -846,8 +845,7 @@
newYieldValues.reserve(warpOp->getNumResults());
DenseMap<Value, int64_t> dedupYieldOperandPositionMap;
DenseMap<OpResult, int64_t> dedupResultPositionMap;
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
// Some values may be yielded multiple times and correspond to multiple
// results. Deduplicating occurs by taking each result with its matching
@@ -901,8 +899,7 @@
using Base::Base;
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Value valForwarded;
unsigned resultIndex;
for (OpOperand &operand : yield->getOpOperands()) {
@@ -1708,8 +1705,7 @@
: WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto warpOpYield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp warpOpYield = warpOp.getTerminator();
// Only pick up `ForOp` if it is the last op in the region.
Operation *lastNode = warpOpYield->getPrevNode();
auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2088c3c..8e47968 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -336,8 +336,7 @@
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode);
if (!storeOp)
@@ -449,8 +448,7 @@
// Make sure the same load op is the last operation in the warp op body.
// This ensure that load op is not sinked earlier violating any barrier
// synchronizations.
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
return yield->getPrevNode() == op;
});
@@ -752,8 +750,7 @@
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode);
if (!prefetchOp)
@@ -794,8 +791,7 @@
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
- auto yield = cast<gpu::YieldOp>(
- warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
// The last node must be a gpu::BarrierOp.
auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode);