diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 03d1cab..e354195 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -34,11 +34,22 @@
 
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+
+  /// A mapping of return values to equivalent BlockArguments.
+  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
 };
 } // namespace
 
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
+/// If `value` is a memref::CastOp, return its source. Otherwise, return
+/// `value` directly.
+static Value getNonCastedValue(Value value) {
+  while (auto castOp = value.getDefiningOp<memref::CastOp>())
+    value = castOp.source();
+  return value;
+}
+
 /// Remove the attribute that triggers inplace bufferization on a FuncOp
 /// argument `bbArg`.
 static void removeBufferizationFuncArguments(BlockArgument bbArg) {
@@ -113,62 +124,40 @@
   return it2.first->second;
 }
 
-/// Return the op with Allocate MemoryEffect if `v` is equivalent to such an
-/// an op. Return null otherwise.
-static Operation *getEquivalentAlloc(Value value,
-                                     const BufferizationAliasInfo &aliasInfo) {
-  Operation *res = nullptr;
-  aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
-    if (!res)
-      if (auto interface =
-              dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
-        if (auto effect =
-                interface.getEffectOnValue<MemoryEffects::Allocate>(v))
-          res = v.getDefiningOp();
-  });
-  return res;
-}
+/// Store function BlockArguments that are equivalent to a returned value in
+/// the given ModuleBufferizationState.
+static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
+                                           ModuleBufferizationState &state) {
+  // Support only single return-terminated block in the function.
+  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  assert(returnOp && "expected func with single return op");
 
-/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
-/// Return null if no such bbArg can be found.
-static BlockArgument
-getEquivalentEnclosingFuncBBArg(Value v,
-                                const BufferizationAliasInfo &aliasInfo) {
-  if (!v.getType().isa<RankedTensorType>())
-    return nullptr;
-  Operation *op = v.getParentBlock()->getParentOp();
-  FuncOp funcOp = dyn_cast<FuncOp>(op);
-  if (!funcOp)
-    funcOp = op->getParentOfType<FuncOp>();
-  assert(funcOp && "expected non-null FuncOp");
-  for (BlockArgument bbArg : funcOp.getArguments()) {
-    if (!bbArg.getType().isa<RankedTensorType>())
-      continue;
-    if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
-      return bbArg;
-  }
-  return nullptr;
+  for (Value returnVal : returnOp.operands())
+    if (returnVal.getType().isa<RankedTensorType>())
+      for (BlockArgument bbArg : funcOp.getArguments())
+        if (bbArg.getType().isa<RankedTensorType>())
+          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
+            state.equivalentReturnValToBBArg[returnVal] = bbArg;
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
 /// buffer form (using the canonical memref layout for now), according to the
 /// inPlace-bufferizable information of the function arguments.
+///
 /// This relies on a buffer equivalence analysis of each return operand. When a
-/// result buffer is equivalent to:
-///   1. a BlockArgument of `funcOp`, it can be dropped from the return values
-///      and becomes inplaceable at all callers. This assumes all CallOp perform
-///      the necessary work to clone operands so as to make them inplaceable.
-//       Reliance on this logic will need to be relaxed in thefuture.
-///   2. an op with an Alloc effect, this currently fails bufferization but is a
-///      candidate for hoisting and creating a new inplace operand at all caller
-///      sites.
-///   3. if such a hoisting for 2. is not possible (e.g. data-dependent that
-///      prevents hoisting), this is currently unsupported and will require a
-///      refcounted buffer type.
-static LogicalResult bufferizeFuncOpBoundary(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
-    DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
+/// dropped from the return values and becomes inplaceable at all callers. This
+/// assumes all CallOp perform the necessary work to clone operands so as to
+/// make them inplaceable. Reliance on this logic will need to be relaxed in the
+/// future.
+///
+/// Note: Returning a memref currently fails bufferization. If such memrefs
+/// originate from an op with an Alloc effect, they could be hoisted in the
+/// future.
+static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
+                                             ModuleBufferizationState &state) {
   LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+  BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // If nothing to do then we are done.
   if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -197,9 +186,9 @@
     if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
       return funcOp->emitError() << "cannot bufferize bodiless function that "
                                  << "returns a tensor";
-    FunctionType bufferizedFuncType =
-        getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
-                                          TypeRange{}, bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+        funcOp, funcOp.getType().getInputs(), TypeRange{},
+        state.bufferizedFunctionTypes);
     funcOp.setType(bufferizedFuncType);
     LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
     return success();
@@ -212,37 +201,27 @@
   // 1. For each FuncOp result, keep track of which inplace argument it reuses.
   SmallVector<Value> returnValues;
   for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+    Value returnVal = returnOperand.get();
+
     // If not a renturn tensor type just forward it.
-    if (!returnOperand.get().getType().isa<RankedTensorType>()) {
-      returnValues.push_back(returnOperand.get());
+    if (!returnVal.getType().isa<RankedTensorType>()) {
+      returnValues.push_back(returnVal);
       continue;
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    Value returnVal = returnOperand.get();
-    if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+    if (state.equivalentReturnValToBBArg.count(returnVal))
       continue;
 
-    // TODO: Need to hoist above function boundary.
-    if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) {
-      returnValues.push_back(allocOp->getResult(0));
-      continue;
-    }
-
-    // Other cases legitimately need to return a tensor, this is currently not
-    // supported. For instance, if hoisting across function boundary has
-    // failed, it may be due to e.g. data-dependent sizes. In such a case, we
-    // would need a better type than memref.
-    int64_t returnIdx = returnOperand.getOperandNumber();
-    return returnOp->emitError()
-           << "buffer result #" << returnIdx << " not produced by an alloc\n";
+    // Cast values at the call site if necessary.
+    returnValues.push_back(getNonCastedValue(state.lookupBuffer(returnVal)));
   }
 
   // 2. Rewrite the terminator without the inPlace bufferizable values.
   ValueRange retValues{returnValues};
   FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
       funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
-      bufferizedFunctionTypes);
+      state.bufferizedFunctionTypes);
   OpBuilder b(returnOp);
   b.create<ReturnOp>(returnOp.getLoc(), returnValues);
   returnOp->erase();
@@ -495,6 +474,7 @@
     FuncOp funcOp = getCalledFunction(callOp);
     assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
            "expected Callop to a FuncOp");
+    auto &moduleState = static_cast<ModuleBufferizationState &>(state);
 
     // Take a guard before anything else.
     OpBuilder::InsertionGuard g(b);
@@ -507,11 +487,10 @@
     //      semantics is ill-defined.
     //    - if the callee has a body, we perform inter-procedural equivalence
     //      analysis. When successful, a result folds onto an operand. When
-    //      unsuccessful, additional work is needed to either:
+    //      unsuccessful, additional work is needed (TODO) to either:
     //        * hoist a result into an inplaceable operand or
     //        * devise a better representation to truly return a buffer.
     SmallVector<Type> resultTypes;
-    SmallVector<Value> hoistedArguments;
     if (funcOp.body().empty()) {
       if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
         return callOp->emitError()
@@ -530,8 +509,9 @@
 
         // If return operand is equivalent to some bbArg, no need to return it.
         Value returnVal = returnOperand.get();
-        if (BlockArgument bbArg =
-                getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) {
+        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
+          BlockArgument bbArg =
+              moduleState.equivalentReturnValToBBArg[returnVal];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
           int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
@@ -552,35 +532,17 @@
           continue;
         }
 
-        // TODO: Need to hoist above function boundary.
-        if (Operation *allocOp =
-                getEquivalentAlloc(returnVal, state.aliasInfo)) {
-          hoistedArguments.push_back(allocOp->getResult(0));
-          continue;
-        }
-
-        // Other cases legitimately need to return a tensor, this is currently
-        // not supported. For instance, if hoisting across function boundary has
-        // failed, it may be due to e.g. data-dependent sizes. In such a case,
-        // we would we need a better type than memref.
         resultTypes.push_back(returnType);
-
-        int64_t returnIdx = returnOperand.getOperandNumber();
-        return returnOp->emitError() << "buffer result #" << returnIdx
-                                     << " not produced by an alloc\n";
       }
     }
 
     // 2. Compute bufferized FunctionType.
     SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
-    ValueRange hoistedArgs{hoistedArguments};
-    llvm::append_range(argumentTypes, hoistedArgs.getTypes());
     // Get the bufferized FunctionType for funcOp or construct it if not yet
     // available.
-    // TODO: Assert that `state` is a ModuleBufferizationState.
-    FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
-        funcOp, argumentTypes, resultTypes,
-        static_cast<ModuleBufferizationState &>(state).bufferizedFunctionTypes);
+    FunctionType bufferizedFuncType =
+        getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
+                                          moduleState.bufferizedFunctionTypes);
 
     // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
     SmallVector<Value> newOperands;
@@ -713,6 +675,8 @@
     // Analyze and bufferize funcOp.
     if (failed(runComprehensiveBufferize(funcOp, options, state)))
       return failure();
+
+    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)
@@ -721,8 +685,7 @@
   for (FuncOp funcOp : orderedFuncOps) {
     // Note: It would be good to apply cleanups here but we cannot as aliasInfo
     // would be invalidated.
-    if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
-                                       state.bufferizedFunctionTypes)))
+    if (failed(bufferizeFuncOpBoundary(funcOp, state)))
       return failure();
 
     if (!options.allowReturnMemref &&
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index d9e0027..c0a91e3 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -110,6 +110,7 @@
 
 // -----
 
+// expected-error @+1 {{memref return type is unsupported}}
 func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
   ->  tensor<4xf32>
 {
@@ -121,7 +122,6 @@
   //     argument aliasing).
   %r0 = tensor.extract_slice %A[0][4][1] : tensor<?xf32> to tensor<4xf32>
 
-  // expected-error @+1 {{buffer result #0 not produced by an alloc}}
   return %r0: tensor<4xf32>
 }
 
