[mlir][Vector] Improve vector.transferx store-to-load-forwarding (#171840)
This patch changes the transfer_write -> transfer_read load store
forwarding canonicalization pattern to work based on permutation maps
and less on adhoc logic. The old logic couldn't canonicalize a simple
unit dim broadcast through transfer_write/transfer_read which is added
as a test in this patch.
This patch also details what would be needed to support cases which are
not yet implemented better.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2789f63..58b3fe0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5118,6 +5118,22 @@
return Speculation::NotSpeculatable;
}
+/// Given a projected permutation, inverse an affine map, making the unused dims
+/// 0 in the result.
+static AffineMap inverseWithUnusedDims(AffineMap map) {
+ assert(map.isProjectedPermutation() &&
+ "expected a projected permutation map");
+ SmallVector<AffineExpr> results(map.getNumInputs(),
+ getAffineConstantExpr(0, map.getContext()));
+ for (auto [idx, result] : llvm::enumerate(map.getResults())) {
+ // We should only have dim exprs because this is a projected permutation.
+ int64_t pos = cast<AffineDimExpr>(result).getPosition();
+ results[pos] = getAffineDimExpr(idx, map.getContext());
+ }
+ return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
+ results, map.getContext());
+}
+
namespace {
/// Store to load forwarding for transfer operations with permuation maps.
/// Even if the permutation maps are different we can still propagate the store
@@ -5153,6 +5169,13 @@
// Bail if we need an alias analysis.
if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
return failure();
+ // Bail in the masked case (too complex atm and needed to properly account
+ // for padding).
+ if (readOp.getMask() || defWrite.getMask())
+ return failure();
+ // If indices are not the same a shift may be required, bail.
+ if (readOp.getIndices() != defWrite.getIndices())
+ return failure();
// Bail if we need a bounds analysis.
if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
return failure();
@@ -5161,60 +5184,50 @@
if (readOp.getTransferChunkAccessed() !=
defWrite.getTransferChunkAccessed())
return failure();
- // TODO: Support cases where a dim is explicitly written but implicitly
- // read (i.e., a unit dim that is rank reduced).
- if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
- getUnusedDimsBitVector({defWrite.getPermutationMap()}))
+ // WriteMap: tensor -> w_vec
+ // ReadMap: tensor -> r_vec
+ //
+ // inv(WriteMap): w_vec -> tensor
+ // inv(WriteMap) o ReadMap: w_vec -> r_vec
+ AffineMap readMap = readOp.getPermutationMap();
+ AffineMap writeMap = defWrite.getPermutationMap();
+ AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
+ AffineMap composedMap = readMap.compose(invWriteMap);
+ // If there are any unused dims in the composedMap, we have to drop some
+ // unit dims from the written vector before we can do transpose(broadcast).
+ // TODO: Support this case.
+ if (getUnusedDimsBitVector(composedMap).any())
return failure();
- // This pattern should only catch the broadcast case, the non-broadcast case
- // should be done separately to keep application conditions clean and
- // separate.
- AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
- AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
- bool bcast = !readMap.getBroadcastDims().empty() ||
- !writeMap.getBroadcastDims().empty();
- if (!bcast)
- return failure();
- // At this point, we know we have a bcast.
- // Bail in the masked case (too complex atm and needed to properly account
- // for padding).
- if (readOp.getMask() || defWrite.getMask())
- return failure();
- // If indices are not the same a shift may be required, bail.
- if (readOp.getIndices() != defWrite.getIndices())
- return failure();
-
- Value vec = defWrite.getVector();
- // TODO: loop through the chain of transfer_write if we can prove that they
- // don't overlap with the transfer_read. This requires improving
- // `isDisjointTransferIndices` helper.
- AffineMap map = readMap.compose(writeMap);
- if (map.getNumResults() == 0)
- return failure();
- // Calculate the permutation to apply to go from the vector stored to the
- // vector read.
- SmallVector<unsigned> permutation;
- if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
- return failure();
-
- Location loc = readOp.getLoc();
- // Calculate the broadcast shape by applying the reverse permutation to the
- // final shape we want.
- ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
- SmallVector<int64_t> broadcastShape(destShape.size());
- SmallVector<bool> broadcastScalableFlags(destShape.size());
- for (const auto &pos : llvm::enumerate(permutation)) {
- broadcastShape[pos.value()] = destShape[pos.index()];
- broadcastScalableFlags[pos.value()] =
- readOp.getVectorType().getScalableDims()[pos.index()];
+ // readVec = transpose(broadcast(writeVec))
+ //
+ // Build a transpose permutation for the above transpose operation.
+ //
+ // Treat the composed map as having extra leading dimensions which are
+ // the broadcasted dimensions, and treat the zeros as these new broadcasted
+ // dimensions.
+ SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
+ int64_t numBroadcastedDims = broadcastedDims.size();
+ auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
+ invPerm.resize(composedMap.getNumResults());
+ for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+ int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
+ invPerm[effectiveDim] = idx;
+ }
}
- VectorType broadcastedType = VectorType::get(
- broadcastShape, defWrite.getVectorType().getElementType(),
- broadcastScalableFlags);
- vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
- SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
- rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
- transposePerm);
+ // Applying the inverse permutation on the readVecTy will give us the
+ // broadcast result type.
+ VectorType readVecTy = readOp.getVectorType();
+ SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
+ auto broadcastedVecTy =
+ VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
+ readVecTy.getElementType(),
+ applyPermutation(readVecTy.getScalableDims(), invPerm));
+ // Build the transpose(broadcast) transformation.
+ Value vec = defWrite.getVector();
+ Location loc = readOp.getLoc();
+ vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
+ rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 084f49f..50d52c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1914,6 +1914,25 @@
// -----
+// CHECK-LABEL: func @store_to_load_tensor_forwarding_unit_dim_broadcast
+// CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: return %[[RET]]
+func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
+ %vec: vector<4x8xf32>,
+ %mem : tensor<1x1x4x8xf32>
+ ) -> vector<1x1x4x8xf32> {
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 0.0 : f32
+ %write = vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : vector<4x8xf32>, tensor<1x1x4x8xf32>
+ %read = vector.transfer_read %write[%c0, %c0, %c0, %c0], %cst_0 : tensor<1x1x4x8xf32>, vector<1x1x4x8xf32>
+ return %read : vector<1x1x4x8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @dead_store_tensor
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index