[mlir][Vector] Improve vector.transferx store-to-load-forwarding (#171840)

This patch changes the transfer_write -> transfer_read load store
forwarding canonicalization pattern to work based on permutation maps
and less on adhoc logic. The old logic couldn't canonicalize a simple
unit dim broadcast through transfer_write/transfer_read which is added
as a test in this patch.

This patch also details what would be needed to support cases which are
not yet implemented better.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2789f63..58b3fe0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5118,6 +5118,22 @@
   return Speculation::NotSpeculatable;
 }
 
+/// Given a projected permutation, inverse an affine map, making the unused dims
+/// 0 in the result.
+static AffineMap inverseWithUnusedDims(AffineMap map) {
+  assert(map.isProjectedPermutation() &&
+         "expected a projected permutation map");
+  SmallVector<AffineExpr> results(map.getNumInputs(),
+                                  getAffineConstantExpr(0, map.getContext()));
+  for (auto [idx, result] : llvm::enumerate(map.getResults())) {
+    // We should only have dim exprs because this is a projected permutation.
+    int64_t pos = cast<AffineDimExpr>(result).getPosition();
+    results[pos] = getAffineDimExpr(idx, map.getContext());
+  }
+  return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
+                        results, map.getContext());
+}
+
 namespace {
 /// Store to load forwarding for transfer operations with permuation maps.
 /// Even if the permutation maps are different we can still propagate the store
@@ -5153,6 +5169,13 @@
     // Bail if we need an alias analysis.
     if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
       return failure();
+    // Bail in the masked case (too complex atm and needed to properly account
+    // for padding).
+    if (readOp.getMask() || defWrite.getMask())
+      return failure();
+    // If indices are not the same a shift may be required, bail.
+    if (readOp.getIndices() != defWrite.getIndices())
+      return failure();
     // Bail if we need a bounds analysis.
     if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
       return failure();
@@ -5161,60 +5184,50 @@
     if (readOp.getTransferChunkAccessed() !=
         defWrite.getTransferChunkAccessed())
       return failure();
-    // TODO: Support cases where a dim is explicitly written but implicitly
-    // read (i.e., a unit dim that is rank reduced).
-    if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
-        getUnusedDimsBitVector({defWrite.getPermutationMap()}))
+    // WriteMap: tensor -> w_vec
+    // ReadMap: tensor -> r_vec
+    //
+    // inv(WriteMap): w_vec -> tensor
+    // inv(WriteMap) o ReadMap: w_vec -> r_vec
+    AffineMap readMap = readOp.getPermutationMap();
+    AffineMap writeMap = defWrite.getPermutationMap();
+    AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
+    AffineMap composedMap = readMap.compose(invWriteMap);
+    // If there are any unused dims in the composedMap, we have to drop some
+    // unit dims from the written vector before we can do transpose(broadcast).
+    // TODO: Support this case.
+    if (getUnusedDimsBitVector(composedMap).any())
       return failure();
-    // This pattern should only catch the broadcast case, the non-broadcast case
-    // should be done separately to keep application conditions clean and
-    // separate.
-    AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
-    AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
-    bool bcast = !readMap.getBroadcastDims().empty() ||
-                 !writeMap.getBroadcastDims().empty();
-    if (!bcast)
-      return failure();
-    // At this point, we know we have a bcast.
-    // Bail in the masked case (too complex atm and needed to properly account
-    // for padding).
-    if (readOp.getMask() || defWrite.getMask())
-      return failure();
-    // If indices are not the same a shift may be required, bail.
-    if (readOp.getIndices() != defWrite.getIndices())
-      return failure();
-
-    Value vec = defWrite.getVector();
-    // TODO: loop through the chain of transfer_write if we can prove that they
-    // don't overlap with the transfer_read. This requires improving
-    // `isDisjointTransferIndices` helper.
-    AffineMap map = readMap.compose(writeMap);
-    if (map.getNumResults() == 0)
-      return failure();
-    // Calculate the permutation to apply to go from the vector stored to the
-    // vector read.
-    SmallVector<unsigned> permutation;
-    if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
-      return failure();
-
-    Location loc = readOp.getLoc();
-    // Calculate the broadcast shape by applying the reverse permutation to the
-    // final shape we want.
-    ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
-    SmallVector<int64_t> broadcastShape(destShape.size());
-    SmallVector<bool> broadcastScalableFlags(destShape.size());
-    for (const auto &pos : llvm::enumerate(permutation)) {
-      broadcastShape[pos.value()] = destShape[pos.index()];
-      broadcastScalableFlags[pos.value()] =
-          readOp.getVectorType().getScalableDims()[pos.index()];
+    // readVec = transpose(broadcast(writeVec))
+    //
+    // Build a transpose permutation for the above transpose operation.
+    //
+    // Treat the composed map as having extra leading dimensions which are
+    // the broadcasted dimensions, and treat the zeros as these new broadcasted
+    // dimensions.
+    SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
+    int64_t numBroadcastedDims = broadcastedDims.size();
+    auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
+    invPerm.resize(composedMap.getNumResults());
+    for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
+      if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+        int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
+        invPerm[effectiveDim] = idx;
+      }
     }
-    VectorType broadcastedType = VectorType::get(
-        broadcastShape, defWrite.getVectorType().getElementType(),
-        broadcastScalableFlags);
-    vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
-    SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
-    rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
-                                                     transposePerm);
+    // Applying the inverse permutation on the readVecTy will give us the
+    // broadcast result type.
+    VectorType readVecTy = readOp.getVectorType();
+    SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
+    auto broadcastedVecTy =
+        VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
+                        readVecTy.getElementType(),
+                        applyPermutation(readVecTy.getScalableDims(), invPerm));
+    // Build the transpose(broadcast) transformation.
+    Value vec = defWrite.getVector();
+    Location loc = readOp.getLoc();
+    vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
+    rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 084f49f..50d52c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1914,6 +1914,25 @@
 
 // -----
 
+// CHECK-LABEL: func @store_to_load_tensor_forwarding_unit_dim_broadcast
+//  CHECK-SAME: (%[[V0:.*]]: vector<4x8xf32>, %[[MEM:.*]]: tensor<1x1x4x8xf32>)
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: %[[RET:.+]] = vector.broadcast %[[V0]] : vector<4x8xf32> to vector<1x1x4x8xf32>
+// CHECK: return %[[RET]]
+func.func @store_to_load_tensor_forwarding_unit_dim_broadcast(
+    %vec: vector<4x8xf32>,
+    %mem : tensor<1x1x4x8xf32>
+  ) -> vector<1x1x4x8xf32> {
+  %c0 = arith.constant 0 : index
+  %cst_0 = arith.constant 0.0 : f32
+  %write = vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : vector<4x8xf32>, tensor<1x1x4x8xf32>
+  %read = vector.transfer_read %write[%c0, %c0, %c0, %c0], %cst_0 : tensor<1x1x4x8xf32>, vector<1x1x4x8xf32>
+  return %read : vector<1x1x4x8xf32>
+}
+
+// -----
+
 
 // CHECK-LABEL: func @dead_store_tensor
 //   CHECK-DAG:      %[[C0:.*]] = arith.constant 0 : index