[mlir] canonicalizer: shape_cast(poison) -> poison  (#133988)

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with 

%1 = ub.poison : vector<6xf32>

---------

Signed-off-by: James Newling <james.newling@gmail.com>
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f0..59f3b78 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
 
 #include <cassert>
 #include <cstdint>
@@ -5611,18 +5612,20 @@
 }
 
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+
   // No-op shape cast.
-  if (getSource().getType() == getResult().getType())
+  if (getSource().getType() == getType())
     return getSource();
 
+  VectorType resultType = getType();
+
   // Canceling shape casts.
   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-    if (getResult().getType() == otherOp.getSource().getType())
-      return otherOp.getSource();
 
-    // Only allows valid transitive folding.
-    VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
-    VectorType resultType = llvm::cast<VectorType>(getResult().getType());
+    // Only allows valid transitive folding (expand/collapse dimensions).
+    VectorType srcType = otherOp.getSource().getType();
+    if (resultType == srcType)
+      return otherOp.getSource();
     if (srcType.getRank() < resultType.getRank()) {
       if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
         return {};
@@ -5632,43 +5635,32 @@
     } else {
       return {};
     }
-
     setOperand(otherOp.getSource());
     return getResult();
   }
 
   // Cancelling broadcast and shape cast ops.
   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
-    if (bcastOp.getSourceType() == getType())
+    if (bcastOp.getSourceType() == resultType)
       return bcastOp.getSource();
   }
 
+  // shape_cast(constant) -> constant
+  if (auto splatAttr =
+          llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+    return DenseElementsAttr::get(resultType,
+                                  splatAttr.getSplatValue<Attribute>());
+  }
+
+  // shape_cast(poison) -> poison
+  if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+    return ub::PoisonAttr::get(getContext());
+  }
+
   return {};
 }
 
 namespace {
-// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
-class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto constantOp =
-        shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
-    if (!constantOp)
-      return failure();
-    // Only handle splat for now.
-    auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
-    if (!dense)
-      return failure();
-    auto newAttr =
-        DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
-                               dense.getSplatValue<Attribute>());
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
-    return success();
-  }
-};
 
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+  results
+      .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec..72064fb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@
 
 // -----
 
+// CHECK-LABEL: shape_cast_poison
+//       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+//       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %poison = ub.poison : vector<5x4x2xf32>
+  %poison_1 = ub.poison : vector<12x2xi32>
+  %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>