| //===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Implements gradual lowering of `gpu.subgroup_reduce` ops. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" |
| #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/GPU/Transforms/Passes.h" |
| #include "mlir/Dialect/GPU/Utils/GPUUtils.h" |
| #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/MathExtras.h" |
| #include <cassert> |
| #include <cstdint> |
| |
| using namespace mlir; |
| |
| namespace { |
| |
| /// Example, assumes `maxShuffleBitwidth` equal to 32: |
| /// ``` |
| /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> |
| /// ==> |
| /// %v0 = arith.constant dense<0.0> : vector<3xf16> |
| /// %e0 = vector.extract_strided_slice %x |
| /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> |
| /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> |
| /// %v1 = vector.insert_strided_slice %r0, %v0 |
| /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> |
| /// %e1 = vector.extract %x[2] : f16 from vector<2xf16> |
| /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 |
| /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> |
| /// ``` |
| struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { |
| BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, |
| PatternBenefit benefit) |
| : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { |
| } |
| |
| LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| PatternRewriter &rewriter) const override { |
| auto vecTy = dyn_cast<VectorType>(op.getType()); |
| if (!vecTy || vecTy.getNumElements() < 2) |
| return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); |
| |
| assert(vecTy.getRank() == 1 && "Unexpected vector type"); |
| assert(!vecTy.isScalable() && "Unexpected vector type"); |
| |
| Type elemTy = vecTy.getElementType(); |
| unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); |
| if (elemBitwidth >= maxShuffleBitwidth) |
| return rewriter.notifyMatchFailure( |
| op, llvm::formatv("element type too large ({0}), cannot break down " |
| "into vectors of bitwidth {1} or less", |
| elemBitwidth, maxShuffleBitwidth)); |
| |
| unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; |
| assert(elementsPerShuffle >= 1); |
| |
| unsigned numNewReductions = |
| llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); |
| assert(numNewReductions >= 1); |
| if (numNewReductions == 1) |
| return rewriter.notifyMatchFailure(op, "nothing to break down"); |
| |
| Location loc = op.getLoc(); |
| Value res = |
| arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(vecTy)); |
| |
| for (unsigned i = 0; i != numNewReductions; ++i) { |
| int64_t startIdx = i * elementsPerShuffle; |
| int64_t endIdx = |
| std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); |
| int64_t numElems = endIdx - startIdx; |
| |
| Value extracted; |
| if (numElems == 1) { |
| extracted = |
| vector::ExtractOp::create(rewriter, loc, op.getValue(), startIdx); |
| } else { |
| extracted = vector::ExtractStridedSliceOp::create( |
| rewriter, loc, op.getValue(), /*offsets=*/startIdx, |
| /*sizes=*/numElems, |
| /*strides=*/1); |
| } |
| |
| Value reduce = gpu::SubgroupReduceOp::create( |
| rewriter, loc, extracted, op.getOp(), op.getUniform(), |
| op.getClusterSize(), op.getClusterStride()); |
| if (numElems == 1) { |
| res = vector::InsertOp::create(rewriter, loc, reduce, res, startIdx); |
| continue; |
| } |
| |
| res = vector::InsertStridedSliceOp::create( |
| rewriter, loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); |
| } |
| |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| |
| private: |
| unsigned maxShuffleBitwidth = 0; |
| }; |
| |
| /// Example: |
| /// ``` |
| /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> |
| /// ==> |
| /// %e0 = vector.extract %x[0] : f32 from vector<1xf32> |
| /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 |
| /// %a = vector.broadcast %r0 : f32 to vector<1xf32> |
| /// ``` |
| struct ScalarizeSingleElementReduce final |
| : OpRewritePattern<gpu::SubgroupReduceOp> { |
| using OpRewritePattern::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| PatternRewriter &rewriter) const override { |
| auto vecTy = dyn_cast<VectorType>(op.getType()); |
| if (!vecTy || vecTy.getNumElements() != 1) |
| return rewriter.notifyMatchFailure(op, "not a single-element reduction"); |
| |
| assert(vecTy.getRank() == 1 && "Unexpected vector type"); |
| assert(!vecTy.isScalable() && "Unexpected vector type"); |
| Location loc = op.getLoc(); |
| Value extracted = |
| vector::ExtractOp::create(rewriter, loc, op.getValue(), 0); |
| Value reduce = gpu::SubgroupReduceOp::create( |
| rewriter, loc, extracted, op.getOp(), op.getUniform(), |
| op.getClusterSize(), op.getClusterStride()); |
| rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); |
| return success(); |
| } |
| }; |
| |
| struct ClusterInfo { |
| unsigned clusterStride; |
| unsigned clusterSize; |
| unsigned subgroupSize; |
| }; |
| |
| static FailureOr<ClusterInfo> |
| getAndValidateClusterInfo(gpu::SubgroupReduceOp op, unsigned subgroupSize) { |
| assert(llvm::isPowerOf2_32(subgroupSize)); |
| |
| std::optional<uint32_t> clusterSize = op.getClusterSize(); |
| assert(!clusterSize || |
| llvm::isPowerOf2_32(*clusterSize)); // Verifier should've caught this. |
| if (clusterSize && *clusterSize > subgroupSize) |
| return op.emitOpError() |
| << "cluster size " << *clusterSize |
| << " is greater than subgroup size " << subgroupSize; |
| unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize); |
| |
| auto clusterStride = op.getClusterStride(); |
| assert(llvm::isPowerOf2_32(clusterStride)); // Verifier should've caught this. |
| if (clusterStride >= subgroupSize) |
| return op.emitOpError() |
| << "cluster stride " << clusterStride |
| << " is not less than subgroup size " << subgroupSize; |
| |
| return ClusterInfo{clusterStride, effectiveClusterSize, subgroupSize}; |
| } |
| |
| /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` |
| /// and `unpackFn` to convert to the native shuffle type and to the reduction |
| /// type, respectively. For example, with `input` of type `f16`, `packFn` could |
| /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` |
| /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that |
| /// the subgroup is `subgroupSize` lanes wide and divides it into clusters of |
| /// `clusterSize` lanes starting at lane 0 with a stride of `clusterStride` for |
| /// lanes within a cluster, reducing all lanes in each cluster in parallel. |
| Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, |
| Value input, gpu::AllReduceOperation mode, |
| const ClusterInfo &ci, |
| function_ref<Value(Value)> packFn, |
| function_ref<Value(Value)> unpackFn) { |
| // Lane value always stays in the original type. We use it to perform arith |
| // reductions. |
| Value laneVal = input; |
| // Parallel reduction using butterfly shuffles. |
| for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; |
| i <<= 1) { |
| Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i, |
| /*width=*/ci.subgroupSize, |
| /*mode=*/gpu::ShuffleMode::XOR) |
| .getShuffleResult(); |
| laneVal = vector::makeArithReduction(builder, loc, |
| gpu::convertReductionKind(mode), |
| laneVal, unpackFn(shuffled)); |
| assert(laneVal.getType() == input.getType()); |
| } |
| |
| return laneVal; |
| } |
| |
| /// Lowers scalar gpu subgroup reductions to a series of shuffles. |
| struct ScalarSubgroupReduceToShuffles final |
| : OpRewritePattern<gpu::SubgroupReduceOp> { |
| ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, |
| unsigned shuffleBitwidth, bool matchClustered, |
| PatternBenefit benefit) |
| : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
| shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} |
| |
| LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getClusterSize().has_value() != matchClustered) { |
| return rewriter.notifyMatchFailure( |
| op, llvm::formatv("op is {0}clustered but pattern is configured to " |
| "only match {1}clustered ops", |
| matchClustered ? "non-" : "", |
| matchClustered ? "" : "non-")); |
| } |
| |
| auto ci = getAndValidateClusterInfo(op, subgroupSize); |
| if (failed(ci)) |
| return failure(); |
| |
| Type valueTy = op.getType(); |
| unsigned elemBitwidth = |
| getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); |
| if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) |
| return rewriter.notifyMatchFailure( |
| op, "value type is not a compatible scalar"); |
| |
| Location loc = op.getLoc(); |
| // Since this is already a native shuffle scalar, no packing is necessary. |
| if (elemBitwidth == shuffleBitwidth) { |
| auto identityFn = [](Value v) { return v; }; |
| rewriter.replaceOp(op, createSubgroupShuffleReduction( |
| rewriter, loc, op.getValue(), op.getOp(), *ci, |
| identityFn, identityFn)); |
| return success(); |
| } |
| |
| auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); |
| auto equivIntType = rewriter.getIntegerType(elemBitwidth); |
| auto packFn = [loc, &rewriter, equivIntType, |
| shuffleIntType](Value unpackedVal) -> Value { |
| auto asInt = |
| arith::BitcastOp::create(rewriter, loc, equivIntType, unpackedVal); |
| return arith::ExtUIOp::create(rewriter, loc, shuffleIntType, asInt); |
| }; |
| auto unpackFn = [loc, &rewriter, equivIntType, |
| valueTy](Value packedVal) -> Value { |
| auto asInt = |
| arith::TruncIOp::create(rewriter, loc, equivIntType, packedVal); |
| return arith::BitcastOp::create(rewriter, loc, valueTy, asInt); |
| }; |
| |
| rewriter.replaceOp( |
| op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(), |
| op.getOp(), *ci, packFn, unpackFn)); |
| return success(); |
| } |
| |
| private: |
| unsigned subgroupSize = 0; |
| unsigned shuffleBitwidth = 0; |
| bool matchClustered = false; |
| }; |
| |
| /// Lowers vector gpu subgroup reductions to a series of shuffles. |
| struct VectorSubgroupReduceToShuffles final |
| : OpRewritePattern<gpu::SubgroupReduceOp> { |
| VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, |
| unsigned shuffleBitwidth, bool matchClustered, |
| PatternBenefit benefit) |
| : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
| shuffleBitwidth(shuffleBitwidth), matchClustered(matchClustered) {} |
| |
| LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getClusterSize().has_value() != matchClustered) { |
| return rewriter.notifyMatchFailure( |
| op, llvm::formatv("op is {0}clustered but pattern is configured to " |
| "only match {1}clustered ops", |
| matchClustered ? "non-" : "", |
| matchClustered ? "" : "non-")); |
| } |
| |
| auto ci = getAndValidateClusterInfo(op, subgroupSize); |
| if (failed(ci)) |
| return failure(); |
| |
| auto vecTy = dyn_cast<VectorType>(op.getType()); |
| if (!vecTy) |
| return rewriter.notifyMatchFailure(op, "value type is not a vector"); |
| |
| unsigned vecBitwidth = |
| vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); |
| if (vecBitwidth > shuffleBitwidth) |
| return rewriter.notifyMatchFailure( |
| op, |
| llvm::formatv("vector type bitwidth too large ({0}), cannot lower " |
| "to shuffles of size {1}", |
| vecBitwidth, shuffleBitwidth)); |
| |
| unsigned elementsPerShuffle = |
| shuffleBitwidth / vecTy.getElementTypeBitWidth(); |
| if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) |
| return rewriter.notifyMatchFailure( |
| op, "shuffle bitwidth is not a multiple of the element bitwidth"); |
| |
| Location loc = op.getLoc(); |
| |
| // If the reduced type is smaller than the native shuffle size, extend it, |
| // perform the shuffles, and extract at the end. |
| auto extendedVecTy = VectorType::get( |
| static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType()); |
| Value extendedInput = op.getValue(); |
| if (vecBitwidth < shuffleBitwidth) { |
| auto zero = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getZeroAttr(extendedVecTy)); |
| extendedInput = vector::InsertStridedSliceOp::create( |
| rewriter, loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); |
| } |
| |
| auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); |
| auto shuffleVecType = VectorType::get(1, shuffleIntType); |
| |
| auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { |
| auto asIntVec = |
| vector::BitCastOp::create(rewriter, loc, shuffleVecType, unpackedVal); |
| return vector::ExtractOp::create(rewriter, loc, asIntVec, 0); |
| }; |
| auto unpackFn = [loc, &rewriter, shuffleVecType, |
| extendedVecTy](Value packedVal) -> Value { |
| auto asIntVec = |
| vector::BroadcastOp::create(rewriter, loc, shuffleVecType, packedVal); |
| return vector::BitCastOp::create(rewriter, loc, extendedVecTy, asIntVec); |
| }; |
| |
| Value res = createSubgroupShuffleReduction( |
| rewriter, loc, extendedInput, op.getOp(), *ci, packFn, unpackFn); |
| |
| if (vecBitwidth < shuffleBitwidth) { |
| res = vector::ExtractStridedSliceOp::create( |
| rewriter, loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), |
| /*strides=*/1); |
| } |
| |
| rewriter.replaceOp(op, res); |
| return success(); |
| } |
| |
| private: |
| unsigned subgroupSize = 0; |
| unsigned shuffleBitwidth = 0; |
| bool matchClustered = false; |
| }; |
| |
| static FailureOr<Value> |
| createSubgroupDPPReduction(PatternRewriter &rewriter, gpu::SubgroupReduceOp op, |
| Value input, gpu::AllReduceOperation mode, |
| const ClusterInfo &ci, amdgpu::Chipset chipset) { |
| Location loc = op.getLoc(); |
| Value dpp; |
| Value res = input; |
| constexpr int allRows = 0xf; |
| constexpr int allBanks = 0xf; |
| const bool boundCtrl = true; |
| if (ci.clusterSize >= 2) { |
| // Perform reduction between all lanes N <-> N+1. |
| dpp = amdgpu::DPPOp::create( |
| rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, |
| rewriter.getI32ArrayAttr({1, 0, 3, 2}), allRows, allBanks, boundCtrl); |
| res = vector::makeArithReduction(rewriter, loc, |
| gpu::convertReductionKind(mode), res, dpp); |
| } |
| |
| if (ci.clusterSize >= 4) { |
| // Perform reduction between all lanes N <-> N+2. |
| dpp = amdgpu::DPPOp::create( |
| rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::quad_perm, |
| rewriter.getI32ArrayAttr({2, 3, 0, 1}), allRows, allBanks, boundCtrl); |
| res = vector::makeArithReduction(rewriter, loc, |
| gpu::convertReductionKind(mode), res, dpp); |
| } |
| if (ci.clusterSize >= 8) { |
| // Perform reduction between all lanes N <-> 7-N, |
| // e.g lane[0] <-> lane[7], lane[1] <-> lane[6]..., lane[3] <-> lane[4]. |
| dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, |
| amdgpu::DPPPerm::row_half_mirror, |
| rewriter.getUnitAttr(), allRows, allBanks, |
| boundCtrl); |
| res = vector::makeArithReduction(rewriter, loc, |
| gpu::convertReductionKind(mode), res, dpp); |
| } |
| if (ci.clusterSize >= 16) { |
| // Perform reduction between all lanes N <-> 15-N, |
| // e.g lane[0] <-> lane[15], lane[1] <-> lane[14]..., lane[7] <-> lane[8]. |
| dpp = amdgpu::DPPOp::create( |
| rewriter, loc, res.getType(), res, res, amdgpu::DPPPerm::row_mirror, |
| rewriter.getUnitAttr(), allRows, allBanks, boundCtrl); |
| res = vector::makeArithReduction(rewriter, loc, |
| gpu::convertReductionKind(mode), res, dpp); |
| } |
| if (ci.clusterSize >= 32) { |
| if (chipset.majorVersion <= 9) { |
| // Broadcast last value from each row to next row. |
| // Use row mask to avoid polluting rows 1 and 3. |
| dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, |
| amdgpu::DPPPerm::row_bcast_15, |
| rewriter.getUnitAttr(), 0xa, allBanks, |
| /*bound_ctrl*/ false); |
| res = vector::makeArithReduction( |
| rewriter, loc, gpu::convertReductionKind(mode), res, dpp); |
| } else if (chipset.majorVersion <= 12) { |
| // Use a permute lane to cross rows (row 1 <-> row 0, row 3 <-> row 2). |
| Value uint32Max = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(-1)); |
| dpp = ROCDL::PermlaneX16Op::create(rewriter, loc, res.getType(), res, res, |
| uint32Max, uint32Max, |
| /*fi=*/true, |
| /*bound_ctrl=*/false); |
| res = vector::makeArithReduction( |
| rewriter, loc, gpu::convertReductionKind(mode), res, dpp); |
| } else { |
| return rewriter.notifyMatchFailure( |
| op, "Subgroup reduce lowering to DPP not currently supported for " |
| "this device."); |
| } |
| if (ci.subgroupSize == 32) { |
| Value lane31 = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); |
| res = |
| ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31); |
| } |
| } |
| if (ci.clusterSize >= 64) { |
| if (chipset.majorVersion <= 9) { |
| // Broadcast 31st lane value to rows 2 and 3. |
| dpp = amdgpu::DPPOp::create(rewriter, loc, res.getType(), res, res, |
| amdgpu::DPPPerm::row_bcast_31, |
| rewriter.getUnitAttr(), 0xf, allBanks, |
| /*bound_ctrl*/ true); |
| res = vector::makeArithReduction( |
| rewriter, loc, gpu::convertReductionKind(mode), dpp, res); |
| // Obtain reduction from last rows, the previous rows are polluted. |
| Value lane63 = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); |
| res = |
| ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63); |
| |
| } else if (chipset.majorVersion <= 12) { |
| // Assume reduction across 32 lanes has been done. |
| // Perform final reduction manually by summing values in lane 0 and |
| // lane 32. |
| Value lane31 = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(31)); |
| Value lane63 = arith::ConstantOp::create( |
| rewriter, loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(63)); |
| lane31 = |
| ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane31); |
| lane63 = |
| ROCDL::ReadlaneOp::create(rewriter, loc, res.getType(), res, lane63); |
| res = vector::makeArithReduction( |
| rewriter, loc, gpu::convertReductionKind(mode), lane31, lane63); |
| } else { |
| return rewriter.notifyMatchFailure( |
| op, "Subgroup reduce lowering to DPP not currently supported for " |
| "this device."); |
| } |
| } |
| assert(res.getType() == input.getType()); |
| return res; |
| } |
| |
| /// Collect a set of patterns to lower `gpu.subgroup_reduce` into `amdgpu.dpp` |
| /// ops over scalar types. Assumes that the subgroup has |
| /// `subgroupSize` lanes. Applicable only to AMD GPUs. |
| struct ScalarSubgroupReduceToDPP final |
| : OpRewritePattern<gpu::SubgroupReduceOp> { |
| ScalarSubgroupReduceToDPP(MLIRContext *ctx, unsigned subgroupSize, |
| bool matchClustered, amdgpu::Chipset chipset, |
| PatternBenefit benefit) |
| : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), |
| matchClustered(matchClustered), chipset(chipset) {} |
| |
| LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.getClusterSize().has_value() != matchClustered) { |
| return rewriter.notifyMatchFailure( |
| op, llvm::formatv("op is {0}clustered but pattern is configured to " |
| "only match {1}clustered ops", |
| matchClustered ? "non-" : "", |
| matchClustered ? "" : "non-")); |
| } |
| auto ci = getAndValidateClusterInfo(op, subgroupSize); |
| if (failed(ci)) |
| return failure(); |
| |
| if (ci->clusterStride != 1) |
| return rewriter.notifyMatchFailure( |
| op, "Subgroup reductions using DPP are currently only available for " |
| "clusters of contiguous lanes."); |
| |
| Type valueTy = op.getType(); |
| if (!valueTy.isIntOrFloat()) |
| return rewriter.notifyMatchFailure( |
| op, "Value type is not a compatible scalar."); |
| |
| FailureOr<Value> dpp = createSubgroupDPPReduction( |
| rewriter, op, op.getValue(), op.getOp(), *ci, chipset); |
| if (failed(dpp)) |
| return failure(); |
| |
| rewriter.replaceOp(op, dpp.value()); |
| return success(); |
| } |
| |
| private: |
| unsigned subgroupSize = 0; |
| bool matchClustered = false; |
| amdgpu::Chipset chipset; |
| }; |
| } // namespace |
| |
| void mlir::populateGpuBreakDownSubgroupReducePatterns( |
| RewritePatternSet &patterns, unsigned maxShuffleBitwidth, |
| PatternBenefit benefit) { |
| patterns.add<BreakDownSubgroupReduce>(patterns.getContext(), |
| maxShuffleBitwidth, benefit); |
| patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit); |
| } |
| |
| void mlir::populateGpuLowerSubgroupReduceToDPPPatterns( |
| RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, |
| PatternBenefit benefit) { |
| patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize, |
| /*matchClustered=*/false, chipset, |
| benefit); |
| } |
| |
| void mlir::populateGpuLowerClusteredSubgroupReduceToDPPPatterns( |
| RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset, |
| PatternBenefit benefit) { |
| patterns.add<ScalarSubgroupReduceToDPP>(patterns.getContext(), subgroupSize, |
| /*matchClustered=*/true, chipset, |
| benefit); |
| } |
| |
| void mlir::populateGpuLowerSubgroupReduceToShufflePatterns( |
| RewritePatternSet &patterns, unsigned subgroupSize, |
| unsigned shuffleBitwidth, PatternBenefit benefit) { |
| patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( |
| patterns.getContext(), subgroupSize, shuffleBitwidth, |
| /*matchClustered=*/false, benefit); |
| } |
| |
| void mlir::populateGpuLowerClusteredSubgroupReduceToShufflePatterns( |
| RewritePatternSet &patterns, unsigned subgroupSize, |
| unsigned shuffleBitwidth, PatternBenefit benefit) { |
| patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>( |
| patterns.getContext(), subgroupSize, shuffleBitwidth, |
| /*matchClustered=*/true, benefit); |
| } |