[mlir][MemRef] Fix SubViewOp canonicalization when a subset of unit-dims are dropped.

The canonical type of the result of the `memref.subview` needs to make
sure that the previously dropped unit-dimensions are the ones dropped
for the canonicalized type as well. This means the generic
`inferRankReducedResultType` cannot be used. Instead the current
dropped dimensions need to be querried and the same need to be dropped.

Reviewed By: nicolasvasilache, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D114751
diff --git a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
index 11d81d7..4c3799c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Utils/Utils.h
@@ -63,6 +63,8 @@
     ResultTypeFunc resultTypeFunc;
     auto resultType =
         resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
+    if (!resultType)
+      return failure();
     auto newOp =
         rewriter.create<OpType>(op.getLoc(), resultType, op.source(),
                                 mixedOffsets, mixedSizes, mixedStrides);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ac36d1d..b9bd01b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -511,14 +511,16 @@
 /// dimension is dropped the stride must be dropped too.
 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
-                               ArrayAttr staticSizes) {
+                               ArrayRef<OpFoldResult> sizes) {
   llvm::SmallDenseSet<unsigned> unusedDims;
   if (originalType.getRank() == reducedType.getRank())
     return unusedDims;
 
-  for (auto dim : llvm::enumerate(staticSizes))
-    if (dim.value().cast<IntegerAttr>().getInt() == 1)
-      unusedDims.insert(dim.index());
+  for (auto dim : llvm::enumerate(sizes))
+    if (auto attr = dim.value().dyn_cast<Attribute>())
+      if (attr.cast<IntegerAttr>().getInt() == 1)
+        unusedDims.insert(dim.index());
+
   SmallVector<int64_t> originalStrides, candidateStrides;
   int64_t originalOffset, candidateOffset;
   if (failed(
@@ -574,7 +576,7 @@
   MemRefType sourceType = getSourceType();
   MemRefType resultType = getType();
   llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
-      computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
+      computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
   assert(unusedDims && "unable to find unused dims of subview");
   return *unusedDims;
 }
@@ -1718,7 +1720,7 @@
 /// not matching dimension must be 1.
 static SubViewVerificationResult
 isRankReducedType(Type originalType, Type candidateReducedType,
-                  ArrayAttr staticSizes, std::string *errMsg = nullptr) {
+                  ArrayRef<OpFoldResult> sizes, std::string *errMsg = nullptr) {
   if (originalType == candidateReducedType)
     return SubViewVerificationResult::Success;
   if (!originalType.isa<MemRefType>())
@@ -1743,7 +1745,7 @@
   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
 
   auto optionalUnusedDimsMask =
-      computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
+      computeMemRefRankReductionMask(original, candidateReduced, sizes);
 
   // Sizes cannot be matched in case empty vector is returned.
   if (!optionalUnusedDimsMask.hasValue())
@@ -1813,7 +1815,7 @@
 
   std::string errMsg;
   auto result =
-      isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
+      isRankReducedType(expectedType, subViewType, op.getMixedSizes(), &errMsg);
   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
 }
 
@@ -1854,21 +1856,29 @@
 /// Infer the canonical type of the result of a subview operation. Returns a
 /// type with rank `resultRank` that is either the rank of the rank-reduced
 /// type, or the non-rank-reduced type.
-static MemRefType
-getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
-                              ArrayRef<OpFoldResult> mixedOffsets,
-                              ArrayRef<OpFoldResult> mixedSizes,
-                              ArrayRef<OpFoldResult> mixedStrides) {
-  auto resultType =
-      SubViewOp::inferRankReducedResultType(
-          resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
-          .cast<MemRefType>();
-  if (resultType.getRank() != resultRank) {
-    resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
-                                            mixedSizes, mixedStrides)
-                     .cast<MemRefType>();
+static MemRefType getCanonicalSubViewResultType(
+    MemRefType currentResultType, MemRefType sourceType,
+    ArrayRef<OpFoldResult> mixedOffsets, ArrayRef<OpFoldResult> mixedSizes,
+    ArrayRef<OpFoldResult> mixedStrides) {
+  auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets,
+                                                       mixedSizes, mixedStrides)
+                                .cast<MemRefType>();
+  llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
+      computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes);
+  // Return nullptr as failure mode.
+  if (!unusedDims)
+    return nullptr;
+  SmallVector<int64_t> shape;
+  for (auto sizes : llvm::enumerate(nonRankReducedType.getShape())) {
+    if (unusedDims->count(sizes.index()))
+      continue;
+    shape.push_back(sizes.value());
   }
-  return resultType;
+  AffineMap layoutMap = nonRankReducedType.getLayout().getAffineMap();
+  if (!layoutMap.isIdentity())
+    layoutMap = getProjectedMap(layoutMap, unusedDims.getValue());
+  return MemRefType::get(shape, nonRankReducedType.getElementType(), layoutMap,
+                         nonRankReducedType.getMemorySpace());
 }
 
 namespace {
@@ -1911,8 +1921,7 @@
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
     auto resultType = getCanonicalSubViewResultType(
-        subViewOp.getType().getRank(),
-        castOp.source().getType().cast<MemRefType>(),
+        subViewOp.getType(), castOp.source().getType().cast<MemRefType>(),
         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
         subViewOp.getMixedStrides());
     Value newSubView = rewriter.create<SubViewOp>(
@@ -1931,9 +1940,9 @@
   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
                         ArrayRef<OpFoldResult> mixedSizes,
                         ArrayRef<OpFoldResult> mixedStrides) {
-    return getCanonicalSubViewResultType(op.getType().getRank(),
-                                         op.getSourceType(), mixedOffsets,
-                                         mixedSizes, mixedStrides);
+    return getCanonicalSubViewResultType(op.getType(), op.getSourceType(),
+                                         mixedOffsets, mixedSizes,
+                                         mixedStrides);
   }
 };
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 71f4c3e..a568d5f 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -47,7 +47,7 @@
 
 // -----
 
-#map0 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
 func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
     %arg2 : index) -> memref<?x?xf32, #map0>
 {
@@ -395,3 +395,25 @@
   %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref<?x?x?x?xf32> into memref<?x?xf32>
   return %collapsed : memref<?x?xf32>
 }
+
+// -----
+
+func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
+    -> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
+  %c0 = arith.constant 0 : index
+  %c5 = arith.constant 5 : index
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.subview %arg0[%arg1, %arg1, %arg1, 0] [%c1, %c4, %c1, 1] [1, 1, 1, 1]
+      : memref<2x5x7x1xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+  %1 = memref.cast %0
+      : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> to
+        memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+  return %1 : memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>>
+}
+
+// CHECK-LABEL: func @reduced_memref
+//       CHECK:   %[[RESULT:.+]] = memref.subview
+//  CHECK-SAME:       memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}>
+//       CHECK:   return %[[RESULT]]