| //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
| #include "mlir/Dialect/GPU/Utils/DistributionUtils.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Transforms/RegionUtils.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include <utility> |
| |
| using namespace mlir; |
| using namespace mlir::vector; |
| using namespace mlir::gpu; |
| |
| /// Currently the distribution map is implicit based on the vector shape. In the |
| /// future it will be part of the op. |
| /// Example: |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { |
| /// ... |
| /// gpu.yield %3 : vector<32x16x64xf32> |
| /// } |
| /// ``` |
| /// Would have an implicit map of: |
| /// `(d0, d1, d2) -> (d0, d2)` |
| static AffineMap calculateImplicitMap(VectorType sequentialType, |
| VectorType distributedType) { |
| SmallVector<AffineExpr> perm; |
| perm.reserve(1); |
| // Check which dimensions of the sequential type are different than the |
| // dimensions of the distributed type to know the distributed dimensions. Then |
| // associate each distributed dimension to an ID in order. |
| for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { |
| if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) |
| perm.push_back(getAffineDimExpr(i, distributedType.getContext())); |
| } |
| auto map = AffineMap::get(sequentialType.getRank(), 0, perm, |
| distributedType.getContext()); |
| return map; |
| } |
| |
| namespace { |
| |
| /// Helper struct to create the load / store operations that permit transit |
| /// through the parallel / sequential and the sequential / parallel boundaries |
| /// when performing `rewriteWarpOpToScfFor`. |
| /// |
| /// The vector distribution dimension is inferred from the vector types. |
| struct DistributedLoadStoreHelper { |
| DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, |
| Value laneId, Value zero) |
| : sequentialVal(sequentialVal), distributedVal(distributedVal), |
| laneId(laneId), zero(zero) { |
| sequentialVectorType = dyn_cast<VectorType>(sequentialVal.getType()); |
| distributedVectorType = dyn_cast<VectorType>(distributedVal.getType()); |
| if (sequentialVectorType && distributedVectorType) |
| distributionMap = |
| calculateImplicitMap(sequentialVectorType, distributedVectorType); |
| } |
| |
| Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { |
| int64_t distributedSize = distributedVectorType.getDimSize(index); |
| AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); |
| return b.createOrFold<affine::AffineApplyOp>(loc, tid * distributedSize, |
| ArrayRef<Value>{laneId}); |
| } |
| |
| /// Create a store during the process of distributing the |
| /// `vector.warp_execute_on_thread_0` op. |
| /// Vector distribution assumes the following convention regarding the |
| /// temporary buffers that are created to transition values. This **must** |
| /// be properly specified in the `options.warpAllocationFn`: |
| /// 1. scalars of type T transit through a memref<1xT>. |
| /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
| Operation *buildStore(RewriterBase &b, Location loc, Value val, |
| Value buffer) { |
| assert((val == distributedVal || val == sequentialVal) && |
| "Must store either the preregistered distributed or the " |
| "preregistered sequential value."); |
| // Scalar case can directly use memref.store. |
| if (!isa<VectorType>(val.getType())) |
| return b.create<memref::StoreOp>(loc, val, buffer, zero); |
| |
| // Vector case must use vector::TransferWriteOp which will later lower to |
| // vector.store of memref.store depending on further lowerings. |
| int64_t rank = sequentialVectorType.getRank(); |
| SmallVector<Value> indices(rank, zero); |
| if (val == distributedVal) { |
| for (auto dimExpr : distributionMap.getResults()) { |
| int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); |
| indices[index] = buildDistributedOffset(b, loc, index); |
| } |
| } |
| SmallVector<bool> inBounds(indices.size(), true); |
| return b.create<vector::TransferWriteOp>( |
| loc, val, buffer, indices, |
| ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
| } |
| |
| /// Create a load during the process of distributing the |
| /// `vector.warp_execute_on_thread_0` op. |
| /// Vector distribution assumes the following convention regarding the |
| /// temporary buffers that are created to transition values. This **must** |
| /// be properly specified in the `options.warpAllocationFn`: |
| /// 1. scalars of type T transit through a memref<1xT>. |
| /// 2. vectors of type V<shapexT> transit through a memref<shapexT> |
| /// |
| /// When broadcastMode is true, the load is not distributed to account for |
| /// the broadcast semantics of the `gpu.warp_execute_on_lane_0` op. |
| /// |
| /// Example: |
| /// |
| /// ``` |
| /// %r = gpu.warp_execute_on_lane_0(...) -> (f32) { |
| /// gpu.yield %cst : f32 |
| /// } |
| /// // Both types are f32. The constant %cst is broadcasted to all lanes. |
| /// ``` |
| /// This behavior described in more detail in the documentation of the op. |
| Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { |
| |
| // Scalar case can directly use memref.store. |
| if (!isa<VectorType>(type)) |
| return b.create<memref::LoadOp>(loc, buffer, zero); |
| |
| // Other cases must be vector atm. |
| // Vector case must use vector::TransferReadOp which will later lower to |
| // vector.read of memref.read depending on further lowerings. |
| assert((type == distributedVectorType || type == sequentialVectorType) && |
| "Must store either the preregistered distributed or the " |
| "preregistered sequential type."); |
| SmallVector<Value> indices(sequentialVectorType.getRank(), zero); |
| if (type == distributedVectorType) { |
| for (auto dimExpr : distributionMap.getResults()) { |
| int64_t index = cast<AffineDimExpr>(dimExpr).getPosition(); |
| indices[index] = buildDistributedOffset(b, loc, index); |
| } |
| } |
| SmallVector<bool> inBounds(indices.size(), true); |
| return b.create<vector::TransferReadOp>( |
| loc, cast<VectorType>(type), buffer, indices, |
| ArrayRef<bool>(inBounds.begin(), inBounds.end())); |
| } |
| |
| Value sequentialVal, distributedVal, laneId, zero; |
| VectorType sequentialVectorType, distributedVectorType; |
| AffineMap distributionMap; |
| }; |
| |
| } // namespace |
| |
| // Clones `op` into a new operation that takes `operands` and returns |
| // `resultTypes`. |
| static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, |
| Location loc, Operation *op, |
| ArrayRef<Value> operands, |
| ArrayRef<Type> resultTypes) { |
| OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, |
| op->getAttrs()); |
| return rewriter.create(res); |
| } |
| |
| namespace { |
| |
| /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single |
| /// thread `laneId` executes the entirety of the computation. |
| /// |
| /// After the transformation: |
| /// - the IR within the scf.if op can be thought of as executing sequentially |
| /// (from the point of view of threads along `laneId`). |
| /// - the IR outside of the scf.if op can be thought of as executing in |
| /// parallel (from the point of view of threads along `laneId`). |
| /// |
| /// Values that need to transit through the parallel / sequential and the |
| /// sequential / parallel boundaries do so via reads and writes to a temporary |
| /// memory location. |
| /// |
| /// The transformation proceeds in multiple steps: |
| /// 1. Create the scf.if op. |
| /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads |
| /// within the scf.if to transit the values captured from above. |
| /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are |
| /// consistent within the scf.if. |
| /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. |
| /// 5. Insert appropriate writes within scf.if and reads after the scf.if to |
| /// transit the values returned by the op. |
| /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are |
| /// consistent after the scf.if. |
| /// 7. Perform late cleanups. |
| /// |
| /// All this assumes the vector distribution occurs along the most minor |
| /// distributed vector dimension. |
| struct WarpOpToScfIfPattern : public WarpDistributionPattern { |
| WarpOpToScfIfPattern(MLIRContext *context, |
| const WarpExecuteOnLane0LoweringOptions &options, |
| PatternBenefit benefit = 1) |
| : WarpDistributionPattern(context, benefit), options(options) {} |
| |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| assert(warpOp.getBodyRegion().hasOneBlock() && |
| "expected WarpOp with single block"); |
| Block *warpOpBody = &warpOp.getBodyRegion().front(); |
| Location loc = warpOp.getLoc(); |
| |
| // Passed all checks. Start rewriting. |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(warpOp); |
| |
| // Step 1: Create scf.if op. |
| Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| Value isLane0 = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); |
| auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, |
| /*withElseRegion=*/false); |
| rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); |
| |
| // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and |
| // reads within the scf.if to transit the values captured from above. |
| SmallVector<Value> bbArgReplacements; |
| for (const auto &it : llvm::enumerate(warpOp.getArgs())) { |
| Value sequentialVal = warpOpBody->getArgument(it.index()); |
| Value distributedVal = it.value(); |
| DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
| warpOp.getLaneid(), c0); |
| |
| // Create buffer before the ifOp. |
| rewriter.setInsertionPoint(ifOp); |
| Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
| sequentialVal.getType()); |
| // Store distributed vector into buffer, before the ifOp. |
| helper.buildStore(rewriter, loc, distributedVal, buffer); |
| // Load sequential vector from buffer, inside the ifOp. |
| rewriter.setInsertionPointToStart(ifOp.thenBlock()); |
| bbArgReplacements.push_back( |
| helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); |
| } |
| |
| // Step 3. Insert sync after all the stores and before all the loads. |
| if (!warpOp.getArgs().empty()) { |
| rewriter.setInsertionPoint(ifOp); |
| options.warpSyncronizationFn(loc, rewriter, warpOp); |
| } |
| |
| // Step 4. Move body of warpOp to ifOp. |
| rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); |
| |
| // Step 5. Insert appropriate writes within scf.if and reads after the |
| // scf.if to transit the values returned by the op. |
| // TODO: at this point, we can reuse the shared memory from previous |
| // buffers. |
| SmallVector<Value> replacements; |
| auto yieldOp = cast<gpu::YieldOp>(ifOp.thenBlock()->getTerminator()); |
| Location yieldLoc = yieldOp.getLoc(); |
| for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { |
| Value sequentialVal = it.value(); |
| Value distributedVal = warpOp->getResult(it.index()); |
| DistributedLoadStoreHelper helper(sequentialVal, distributedVal, |
| warpOp.getLaneid(), c0); |
| |
| // Create buffer before the ifOp. |
| rewriter.setInsertionPoint(ifOp); |
| Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, |
| sequentialVal.getType()); |
| |
| // Store yielded value into buffer, inside the ifOp, before the |
| // terminator. |
| rewriter.setInsertionPoint(yieldOp); |
| helper.buildStore(rewriter, loc, sequentialVal, buffer); |
| |
| // Load distributed value from buffer, after the warpOp. |
| rewriter.setInsertionPointAfter(ifOp); |
| // Result type and yielded value type are the same. This is a broadcast. |
| // E.g.: |
| // %r = gpu.warp_execute_on_lane_0(...) -> (f32) { |
| // gpu.yield %cst : f32 |
| // } |
| // Both types are f32. The constant %cst is broadcasted to all lanes. |
| // This is described in more detail in the documentation of the op. |
| replacements.push_back( |
| helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); |
| } |
| |
| // Step 6. Insert sync after all the stores and before all the loads. |
| if (!yieldOp.getOperands().empty()) { |
| rewriter.setInsertionPointAfter(ifOp); |
| options.warpSyncronizationFn(loc, rewriter, warpOp); |
| } |
| |
| // Step 7. Delete terminator and add empty scf.yield. |
| rewriter.eraseOp(yieldOp); |
| rewriter.setInsertionPointToEnd(ifOp.thenBlock()); |
| rewriter.create<scf::YieldOp>(yieldLoc); |
| |
| // Compute replacements for WarpOp results. |
| rewriter.replaceOp(warpOp, replacements); |
| |
| return success(); |
| } |
| |
| private: |
| const WarpExecuteOnLane0LoweringOptions &options; |
| }; |
| |
| /// Return the distributed vector type based on the original type and the |
| /// distribution map. The map is expected to have a dimension equal to the |
| /// original type rank and should be a projection where the results are the |
| /// distributed dimensions. The number of results should be equal to the number |
| /// of warp sizes which is currently limited to 1. |
| /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) |
| /// and a warp size of 16 would distribute the second dimension (associated to |
| /// d1) and return vector<16x2x64> |
| static VectorType getDistributedType(VectorType originalType, AffineMap map, |
| int64_t warpSize) { |
| SmallVector<int64_t> targetShape(originalType.getShape()); |
| for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { |
| unsigned position = map.getDimPosition(i); |
| if (targetShape[position] % warpSize != 0) { |
| if (warpSize % targetShape[position] != 0) { |
| return VectorType(); |
| } |
| warpSize /= targetShape[position]; |
| targetShape[position] = 1; |
| continue; |
| } |
| targetShape[position] = targetShape[position] / warpSize; |
| warpSize = 1; |
| break; |
| } |
| if (warpSize != 1) { |
| return VectorType(); |
| } |
| VectorType targetType = |
| VectorType::get(targetShape, originalType.getElementType()); |
| return targetType; |
| } |
| |
| /// Distribute transfer_write ops based on the affine map returned by |
| /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` |
| /// will not be distributed (it should be less than the warp size). |
| /// |
| /// Example: |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%id){ |
| /// ... |
| /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> |
| /// gpu.yield |
| /// } |
| /// ``` |
| /// To |
| /// ``` |
| /// %r:3 = gpu.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { |
| /// ... |
| /// gpu.yield %v : vector<32xf32> |
| /// } |
| /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> |
| struct WarpOpTransferWrite : public WarpDistributionPattern { |
| WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, |
| unsigned maxNumElementsToExtract, PatternBenefit b = 1) |
| : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)), |
| maxNumElementsToExtract(maxNumElementsToExtract) {} |
| |
| /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that |
| /// are multiples of the distribution ratio are supported at the moment. |
| LogicalResult tryDistributeOp(RewriterBase &rewriter, |
| vector::TransferWriteOp writeOp, |
| WarpExecuteOnLane0Op warpOp) const { |
| VectorType writtenVectorType = writeOp.getVectorType(); |
| |
| // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op |
| // to separate it from the rest. |
| if (writtenVectorType.getRank() == 0) |
| return failure(); |
| |
| // 2. Compute the distributed type. |
| AffineMap map = distributionMapFn(writeOp.getVector()); |
| VectorType targetType = |
| getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); |
| if (!targetType) |
| return failure(); |
| |
| // 2.5 Compute the distributed type for the new mask; |
| VectorType maskType; |
| if (writeOp.getMask()) { |
| // TODO: Distribution of masked writes with non-trivial permutation maps |
| // requires the distribution of the mask to elementwise match the |
| // distribution of the permuted written vector. Currently the details |
| // of which lane is responsible for which element is captured strictly |
| // by shape information on the warp op, and thus requires materializing |
| // the permutation in IR. |
| if (!writeOp.getPermutationMap().isMinorIdentity()) |
| return failure(); |
| maskType = |
| getDistributedType(writeOp.getMaskType(), map, warpOp.getWarpSize()); |
| } |
| |
| // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from |
| // the rest. |
| vector::TransferWriteOp newWriteOp = |
| cloneWriteOp(rewriter, warpOp, writeOp, targetType, maskType); |
| |
| // 4. Reindex the write using the distribution map. |
| auto newWarpOp = |
| newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); |
| |
| // Delinearize the lane id based on the way threads are divided across the |
| // vector. To get the number of threads per vector dimension, divide the |
| // sequential size by the distributed size along each dim. |
| rewriter.setInsertionPoint(newWriteOp); |
| SmallVector<OpFoldResult> delinearizedIdSizes; |
| for (auto [seqSize, distSize] : |
| llvm::zip_equal(writtenVectorType.getShape(), targetType.getShape())) { |
| assert(seqSize % distSize == 0 && "Invalid distributed vector shape"); |
| delinearizedIdSizes.push_back(rewriter.getIndexAttr(seqSize / distSize)); |
| } |
| SmallVector<Value> delinearized; |
| if (map.getNumResults() > 1) { |
| delinearized = rewriter |
| .create<mlir::affine::AffineDelinearizeIndexOp>( |
| newWarpOp.getLoc(), newWarpOp.getLaneid(), |
| delinearizedIdSizes) |
| .getResults(); |
| } else { |
| // If there is only one map result, we can elide the delinearization |
| // op and use the lane id directly. |
| delinearized.append(targetType.getRank(), newWarpOp.getLaneid()); |
| } |
| |
| AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); |
| Location loc = newWriteOp.getLoc(); |
| SmallVector<Value> indices(newWriteOp.getIndices().begin(), |
| newWriteOp.getIndices().end()); |
| for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { |
| AffineExpr d0, d1; |
| bindDims(newWarpOp.getContext(), d0, d1); |
| auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
| if (!indexExpr) |
| continue; |
| unsigned indexPos = indexExpr.getPosition(); |
| unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
| Value laneId = delinearized[vectorPos]; |
| auto scale = |
| rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); |
| indices[indexPos] = affine::makeComposedAffineApply( |
| rewriter, loc, d0 + scale * d1, {indices[indexPos], laneId}); |
| } |
| newWriteOp.getIndicesMutable().assign(indices); |
| |
| return success(); |
| } |
| |
| /// Extract TransferWriteOps of vector<1x> into a separate warp op. |
| LogicalResult tryExtractOp(RewriterBase &rewriter, |
| vector::TransferWriteOp writeOp, |
| WarpExecuteOnLane0Op warpOp) const { |
| Location loc = writeOp.getLoc(); |
| VectorType vecType = writeOp.getVectorType(); |
| |
| if (vecType.getNumElements() > maxNumElementsToExtract) { |
| return rewriter.notifyMatchFailure( |
| warpOp, |
| llvm::formatv( |
| "writes more elements ({0}) than allowed to extract ({1})", |
| vecType.getNumElements(), maxNumElementsToExtract)); |
| } |
| |
| // Do not process warp ops that contain only TransferWriteOps. |
| if (llvm::all_of(warpOp.getOps(), |
| llvm::IsaPred<vector::TransferWriteOp, gpu::YieldOp>)) |
| return failure(); |
| |
| SmallVector<Value> yieldValues = {writeOp.getVector()}; |
| SmallVector<Type> retTypes = {vecType}; |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| |
| // Create a second warp op that contains only writeOp. |
| auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( |
| loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); |
| Block &body = secondWarpOp.getBodyRegion().front(); |
| rewriter.setInsertionPointToStart(&body); |
| auto newWriteOp = |
| cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
| newWriteOp.getValueToStoreMutable().assign( |
| newWarpOp.getResult(newRetIndices[0])); |
| rewriter.eraseOp(writeOp); |
| rewriter.create<gpu::YieldOp>(newWarpOp.getLoc()); |
| return success(); |
| } |
| |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| auto yield = cast<gpu::YieldOp>( |
| warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| Operation *lastNode = yield->getPrevNode(); |
| auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); |
| if (!writeOp) |
| return failure(); |
| |
| Value maybeMask = writeOp.getMask(); |
| if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { |
| return writeOp.getVector() == value || |
| (maybeMask && maybeMask == value) || |
| warpOp.isDefinedOutsideOfRegion(value); |
| })) |
| return failure(); |
| |
| if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) |
| return success(); |
| |
| // Masked writes not supported for extraction. |
| if (writeOp.getMask()) |
| return failure(); |
| |
| if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) |
| return success(); |
| |
| return failure(); |
| } |
| |
| private: |
| /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp |
| /// execute op with the proper return type. The new write op is updated to |
| /// write the result of the new warp execute op. The old `writeOp` is deleted. |
| vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, |
| WarpExecuteOnLane0Op warpOp, |
| vector::TransferWriteOp writeOp, |
| VectorType targetType, |
| VectorType maybeMaskType) const { |
| assert(writeOp->getParentOp() == warpOp && |
| "write must be nested immediately under warp"); |
| OpBuilder::InsertionGuard g(rewriter); |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp; |
| if (maybeMaskType) { |
| newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, ValueRange{writeOp.getVector(), writeOp.getMask()}, |
| TypeRange{targetType, maybeMaskType}, newRetIndices); |
| } else { |
| newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, ValueRange{{writeOp.getVector()}}, |
| TypeRange{targetType}, newRetIndices); |
| } |
| rewriter.setInsertionPointAfter(newWarpOp); |
| auto newWriteOp = |
| cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); |
| rewriter.eraseOp(writeOp); |
| newWriteOp.getValueToStoreMutable().assign( |
| newWarpOp.getResult(newRetIndices[0])); |
| if (maybeMaskType) |
| newWriteOp.getMaskMutable().assign(newWarpOp.getResult(newRetIndices[1])); |
| return newWriteOp; |
| } |
| |
| DistributionMapFn distributionMapFn; |
| unsigned maxNumElementsToExtract = 1; |
| }; |
| |
| /// Sink out elementwise op feeding into a warp op yield. |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
| /// ... |
| /// %3 = arith.addf %1, %2 : vector<32xf32> |
| /// gpu.yield %3 : vector<32xf32> |
| /// } |
| /// ``` |
| /// To |
| /// ``` |
| /// %r:3 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
| /// vector<1xf32>, vector<1xf32>) { |
| /// ... |
| /// %4 = arith.addf %2, %3 : vector<32xf32> |
| /// gpu.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, |
| /// vector<32xf32> |
| /// } |
| /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> |
| struct WarpOpElementwise : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { |
| return OpTrait::hasElementwiseMappableTraits(op); |
| }); |
| if (!yieldOperand) |
| return failure(); |
| |
| Operation *elementWise = yieldOperand->get().getDefiningOp(); |
| unsigned operandIndex = yieldOperand->getOperandNumber(); |
| Value distributedVal = warpOp.getResult(operandIndex); |
| SmallVector<Value> yieldValues; |
| SmallVector<Type> retTypes; |
| Location loc = warpOp.getLoc(); |
| for (OpOperand &operand : elementWise->getOpOperands()) { |
| Type targetType; |
| if (auto vecType = dyn_cast<VectorType>(distributedVal.getType())) { |
| // If the result type is a vector, the operands must also be vectors. |
| auto operandType = cast<VectorType>(operand.get().getType()); |
| targetType = |
| VectorType::get(vecType.getShape(), operandType.getElementType()); |
| } else { |
| auto operandType = operand.get().getType(); |
| assert(!isa<VectorType>(operandType) && |
| "unexpected yield of vector from op with scalar result type"); |
| targetType = operandType; |
| } |
| retTypes.push_back(targetType); |
| yieldValues.push_back(operand.get()); |
| } |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| SmallVector<Value> newOperands(elementWise->getOperands().begin(), |
| elementWise->getOperands().end()); |
| for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { |
| newOperands[i] = newWarpOp.getResult(newRetIndices[i]); |
| } |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Operation *newOp = cloneOpWithOperandsAndTypes( |
| rewriter, loc, elementWise, newOperands, |
| {newWarpOp.getResult(operandIndex).getType()}); |
| rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), |
| newOp->getResult(0)); |
| return success(); |
| } |
| }; |
| |
| /// Sink out splat constant op feeding into a warp op yield. |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
| /// ... |
| /// %cst = arith.constant dense<2.0> : vector<32xf32> |
| /// gpu.yield %cst : vector<32xf32> |
| /// } |
| /// ``` |
| /// To |
| /// ``` |
| /// gpu.warp_execute_on_lane_0(%arg0 { |
| /// ... |
| /// } |
| /// %0 = arith.constant dense<2.0> : vector<1xf32> |
| struct WarpOpConstant : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *yieldOperand = |
| getWarpResult(warpOp, llvm::IsaPred<arith::ConstantOp>); |
| if (!yieldOperand) |
| return failure(); |
| auto constantOp = yieldOperand->get().getDefiningOp<arith::ConstantOp>(); |
| auto dense = dyn_cast<SplatElementsAttr>(constantOp.getValue()); |
| if (!dense) |
| return failure(); |
| // Notify the rewriter that the warp op is changing (see the comment on |
| // the WarpOpTransferRead pattern). |
| rewriter.startOpModification(warpOp); |
| unsigned operandIndex = yieldOperand->getOperandNumber(); |
| Attribute scalarAttr = dense.getSplatValue<Attribute>(); |
| auto newAttr = DenseElementsAttr::get( |
| cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr); |
| Location loc = warpOp.getLoc(); |
| rewriter.setInsertionPointAfter(warpOp); |
| Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr); |
| rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); |
| rewriter.finalizeOpModification(warpOp); |
| return success(); |
| } |
| }; |
| |
| /// Sink out transfer_read op feeding into a warp op yield. |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
| /// ... |
| // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
| // vector<32xf32> |
| /// gpu.yield %2 : vector<32xf32> |
| /// } |
| /// ``` |
| /// To |
| /// ``` |
| /// %dead = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, |
| /// vector<1xf32>, vector<1xf32>) { |
| /// ... |
| /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, |
| /// vector<32xf32> gpu.yield %2 : vector<32xf32> |
| /// } |
| /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> |
| struct WarpOpTransferRead : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| // Try to find a distributable yielded read. Note that this pattern can |
| // still fail at the end after distribution, in which case this might have |
| // missed another distributable read. |
| OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { |
| // Don't duplicate transfer_read ops when distributing. |
| return isa<vector::TransferReadOp>(op) && op->hasOneUse(); |
| }); |
| if (!operand) |
| return rewriter.notifyMatchFailure( |
| warpOp, "warp result is not a vector.transfer_read op"); |
| auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); |
| |
| // Source must be defined outside of the region. |
| if (!warpOp.isDefinedOutsideOfRegion(read.getBase())) |
| return rewriter.notifyMatchFailure( |
| read, "source must be defined outside of the region"); |
| |
| unsigned operandIndex = operand->getOperandNumber(); |
| Value distributedVal = warpOp.getResult(operandIndex); |
| |
| SmallVector<Value, 4> indices(read.getIndices().begin(), |
| read.getIndices().end()); |
| auto sequentialType = cast<VectorType>(read.getResult().getType()); |
| auto distributedType = cast<VectorType>(distributedVal.getType()); |
| AffineMap map = calculateImplicitMap(sequentialType, distributedType); |
| AffineMap indexMap = map.compose(read.getPermutationMap()); |
| |
| // Try to delinearize the lane ID to match the rank expected for |
| // distribution. |
| SmallVector<Value> delinearizedIds; |
| if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), |
| distributedType.getShape(), warpOp.getWarpSize(), |
| warpOp.getLaneid(), delinearizedIds)) { |
| return rewriter.notifyMatchFailure( |
| read, "cannot delinearize lane ID for distribution"); |
| } |
| assert(!delinearizedIds.empty() || map.getNumResults() == 0); |
| |
| // Distribute indices and the mask (if present). |
| OpBuilder::InsertionGuard g(rewriter); |
| SmallVector<Value> additionalResults(indices.begin(), indices.end()); |
| SmallVector<Type> additionalResultTypes(indices.size(), |
| rewriter.getIndexType()); |
| additionalResults.push_back(read.getPadding()); |
| additionalResultTypes.push_back(read.getPadding().getType()); |
| |
| bool hasMask = false; |
| if (read.getMask()) { |
| hasMask = true; |
| // TODO: Distribution of masked reads with non-trivial permutation maps |
| // requires the distribution of the mask to elementwise match the |
| // distribution of the permuted written vector. Currently the details |
| // of which lane is responsible for which element is captured strictly |
| // by shape information on the warp op, and thus requires materializing |
| // the permutation in IR. |
| if (!mlir::compressUnusedDims(read.getPermutationMap()).isIdentity()) |
| return rewriter.notifyMatchFailure( |
| read, "non-trivial permutation maps not supported"); |
| VectorType maskType = |
| getDistributedType(read.getMaskType(), map, warpOp.getWarpSize()); |
| additionalResults.push_back(read.getMask()); |
| additionalResultTypes.push_back(maskType); |
| } |
| |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, additionalResults, additionalResultTypes, |
| newRetIndices); |
| distributedVal = newWarpOp.getResult(operandIndex); |
| |
| // Distributed indices were appended first. |
| SmallVector<Value> newIndices; |
| for (int64_t i = 0, e = indices.size(); i < e; ++i) |
| newIndices.push_back(newWarpOp.getResult(newRetIndices[i])); |
| |
| rewriter.setInsertionPointAfter(newWarpOp); |
| for (auto it : llvm::zip_equal(indexMap.getResults(), map.getResults())) { |
| AffineExpr d0, d1; |
| bindDims(read.getContext(), d0, d1); |
| auto indexExpr = dyn_cast<AffineDimExpr>(std::get<0>(it)); |
| if (!indexExpr) |
| continue; |
| unsigned indexPos = indexExpr.getPosition(); |
| unsigned vectorPos = cast<AffineDimExpr>(std::get<1>(it)).getPosition(); |
| int64_t scale = distributedType.getDimSize(vectorPos); |
| newIndices[indexPos] = affine::makeComposedAffineApply( |
| rewriter, read.getLoc(), d0 + scale * d1, |
| {newIndices[indexPos], delinearizedIds[vectorPos]}); |
| } |
| |
| // Distributed padding value was appended right after the indices. |
| Value newPadding = newWarpOp.getResult(newRetIndices[indices.size()]); |
| // Distributed mask value was added at the end (if the op has a mask). |
| Value newMask = |
| hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1]) |
| : Value(); |
| auto newRead = rewriter.create<vector::TransferReadOp>( |
| read.getLoc(), distributedVal.getType(), read.getBase(), newIndices, |
| read.getPermutationMapAttr(), newPadding, newMask, |
| read.getInBoundsAttr()); |
| |
| rewriter.replaceAllUsesWith(distributedVal, newRead); |
| return success(); |
| } |
| }; |
| |
| /// Remove any result that has no use along with the matching yieldOp operand. |
| // TODO: Move this in WarpExecuteOnLane0Op canonicalization. |
| struct WarpOpDeadResult : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| SmallVector<Type> newResultTypes; |
| newResultTypes.reserve(warpOp->getNumResults()); |
| SmallVector<Value> newYieldValues; |
| newYieldValues.reserve(warpOp->getNumResults()); |
| DenseMap<Value, int64_t> dedupYieldOperandPositionMap; |
| DenseMap<OpResult, int64_t> dedupResultPositionMap; |
| auto yield = cast<gpu::YieldOp>( |
| warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| |
| // Some values may be yielded multiple times and correspond to multiple |
| // results. Deduplicating occurs by taking each result with its matching |
| // yielded value, and: |
| // 1. recording the unique first position at which the value is yielded. |
| // 2. recording for the result, the first position at which the dedup'ed |
| // value is yielded. |
| // 3. skipping from the new result types / new yielded values any result |
| // that has no use or whose yielded value has already been seen. |
| for (OpResult result : warpOp.getResults()) { |
| Value yieldOperand = yield.getOperand(result.getResultNumber()); |
| auto it = dedupYieldOperandPositionMap.insert( |
| std::make_pair(yieldOperand, newResultTypes.size())); |
| dedupResultPositionMap.insert(std::make_pair(result, it.first->second)); |
| if (result.use_empty() || !it.second) |
| continue; |
| newResultTypes.push_back(result.getType()); |
| newYieldValues.push_back(yieldOperand); |
| } |
| // No modification, exit early. |
| if (yield.getNumOperands() == newYieldValues.size()) |
| return failure(); |
| // Move the body of the old warpOp to a new warpOp. |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( |
| rewriter, warpOp, newYieldValues, newResultTypes); |
| |
| // Simplify the new warp op after dropping dead results. |
| newWarpOp.getBody()->walk([&](Operation *op) { |
| if (isOpTriviallyDead(op)) |
| rewriter.eraseOp(op); |
| }); |
| |
| // Replace results of the old warpOp by the new, deduplicated results. |
| SmallVector<Value> newValues; |
| newValues.reserve(warpOp->getNumResults()); |
| for (OpResult result : warpOp.getResults()) { |
| if (result.use_empty()) |
| newValues.push_back(Value()); |
| else |
| newValues.push_back( |
| newWarpOp.getResult(dedupResultPositionMap.lookup(result))); |
| } |
| rewriter.replaceOp(warpOp, newValues); |
| return success(); |
| } |
| }; |
| |
| // If an operand is directly yielded out of the region we can forward it |
| // directly and it doesn't need to go through the region. |
| struct WarpOpForwardOperand : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| auto yield = cast<gpu::YieldOp>( |
| warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| Value valForwarded; |
| unsigned resultIndex; |
| for (OpOperand &operand : yield->getOpOperands()) { |
| Value result = warpOp.getResult(operand.getOperandNumber()); |
| if (result.use_empty()) |
| continue; |
| |
| // Assume all the values coming from above are uniform. |
| if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { |
| if (result.getType() != operand.get().getType()) |
| continue; |
| valForwarded = operand.get(); |
| resultIndex = operand.getOperandNumber(); |
| break; |
| } |
| auto arg = dyn_cast<BlockArgument>(operand.get()); |
| if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) |
| continue; |
| Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; |
| if (result.getType() != warpOperand.getType()) |
| continue; |
| valForwarded = warpOperand; |
| resultIndex = operand.getOperandNumber(); |
| break; |
| } |
| if (!valForwarded) |
| return failure(); |
| // Notify the rewriter that the warp op is changing (see the comment on |
| // the WarpOpTransferRead pattern). |
| rewriter.startOpModification(warpOp); |
| rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); |
| rewriter.finalizeOpModification(warpOp); |
| return success(); |
| } |
| }; |
| |
| struct WarpOpBroadcast : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>); |
| if (!operand) |
| return failure(); |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); |
| Location loc = broadcastOp.getLoc(); |
| auto destVecType = |
| cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
| Value broadcastSrc = broadcastOp.getSource(); |
| Type broadcastSrcType = broadcastSrc.getType(); |
| |
| // Check that the broadcast actually spans a set of values uniformly across |
| // all threads. In other words, check that each thread can reconstruct |
| // their own broadcast. |
| // For that we simply check that the broadcast we want to build makes sense. |
| if (vector::isBroadcastableTo(broadcastSrcType, destVecType) != |
| vector::BroadcastableToResult::Success) |
| return failure(); |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {broadcastSrc}, {broadcastSrcType}, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value broadcasted = rewriter.create<vector::BroadcastOp>( |
| loc, destVecType, newWarpOp->getResult(newRetIndices[0])); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| broadcasted); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to move shape cast out of the warp op. shape cast is basically a |
| /// no-op for warp distribution; we need to handle the shape though. |
| struct WarpOpShapeCast : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::ShapeCastOp>); |
| if (!operand) |
| return failure(); |
| |
| auto oldCastOp = operand->get().getDefiningOp<vector::ShapeCastOp>(); |
| |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto castDistributedType = |
| cast<VectorType>(warpOp->getResultTypes()[operandNumber]); |
| VectorType castOriginalType = oldCastOp.getSourceVectorType(); |
| VectorType castResultType = castDistributedType; |
| |
| // We expect the distributed type to have a smaller rank than the original |
| // type. Prepend with size-one dimensions to make them the same. |
| unsigned castDistributedRank = castDistributedType.getRank(); |
| unsigned castOriginalRank = castOriginalType.getRank(); |
| if (castDistributedRank < castOriginalRank) { |
| SmallVector<int64_t> shape(castOriginalRank - castDistributedRank, 1); |
| llvm::append_range(shape, castDistributedType.getShape()); |
| castDistributedType = |
| VectorType::get(shape, castDistributedType.getElementType()); |
| } |
| |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, |
| newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value newCast = rewriter.create<vector::ShapeCastOp>( |
| oldCastOp.getLoc(), castResultType, |
| newWarpOp->getResult(newRetIndices[0])); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); |
| return success(); |
| } |
| }; |
| |
| /// Sink out vector.create_mask op feeding into a warp op yield. |
| /// ``` |
| /// %0 = ... |
| /// %1 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { |
| /// ... |
| /// %mask = vector.create_mask %0 : vector<32xi1> |
| /// gpu.yield %mask : vector<32xi1> |
| /// } |
| /// ``` |
| /// To |
| /// ``` |
| /// %0 = ... |
| /// gpu.warp_execute_on_lane_0(%arg0) { |
| /// ... |
| /// } |
| /// %cmp = arith.cmpi ult, %laneid, %0 |
| /// %ub = arith.select %cmp, %c0, %c1 |
| /// %1 = vector.create_mask %ub : vector<1xi1> |
| struct WarpOpCreateMask : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *yieldOperand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::CreateMaskOp>); |
| if (!yieldOperand) |
| return failure(); |
| |
| auto mask = yieldOperand->get().getDefiningOp<vector::CreateMaskOp>(); |
| |
| // Early exit if any values needed for calculating the new mask indices |
| // are defined inside the warp op. |
| if (!llvm::all_of(mask->getOperands(), [&](Value value) { |
| return warpOp.isDefinedOutsideOfRegion(value); |
| })) |
| return failure(); |
| |
| Location loc = mask.getLoc(); |
| unsigned operandIndex = yieldOperand->getOperandNumber(); |
| |
| auto distType = cast<VectorType>(warpOp.getResult(operandIndex).getType()); |
| VectorType seqType = mask.getVectorType(); |
| ArrayRef<int64_t> seqShape = seqType.getShape(); |
| ArrayRef<int64_t> distShape = distType.getShape(); |
| |
| rewriter.setInsertionPointAfter(warpOp); |
| |
| // Delinearize the lane ID for constructing the distributed mask sizes. |
| SmallVector<Value> delinearizedIds; |
| if (!delinearizeLaneId(rewriter, loc, seqShape, distShape, |
| warpOp.getWarpSize(), warpOp.getLaneid(), |
| delinearizedIds)) |
| return rewriter.notifyMatchFailure( |
| mask, "cannot delinearize lane ID for distribution"); |
| assert(!delinearizedIds.empty()); |
| |
| // Notify the rewriter that the warp op is changing (see the comment on |
| // the WarpOpTransferRead pattern). |
| rewriter.startOpModification(warpOp); |
| |
| AffineExpr s0, s1; |
| bindSymbols(rewriter.getContext(), s0, s1); |
| SmallVector<Value> newOperands; |
| for (int i = 0, e = distShape.size(); i < e; ++i) { |
| // Get `mask_dim_range_upper_limit[i] - lane_id[i] * dist_sizes[i]` to |
| // find the distance from the largest mask index owned by this lane to the |
| // original mask size. `vector.create_mask` implicitly clamps mask |
| // operands to the range [0, mask_vector_size[i]], or in other words, the |
| // mask sizes are always in the range [0, mask_vector_size[i]). |
| Value maskDimIdx = affine::makeComposedAffineApply( |
| rewriter, loc, s1 - s0 * distShape[i], |
| {delinearizedIds[i], mask.getOperand(i)}); |
| newOperands.push_back(maskDimIdx); |
| } |
| |
| auto newMask = |
| rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands); |
| rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); |
| rewriter.finalizeOpModification(warpOp); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to move out vector.extract of single element vector. Those don't |
| /// need to be distributed and can just be propagated outside of the region. |
| struct WarpOpExtract : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); |
| if (!operand) |
| return failure(); |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); |
| VectorType extractSrcType = extractOp.getSourceVectorType(); |
| Location loc = extractOp.getLoc(); |
| |
| // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern. |
| if (extractSrcType.getRank() <= 1) { |
| return failure(); |
| } |
| |
| // All following cases are 2d or higher dimensional source vectors. |
| |
| if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
| // There is no distribution, this is a broadcast. Simply move the extract |
| // out of the warp op. |
| // TODO: This could be optimized. E.g., in case of a scalar result, let |
| // one lane extract and shuffle the result to all other lanes (same as |
| // the 1d case). |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {extractOp.getVector()}, |
| {extractOp.getSourceVectorType()}, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
| // Extract from distributed vector. |
| Value newExtract = rewriter.create<vector::ExtractOp>( |
| loc, distributedVec, extractOp.getMixedPosition()); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| newExtract); |
| return success(); |
| } |
| |
| // Find the distributed dimension. There should be exactly one. |
| auto distributedType = |
| cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
| auto yieldedType = cast<VectorType>(operand->get().getType()); |
| int64_t distributedDim = -1; |
| for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
| if (distributedType.getDimSize(i) != yieldedType.getDimSize(i)) { |
| // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
| // support distributing multiple dimensions in the future. |
| assert(distributedDim == -1 && "found multiple distributed dims"); |
| distributedDim = i; |
| } |
| } |
| assert(distributedDim != -1 && "could not find distributed dimension"); |
| (void)distributedDim; |
| |
| // Yield source vector from warp op. |
| SmallVector<int64_t> newDistributedShape(extractSrcType.getShape()); |
| for (int i = 0; i < distributedType.getRank(); ++i) |
| newDistributedShape[i + extractOp.getNumIndices()] = |
| distributedType.getDimSize(i); |
| auto newDistributedType = |
| VectorType::get(newDistributedShape, distributedType.getElementType()); |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {extractOp.getVector()}, {newDistributedType}, |
| newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
| // Extract from distributed vector. |
| Value newExtract = rewriter.create<vector::ExtractOp>( |
| loc, distributedVec, extractOp.getMixedPosition()); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| newExtract); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to move out vector.extract with a scalar result. |
| /// Only supports 1-D and 0-D sources for now. |
| struct WarpOpExtractScalar : public WarpDistributionPattern { |
| WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn, |
| PatternBenefit b = 1) |
| : WarpDistributionPattern(ctx, b), warpShuffleFromIdxFn(std::move(fn)) {} |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>); |
| if (!operand) |
| return failure(); |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>(); |
| VectorType extractSrcType = extractOp.getSourceVectorType(); |
| // Only supports 1-D or 0-D sources for now. |
| if (extractSrcType.getRank() > 1) { |
| return rewriter.notifyMatchFailure( |
| extractOp, "only 0-D or 1-D source supported for now"); |
| } |
| // TODO: Supported shuffle types should be parameterizable, similar to |
| // `WarpShuffleFromIdxFn`. |
| if (!extractSrcType.getElementType().isF32() && |
| !extractSrcType.getElementType().isInteger(32)) |
| return rewriter.notifyMatchFailure( |
| extractOp, "only f32/i32 element types are supported"); |
| bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1; |
| Type elType = extractSrcType.getElementType(); |
| VectorType distributedVecType; |
| if (!is0dOrVec1Extract) { |
| assert(extractSrcType.getRank() == 1 && |
| "expected that extract src rank is 0 or 1"); |
| if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0) |
| return failure(); |
| int64_t elementsPerLane = |
| extractSrcType.getShape()[0] / warpOp.getWarpSize(); |
| distributedVecType = VectorType::get({elementsPerLane}, elType); |
| } else { |
| distributedVecType = extractSrcType; |
| } |
| // Yield source vector and position (if present) from warp op. |
| SmallVector<Value> additionalResults{extractOp.getVector()}; |
| SmallVector<Type> additionalResultTypes{distributedVecType}; |
| additionalResults.append( |
| SmallVector<Value>(extractOp.getDynamicPosition())); |
| additionalResultTypes.append( |
| SmallVector<Type>(extractOp.getDynamicPosition().getTypes())); |
| |
| Location loc = extractOp.getLoc(); |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, additionalResults, additionalResultTypes, |
| newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
| |
| // 0d extract: The new warp op broadcasts the source vector to all lanes. |
| // All lanes extract the scalar. |
| if (is0dOrVec1Extract) { |
| Value newExtract; |
| SmallVector<int64_t> indices(extractSrcType.getRank(), 0); |
| newExtract = |
| rewriter.create<vector::ExtractOp>(loc, distributedVec, indices); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| newExtract); |
| return success(); |
| } |
| |
| int64_t staticPos = extractOp.getStaticPosition()[0]; |
| OpFoldResult pos = ShapedType::isDynamic(staticPos) |
| ? (newWarpOp->getResult(newRetIndices[1])) |
| : OpFoldResult(rewriter.getIndexAttr(staticPos)); |
| // 1d extract: Distribute the source vector. One lane extracts and shuffles |
| // the value to all other lanes. |
| int64_t elementsPerLane = distributedVecType.getShape()[0]; |
| AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); |
| // tid of extracting thread: pos / elementsPerLane |
| Value broadcastFromTid = affine::makeComposedAffineApply( |
| rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); |
| // Extract at position: pos % elementsPerLane |
| Value newPos = |
| elementsPerLane == 1 |
| ? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult() |
| : affine::makeComposedAffineApply(rewriter, loc, |
| sym0 % elementsPerLane, pos); |
| Value extracted = |
| rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos); |
| |
| // Shuffle the extracted value to all lanes. |
| Value shuffled = warpShuffleFromIdxFn( |
| loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), shuffled); |
| return success(); |
| } |
| |
| private: |
| WarpShuffleFromIdxFn warpShuffleFromIdxFn; |
| }; |
| |
| /// Pattern to convert vector.extractelement to vector.extract. |
| struct WarpOpExtractElement : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>); |
| if (!operand) |
| return failure(); |
| auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>(); |
| SmallVector<OpFoldResult> indices; |
| if (auto pos = extractOp.getPosition()) { |
| indices.push_back(pos); |
| } |
| rewriter.setInsertionPoint(extractOp); |
| rewriter.replaceOpWithNewOp<vector::ExtractOp>( |
| extractOp, extractOp.getVector(), indices); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to move out vector.insert with a scalar input. |
| /// Only supports 1-D and 0-D destinations for now. |
| struct WarpOpInsertScalar : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); |
| if (!operand) |
| return failure(); |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); |
| VectorType vecType = insertOp.getDestVectorType(); |
| VectorType distrType = |
| cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
| |
| // Only supports 1-D or 0-D destinations for now. |
| if (vecType.getRank() > 1) { |
| return rewriter.notifyMatchFailure( |
| insertOp, "only 0-D or 1-D source supported for now"); |
| } |
| |
| // Yield destination vector, source scalar and position from warp op. |
| SmallVector<Value> additionalResults{insertOp.getDest(), |
| insertOp.getValueToStore()}; |
| SmallVector<Type> additionalResultTypes{ |
| distrType, insertOp.getValueToStore().getType()}; |
| additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition())); |
| additionalResultTypes.append( |
| SmallVector<Type>(insertOp.getDynamicPosition().getTypes())); |
| |
| Location loc = insertOp.getLoc(); |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, additionalResults, additionalResultTypes, |
| newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedVec = newWarpOp->getResult(newRetIndices[0]); |
| Value newSource = newWarpOp->getResult(newRetIndices[1]); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| |
| OpFoldResult pos; |
| if (vecType.getRank() != 0) { |
| int64_t staticPos = insertOp.getStaticPosition()[0]; |
| pos = ShapedType::isDynamic(staticPos) |
| ? (newWarpOp->getResult(newRetIndices[2])) |
| : OpFoldResult(rewriter.getIndexAttr(staticPos)); |
| } |
| |
| // This condition is always true for 0-d vectors. |
| if (vecType == distrType) { |
| Value newInsert; |
| SmallVector<OpFoldResult> indices; |
| if (pos) { |
| indices.push_back(pos); |
| } |
| newInsert = rewriter.create<vector::InsertOp>(loc, newSource, |
| distributedVec, indices); |
| // Broadcast: Simply move the vector.insert op out. |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| newInsert); |
| return success(); |
| } |
| |
| // This is a distribution. Only one lane should insert. |
| int64_t elementsPerLane = distrType.getShape()[0]; |
| AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); |
| // tid of extracting thread: pos / elementsPerLane |
| Value insertingLane = affine::makeComposedAffineApply( |
| rewriter, loc, sym0.ceilDiv(elementsPerLane), pos); |
| // Insert position: pos % elementsPerLane |
| OpFoldResult newPos = affine::makeComposedFoldedAffineApply( |
| rewriter, loc, sym0 % elementsPerLane, pos); |
| Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
| Value newResult = |
| rewriter |
| .create<scf::IfOp>( |
| loc, isInsertingLane, |
| /*thenBuilder=*/ |
| [&](OpBuilder &builder, Location loc) { |
| Value newInsert = builder.create<vector::InsertOp>( |
| loc, newSource, distributedVec, newPos); |
| builder.create<scf::YieldOp>(loc, newInsert); |
| }, |
| /*elseBuilder=*/ |
| [&](OpBuilder &builder, Location loc) { |
| builder.create<scf::YieldOp>(loc, distributedVec); |
| }) |
| .getResult(0); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
| return success(); |
| } |
| }; |
| |
| struct WarpOpInsert : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>); |
| if (!operand) |
| return failure(); |
| unsigned int operandNumber = operand->getOperandNumber(); |
| auto insertOp = operand->get().getDefiningOp<vector::InsertOp>(); |
| Location loc = insertOp.getLoc(); |
| |
| // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern. |
| if (insertOp.getDestVectorType().getRank() <= 1) { |
| return failure(); |
| } |
| |
| // All following cases are 2d or higher dimensional source vectors. |
| |
| if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { |
| // There is no distribution, this is a broadcast. Simply move the insert |
| // out of the warp op. |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, |
| {insertOp.getValueToStoreType(), insertOp.getDestVectorType()}, |
| newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
| Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
| Value newResult = rewriter.create<vector::InsertOp>( |
| loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), |
| newResult); |
| return success(); |
| } |
| |
| // Find the distributed dimension. There should be exactly one. |
| auto distrDestType = |
| cast<VectorType>(warpOp.getResult(operandNumber).getType()); |
| auto yieldedType = cast<VectorType>(operand->get().getType()); |
| int64_t distrDestDim = -1; |
| for (int64_t i = 0; i < yieldedType.getRank(); ++i) { |
| if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { |
| // Keep this assert here in case WarpExecuteOnLane0Op gets extended to |
| // support distributing multiple dimensions in the future. |
| assert(distrDestDim == -1 && "found multiple distributed dims"); |
| distrDestDim = i; |
| } |
| } |
| assert(distrDestDim != -1 && "could not find distributed dimension"); |
| |
| // Compute the distributed source vector type. |
| VectorType srcVecType = cast<VectorType>(insertOp.getValueToStoreType()); |
| SmallVector<int64_t> distrSrcShape(srcVecType.getShape()); |
| // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> |
| // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will |
| // insert a smaller vector<3xf32>. |
| // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that |
| // case, one lane will insert the source vector<96xf32>. The other |
| // lanes will not do anything. |
| int64_t distrSrcDim = distrDestDim - insertOp.getNumIndices(); |
| if (distrSrcDim >= 0) |
| distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); |
| auto distrSrcType = |
| VectorType::get(distrSrcShape, distrDestType.getElementType()); |
| |
| // Yield source and dest vectors from warp op. |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, {insertOp.getValueToStore(), insertOp.getDest()}, |
| {distrSrcType, distrDestType}, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); |
| Value distributedDest = newWarpOp->getResult(newRetIndices[1]); |
| |
| // Insert into the distributed vector. |
| Value newResult; |
| if (distrSrcDim >= 0) { |
| // Every lane inserts a small piece. |
| newResult = rewriter.create<vector::InsertOp>( |
| loc, distributedSrc, distributedDest, insertOp.getMixedPosition()); |
| } else { |
| // One lane inserts the entire source vector. |
| int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); |
| SmallVector<OpFoldResult> pos = insertOp.getMixedPosition(); |
| SmallVector<int64_t> newPos = getAsIntegers(pos); |
| // tid of inserting lane: pos / elementsPerLane |
| Value insertingLane = rewriter.create<arith::ConstantIndexOp>( |
| loc, newPos[distrDestDim] / elementsPerLane); |
| Value isInsertingLane = rewriter.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); |
| // Insert position: pos % elementsPerLane |
| newPos[distrDestDim] %= elementsPerLane; |
| auto insertingBuilder = [&](OpBuilder &builder, Location loc) { |
| Value newInsert = builder.create<vector::InsertOp>( |
| loc, distributedSrc, distributedDest, newPos); |
| builder.create<scf::YieldOp>(loc, newInsert); |
| }; |
| auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { |
| builder.create<scf::YieldOp>(loc, distributedDest); |
| }; |
| newResult = rewriter |
| .create<scf::IfOp>(loc, isInsertingLane, |
| /*thenBuilder=*/insertingBuilder, |
| /*elseBuilder=*/nonInsertingBuilder) |
| .getResult(0); |
| } |
| |
| rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); |
| return success(); |
| } |
| }; |
| |
| struct WarpOpInsertElement : public WarpDistributionPattern { |
| using Base::Base; |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *operand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>); |
| if (!operand) |
| return failure(); |
| auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>(); |
| SmallVector<OpFoldResult> indices; |
| if (auto pos = insertOp.getPosition()) { |
| indices.push_back(pos); |
| } |
| rewriter.setInsertionPoint(insertOp); |
| rewriter.replaceOpWithNewOp<vector::InsertOp>( |
| insertOp, insertOp.getSource(), insertOp.getDest(), indices); |
| return success(); |
| } |
| }; |
| |
| /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if |
| /// the scf.ForOp is the last operation in the region so that it doesn't |
| /// change the order of execution. This creates a new scf.for region after the |
| /// WarpExecuteOnLane0Op. The new scf.for region will contain a new |
| /// WarpExecuteOnLane0Op region. Example: |
| /// ``` |
| /// %w = gpu.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { |
| /// ... |
| /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) |
| /// -> (vector<128xf32>) { |
| /// ... |
| /// scf.yield %r : vector<128xf32> |
| /// } |
| /// gpu.yield %v1 : vector<128xf32> |
| /// } |
| /// ``` |
| /// To: |
| /// %w0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { |
| /// ... |
| /// gpu.yield %v : vector<128xf32> |
| /// } |
| /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) |
| /// -> (vector<4xf32>) { |
| /// %iw = gpu.warp_execute_on_lane_0(%laneid) |
| /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { |
| /// ^bb0(%arg: vector<128xf32>): |
| /// ... |
| /// gpu.yield %ir : vector<128xf32> |
| /// } |
| /// scf.yield %iw : vector<4xf32> |
| /// } |
| /// ``` |
| struct WarpOpScfForOp : public WarpDistributionPattern { |
| |
| WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) |
| : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| auto yield = cast<gpu::YieldOp>( |
| warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| // Only pick up forOp if it is the last op in the region. |
| Operation *lastNode = yield->getPrevNode(); |
| auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); |
| if (!forOp) |
| return failure(); |
| // Collect Values that come from the warp op but are outside the forOp. |
| // Those Value needs to be returned by the original warpOp and passed to |
| // the new op. |
| llvm::SmallSetVector<Value, 32> escapingValues; |
| SmallVector<Type> inputTypes; |
| SmallVector<Type> distTypes; |
| mlir::visitUsedValuesDefinedAbove( |
| forOp.getBodyRegion(), [&](OpOperand *operand) { |
| Operation *parent = operand->get().getParentRegion()->getParentOp(); |
| if (warpOp->isAncestor(parent)) { |
| if (!escapingValues.insert(operand->get())) |
| return; |
| Type distType = operand->get().getType(); |
| if (auto vecType = dyn_cast<VectorType>(distType)) { |
| AffineMap map = distributionMapFn(operand->get()); |
| distType = getDistributedType(vecType, map, warpOp.getWarpSize()); |
| } |
| inputTypes.push_back(operand->get().getType()); |
| distTypes.push_back(distType); |
| } |
| }); |
| |
| if (llvm::is_contained(distTypes, Type{})) |
| return failure(); |
| |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, escapingValues.getArrayRef(), distTypes, |
| newRetIndices); |
| yield = cast<gpu::YieldOp>( |
| newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); |
| |
| SmallVector<Value> newOperands; |
| SmallVector<unsigned> resultIdx; |
| // Collect all the outputs coming from the forOp. |
| for (OpOperand &yieldOperand : yield->getOpOperands()) { |
| if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) |
| continue; |
| auto forResult = cast<OpResult>(yieldOperand.get()); |
| newOperands.push_back( |
| newWarpOp.getResult(yieldOperand.getOperandNumber())); |
| yieldOperand.set(forOp.getInitArgs()[forResult.getResultNumber()]); |
| resultIdx.push_back(yieldOperand.getOperandNumber()); |
| } |
| |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| |
| // Create a new for op outside the region with a WarpExecuteOnLane0Op |
| // region inside. |
| auto newForOp = rewriter.create<scf::ForOp>( |
| forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| forOp.getStep(), newOperands); |
| rewriter.setInsertionPointToStart(newForOp.getBody()); |
| |
| SmallVector<Value> warpInput(newForOp.getRegionIterArgs().begin(), |
| newForOp.getRegionIterArgs().end()); |
| SmallVector<Type> warpInputType(forOp.getResultTypes().begin(), |
| forOp.getResultTypes().end()); |
| llvm::SmallDenseMap<Value, int64_t> argIndexMapping; |
| for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { |
| warpInput.push_back(newWarpOp.getResult(retIdx)); |
| argIndexMapping[escapingValues[i]] = warpInputType.size(); |
| warpInputType.push_back(inputTypes[i]); |
| } |
| auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( |
| newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), |
| newWarpOp.getWarpSize(), warpInput, warpInputType); |
| |
| SmallVector<Value> argMapping; |
| argMapping.push_back(newForOp.getInductionVar()); |
| for (Value args : innerWarp.getBody()->getArguments()) { |
| argMapping.push_back(args); |
| } |
| argMapping.resize(forOp.getBody()->getNumArguments()); |
| SmallVector<Value> yieldOperands; |
| for (Value operand : forOp.getBody()->getTerminator()->getOperands()) |
| yieldOperands.push_back(operand); |
| rewriter.eraseOp(forOp.getBody()->getTerminator()); |
| rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); |
| rewriter.setInsertionPointToEnd(innerWarp.getBody()); |
| rewriter.create<gpu::YieldOp>(innerWarp.getLoc(), yieldOperands); |
| rewriter.setInsertionPointAfter(innerWarp); |
| if (!innerWarp.getResults().empty()) |
| rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); |
| rewriter.eraseOp(forOp); |
| // Replace the warpOp result coming from the original ForOp. |
| for (const auto &res : llvm::enumerate(resultIdx)) { |
| rewriter.replaceAllUsesWith(newWarpOp.getResult(res.value()), |
| newForOp.getResult(res.index())); |
| newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); |
| } |
| newForOp.walk([&](Operation *op) { |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto it = argIndexMapping.find(operand.get()); |
| if (it == argIndexMapping.end()) |
| continue; |
| operand.set(innerWarp.getBodyRegion().getArgument(it->second)); |
| } |
| }); |
| |
| // Finally, hoist out any now uniform code from the inner warp op. |
| mlir::vector::moveScalarUniformCode(innerWarp); |
| return success(); |
| } |
| |
| private: |
| DistributionMapFn distributionMapFn; |
| }; |
| |
| /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. |
| /// The vector is reduced in parallel. Currently limited to vector size |
| /// matching the warpOp size. E.g.: |
| /// ``` |
| /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (f32) { |
| /// %0 = "some_def"() : () -> (vector<32xf32>) |
| /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 |
| /// gpu.yield %1 : f32 |
| /// } |
| /// ``` |
| /// is lowered to: |
| /// ``` |
| /// %0 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { |
| /// %1 = "some_def"() : () -> (vector<32xf32>) |
| /// gpu.yield %1 : vector<32xf32> |
| /// } |
| /// %a = vector.extract %0[0] : f32 from vector<1xf32> |
| /// %r = ("warp.reduction %a") |
| /// ``` |
| struct WarpOpReduction : public WarpDistributionPattern { |
| WarpOpReduction(MLIRContext *context, |
| DistributedReductionFn distributedReductionFn, |
| PatternBenefit benefit = 1) |
| : WarpDistributionPattern(context, benefit), |
| distributedReductionFn(std::move(distributedReductionFn)) {} |
| |
| LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, |
| PatternRewriter &rewriter) const override { |
| OpOperand *yieldOperand = |
| getWarpResult(warpOp, llvm::IsaPred<vector::ReductionOp>); |
| if (!yieldOperand) |
| return failure(); |
| |
| auto reductionOp = |
| cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); |
| auto vectorType = cast<VectorType>(reductionOp.getVector().getType()); |
| // Only rank 1 vectors supported. |
| if (vectorType.getRank() != 1) |
| return rewriter.notifyMatchFailure( |
| warpOp, "Only rank 1 reductions can be distributed."); |
| // Only warp_size-sized vectors supported. |
| if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) |
| return rewriter.notifyMatchFailure( |
| warpOp, "Reduction vector dimension must match was size."); |
| if (!reductionOp.getType().isIntOrFloat()) |
| return rewriter.notifyMatchFailure( |
| warpOp, "Reduction distribution currently only supports floats and " |
| "integer types."); |
| |
| int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); |
| // Return vector that will be reduced from the WarpExecuteOnLane0Op. |
| unsigned operandIndex = yieldOperand->getOperandNumber(); |
| SmallVector<Value> yieldValues = {reductionOp.getVector()}; |
| SmallVector<Type> retTypes = { |
| VectorType::get({numElements}, reductionOp.getType())}; |
| if (reductionOp.getAcc()) { |
| yieldValues.push_back(reductionOp.getAcc()); |
| retTypes.push_back(reductionOp.getAcc().getType()); |
| } |
| SmallVector<size_t> newRetIndices; |
| WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( |
| rewriter, warpOp, yieldValues, retTypes, newRetIndices); |
| rewriter.setInsertionPointAfter(newWarpOp); |
| |
| // Obtain data to reduce for a single lane. |
| Value laneValVec = newWarpOp.getResult(newRetIndices[0]); |
| // Distribute and reduce across threads. |
| Value fullReduce = |
| distributedReductionFn(reductionOp.getLoc(), rewriter, laneValVec, |
| reductionOp.getKind(), newWarpOp.getWarpSize()); |
| if (reductionOp.getAcc()) { |
| fullReduce = vector::makeArithReduction( |
| rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, |
| newWarpOp.getResult(newRetIndices[1])); |
| } |
| rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIndex), fullReduce); |
| return success(); |
| } |
| |
| private: |
| DistributedReductionFn distributedReductionFn; |
| }; |
| |
| } // namespace |
| |
| void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( |
| RewritePatternSet &patterns, |
| const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { |
| patterns.add<WarpOpToScfIfPattern>(patterns.getContext(), options, benefit); |
| } |
| |
| void mlir::vector::populateDistributeTransferWriteOpPatterns( |
| RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
| unsigned maxNumElementsToExtract, PatternBenefit benefit) { |
| patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn, |
| maxNumElementsToExtract, benefit); |
| } |
| |
| void mlir::vector::populatePropagateWarpVectorDistributionPatterns( |
| RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, |
| const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, |
| PatternBenefit readBenefit) { |
| patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit); |
| patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, |
| WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, |
| WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement, |
| WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>( |
| patterns.getContext(), benefit); |
| patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, |
| benefit); |
| patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn, |
| benefit); |
| } |
| |
| void mlir::vector::populateDistributeReduction( |
| RewritePatternSet &patterns, |
| const DistributedReductionFn &distributedReductionFn, |
| PatternBenefit benefit) { |
| patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn, |
| benefit); |
| } |
| |
| /// Helper to know if an op can be hoisted out of the region. |
| static bool canBeHoisted(Operation *op, |
| function_ref<bool(Value)> definedOutside) { |
| return llvm::all_of(op->getOperands(), definedOutside) && |
| isMemoryEffectFree(op) && op->getNumRegions() == 0; |
| } |
| |
| void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { |
| Block *body = warpOp.getBody(); |
| |
| // Keep track of the ops we want to hoist. |
| llvm::SmallSetVector<Operation *, 8> opsToMove; |
| |
| // Helper to check if a value is or will be defined outside of the region. |
| auto isDefinedOutsideOfBody = [&](Value value) { |
| auto *definingOp = value.getDefiningOp(); |
| return (definingOp && opsToMove.count(definingOp)) || |
| warpOp.isDefinedOutsideOfRegion(value); |
| }; |
| |
| // Do not use walk here, as we do not want to go into nested regions and hoist |
| // operations from there. |
| for (auto &op : body->without_terminator()) { |
| bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { |
| return isa<VectorType>(result.getType()); |
| }); |
| if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) |
| opsToMove.insert(&op); |
| } |
| |
| // Move all the ops marked as uniform outside of the region. |
| for (Operation *op : opsToMove) |
| op->moveBefore(warpOp); |
| } |