[mlir][linalg][bufferize] Bufferization of tensor.insert

This is a lightweight operation, useful for writing unit tests. It will be utilized for testing in subsequent commits.

Differential Revision: https://reviews.llvm.org/D114693

GitOrigin-RevId: 4479138de8e662f0dc64a92008b126f050e18b77
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 0cecb1d..695e448 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -230,6 +230,50 @@
   }
 };
 
+struct InsertOpInterface
+    : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
+                                                    tensor::InsertOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+    return true;
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
+           "expected dest OpOperand");
+    return op->getOpResult(0);
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {&op->getOpOperand(1) /*dest*/};
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto insertOp = cast<tensor::InsertOp>(op);
+
+    // Take a guard before anything else.
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(insertOp);
+
+    Location loc = insertOp.getLoc();
+    Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state);
+    b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
+                              insertOp.indices());
+    state.mapBuffer(insertOp, destMemref);
+    state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
+    return success();
+  }
+
+  BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+    return BufferRelation::Equivalent;
+  }
+};
+
 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
 /// equivalent operand / result and same offset/sizes/strides specification).
 ///
@@ -459,6 +503,7 @@
   registry.addOpInterface<tensor::ExtractSliceOp,
                           tensor_ext::ExtractSliceOpInterface>();
   registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
+  registry.addOpInterface<tensor::InsertOp, tensor_ext::InsertOpInterface>();
   registry.addOpInterface<tensor::InsertSliceOp,
                           tensor_ext::InsertSliceOpInterface>();
 }
diff --git a/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 583c58e..3a70adb 100644
--- a/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -916,3 +916,15 @@
   }
   return %r : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_op
+//  CHECK-SAME:     %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
+func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
+                %s : f32, %i : index) -> tensor<?xf32> {
+  // CHECK: memref.store %[[s]], %[[t1]][%[[i]]]
+  %0 = tensor.insert %s into %t1[%i] : tensor<?xf32>
+  // CHECK: return
+  return %0 : tensor<?xf32>
+}