[mlir][tensor] Add a PadOp::FoldReifiedShape canonicalization
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 22a25fd..c3147e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3791,6 +3791,47 @@
   }
 };
 
+struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    if (padOp.getNofold()) {
+      return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+    }
+
+    ReifiedRankedShapedTypeDims reifiedResultShapes;
+    if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
+      return failure();
+
+    SmallVector<int64_t> newShape;
+    for (const auto &[s, ofr] : llvm::zip_equal(
+             padOp.getResultType().getShape(), reifiedResultShapes.front())) {
+      std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+      // Reification does not add static information, just use existing shape.
+      if (!maybeCst.has_value()) {
+        newShape.push_back(s);
+        continue;
+      }
+      int64_t cst = *maybeCst;
+      assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
+      newShape.push_back(cst);
+    }
+    if (newShape == padOp.getResultType().getShape())
+      return failure();
+
+    Type oldType = padOp.getResultType();
+    Type newType =
+        RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
+    Location loc = padOp->getLoc();
+    Operation *newPad = rewriter.clone(*padOp);
+    newPad->getResult(0).setType(newType);
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
+                                                newPad->getResult(0));
+    return success();
+  }
+};
+
 } // namespace
 
 LogicalResult
@@ -3820,7 +3861,7 @@
                                         MLIRContext *context) {
   results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
               FoldOrthogonalPaddings, FoldStaticPadding,
-              FoldConsecutiveConstantPadding>(context);
+              FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3f92360..2a42a9a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2561,3 +2561,21 @@
 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
 //  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>
 //       CHECK:   return %[[RES]]
+
+// -----
+
+// CHECK-LABEL:  func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+    -> tensor<1x?x64xf32> {
+  %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+  %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+//       CHECK: tensor.pad
+//       CHECK:   : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+  %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+  ^bb0(%a: index, %b: index, %c: index):
+    tensor.yield %cst : f32
+  } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+  return %padded : tensor<1x?x64xf32>
+}