[mlir][linalg][bufferize][NFC] Allow returning arbitrary memrefs

If `allowReturnMemref` is set to true, arbitrary memrefs may be returned from FuncOps. Also remove allocation hoisting code, which is only partly implemented at the moment.

The purpose of this commit is to untangle `bufferize` from `aliasInfo`. (Even with this change, they are not fully untangled yet.)

Differential Revision: https://reviews.llvm.org/D114507

GitOrigin-RevId: c94b80b4380ce851b5cf406a961eab472a43b3df
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 03d1cab..e354195 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -34,11 +34,22 @@
 
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+
+  /// A mapping of return values to equivalent BlockArguments.
+  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
 };
 } // namespace
 
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
+/// If `value` is a memref::CastOp, return its source. Otherwise, return
+/// `value` directly.
+static Value getNonCastedValue(Value value) {
+  while (auto castOp = value.getDefiningOp<memref::CastOp>())
+    value = castOp.source();
+  return value;
+}
+
 /// Remove the attribute that triggers inplace bufferization on a FuncOp
 /// argument `bbArg`.
 static void removeBufferizationFuncArguments(BlockArgument bbArg) {
@@ -113,62 +124,40 @@
   return it2.first->second;
 }
 
-/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an
-/// an op. Return null otherwise.
-static Operation *getEquivalentAlloc(Value value,
-                                     const BufferizationAliasInfo &aliasInfo) {
-  Operation *res = nullptr;
-  aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
-    if (!res)
-      if (auto interface =
-              dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
-        if (auto effect =
-                interface.getEffectOnValue<MemoryEffects::Allocate>(v))
-          res = v.getDefiningOp();
-  });
-  return res;
-}
+/// Store function BlockArguments that are equivalent to a returned value in
+/// the given ModuleBufferizationState.
+static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
+                                           ModuleBufferizationState &state) {
+  // Support only single return-terminated block in the function.
+  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  assert(returnOp && "expected func with single return op");
 
-/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
-/// Return null if no such bbArg can be found.
-static BlockArgument
-getEquivalentEnclosingFuncBBArg(Value v,
-                                const BufferizationAliasInfo &aliasInfo) {
-  if (!v.getType().isa<RankedTensorType>())
-    return nullptr;
-  Operation *op = v.getParentBlock()->getParentOp();
-  FuncOp funcOp = dyn_cast<FuncOp>(op);
-  if (!funcOp)
-    funcOp = op->getParentOfType<FuncOp>();
-  assert(funcOp && "expected non-null FuncOp");
-  for (BlockArgument bbArg : funcOp.getArguments()) {
-    if (!bbArg.getType().isa<RankedTensorType>())
-      continue;
-    if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
-      return bbArg;
-  }
-  return nullptr;
+  for (Value returnVal : returnOp.operands())
+    if (returnVal.getType().isa<RankedTensorType>())
+      for (BlockArgument bbArg : funcOp.getArguments())
+        if (bbArg.getType().isa<RankedTensorType>())
+          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
+            state.equivalentReturnValToBBArg[returnVal] = bbArg;
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
 /// buffer form (using the canonical memref layout for now), according to the
 /// inPlace-bufferizable information of the function arguments.
+///
 /// This relies on a buffer equivalence analysis of each return operand. When a
-/// result buffer is equivalent to:
-///   1. a BlockArgument of `funcOp`, it can be dropped from the return values
-///      and becomes inplaceable at all callers. This assumes all CallOp perform
-///      the necessary work to clone operands so as to make them inplaceable.
-//       Reliance on this logic will need to be relaxed in thefuture.
-///   2. an op with an Alloc effect, this currently fails bufferization but is a
-///      candidate for hoisting and creating a new inplace operand at all caller
-///      sites.
-///   3. if such a hoisting for 2. is not possible (e.g. data-dependent that
-///      prevents hoisting), this is currently unsupported and will require a
-///      refcounted buffer type.
-static LogicalResult bufferizeFuncOpBoundary(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
-    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
+/// dropped from the return values and becomes inplaceable at all callers. This
+/// assumes all CallOp perform the necessary work to clone operands so as to
+/// make them inplaceable. Reliance on this logic will need to be relaxed in the
+/// future.
+///
+/// Note: Returning a memref currently fails bufferization. If such memrefs
+/// originate from an op with an Alloc effect, they could be hoisted in the
+/// future.
+static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
+                                             ModuleBufferizationState &state) {
   LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+  BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // If nothing to do then we are done.
   if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -197,9 +186,9 @@
     if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
       return funcOp->emitError() << "cannot bufferize bodiless function that "
                                  << "returns a tensor";
-    FunctionType bufferizedFuncType =
-        getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
-                                          TypeRange{}, bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+        funcOp, funcOp.getType().getInputs(), TypeRange{},
+        state.bufferizedFunctionTypes);
     funcOp.setType(bufferizedFuncType);
     LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
     return success();
@@ -212,37 +201,27 @@
   // 1. For each FuncOp result, keep track of which inplace argument it reuses.
   SmallVector<Value> returnValues;
   for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+    Value returnVal = returnOperand.get();
+
     // If not a renturn tensor type just forward it.
-    if (!returnOperand.get().getType().isa<RankedTensorType>()) {
-      returnValues.push_back(returnOperand.get());
+    if (!returnVal.getType().isa<RankedTensorType>()) {
+      returnValues.push_back(returnVal);
       continue;
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    Value returnVal = returnOperand.get();
-    if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+    if (state.equivalentReturnValToBBArg.count(returnVal))
       continue;
 
-    // TODO: Need to hoist above function boundary.
-    if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
-      returnValues.push_back(allocOp->getResult(0));
-      continue;
-    }
-
-    // Other cases legitimately need to return a tensor, this is currently not
-    // supported. For instance, if hoisting across function boundary has
-    // failed, it may be due to e.g. data-dependent sizes. In such a case, we
-    // would need a better type than memref.
-    int64_t returnIdx = returnOperand.getOperandNumber();
-    return returnOp->emitError()
-           << "buffer result #" << returnIdx << " not produced by an alloc\n";
+    // Cast values at the call site if necessary.
+    returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
   }
 
   // 2. Rewrite the terminator without the inPlace bufferizable values.
   ValueRange retValues{returnValues};
   FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
       funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
-      bufferizedFunctionTypes);
+      state.bufferizedFunctionTypes);
   OpBuilder b(returnOp);
   b.create<ReturnOp>(returnOp.getLoc(), returnValues);
   returnOp->erase();
@@ -495,6 +474,7 @@
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
            "expected Callop to a FuncOp");
+    auto &moduleState = static_cast<ModuleBufferizationState &>(state);
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
@@ -507,11 +487,10 @@
     //      semantics is ill-defined.
     //    - if the callee has a body, we perform inter-procedural equivalence
     //      analysis. When successful, a result folds onto an operand. When
-    //      unsuccessful, additional work is needed to either:
+    //      unsuccessful, additional work is needed (TODO) to either:
     //        * hoist a result into an inplaceable operand or
     //        * devise a better representation to truly return a buffer.
     SmallVector<Type> resultTypes;
-    SmallVector<Value> hoistedArguments;
     if (funcOp.body().empty()) {
       if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
         return callOp->emitError()
@@ -530,8 +509,9 @@
 
         // If return operand is equivalent to some bbArg, no need to return it.
         Value returnVal = returnOperand.get();
-        if (BlockArgument bbArg =
-                getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
+        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
+          BlockArgument bbArg =
+              moduleState.equivalentReturnValToBBArg[returnVal];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
           int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
@@ -552,35 +532,17 @@
           continue;
         }
 
-        // TODO: Need to hoist above function boundary.
-        if (Operation *allocOp =
-                getEquivalentAlloc(returnVal, state.aliasInfo)) {
-          hoistedArguments.push_back(allocOp->getResult(0));
-          continue;
-        }
-
-        // Other cases legitimately need to return a tensor, this is currently
-        // not supported. For instance, if hoisting across function boundary has
-        // failed, it may be due to e.g. data-dependent sizes. In such a case,
-        // we would we need a better type than memref.
         resultTypes.push_back(returnType);
-
-        int64_t returnIdx = returnOperand.getOperandNumber();
-        return returnOp->emitError() << "buffer result #" << returnIdx
-                                     << " not produced by an alloc\n";
       }
     }
 
     // 2. Compute bufferized FunctionType.
     SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
-    ValueRange hoistedArgs{hoistedArguments};
-    llvm::append_range(argumentTypes, hoistedArgs.getTypes());
     // Get the bufferized FunctionType for funcOp or construct it if not yet
     // available.
-    // TODO: Assert that `state` is a ModuleBufferizationState.
-    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
-        funcOp, argumentTypes, resultTypes,
-        static_cast<ModuleBufferizationState &>(state).bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType =
+        getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
+                                          moduleState.bufferizedFunctionTypes);
 
     // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
     SmallVector<Value> newOperands;
@@ -713,6 +675,8 @@
     // Analyze and bufferize funcOp.
     if (failed(runComprehensiveBufferize(funcOp, options, state)))
       return failure();
+
+    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)
@@ -721,8 +685,7 @@
   for (FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
-                                       state.bufferizedFunctionTypes)))
+    if (failed(bufferizeFuncOpBoundary(funcOp, state)))
       return failure();
 
     if (!options.allowReturnMemref &&
diff --git a/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index d9e0027..c0a91e3 100644
--- a/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -110,6 +110,7 @@
 
 // -----
 
+// expected-error @+1 {{memref return type is unsupported}}
 func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   ->  tensor<4xf32>
 {
@@ -121,7 +122,6 @@
   //     argument aliasing).
   %r0 = tensor.extract_slice %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
 
-  // expected-error @+1 {{buffer result #0 not produced by an alloc}}
   return %r0: tensor<4xf32>
 }