[mlir][linalg][bufferize] Bufferization of tensor.insert
This is a lightweight operation, useful for writing unit tests. It will be utilized for testing in subsequent commits.
Differential Revision: https://reviews.llvm.org/D114693
GitOrigin-RevId: 4479138de8e662f0dc64a92008b126f050e18b77
diff --git a/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 0cecb1d..695e448 100644
--- a/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -230,6 +230,50 @@
}
};
+struct InsertOpInterface
+ : public BufferizableOpInterface::ExternalModel<InsertOpInterface,
+ tensor::InsertOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ return true;
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
+ "expected dest OpOperand");
+ return op->getOpResult(0);
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {&op->getOpOperand(1) /*dest*/};
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto insertOp = cast<tensor::InsertOp>(op);
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(insertOp);
+
+ Location loc = insertOp.getLoc();
+ Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state);
+ b.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
+ insertOp.indices());
+ state.mapBuffer(insertOp, destMemref);
+ state.aliasInfo.insertNewBufferAlias(insertOp, destMemref);
+ return success();
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const {
+ return BufferRelation::Equivalent;
+ }
+};
+
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
///
@@ -459,6 +503,7 @@
registry.addOpInterface<tensor::ExtractSliceOp,
tensor_ext::ExtractSliceOpInterface>();
registry.addOpInterface<tensor::ExtractOp, tensor_ext::ExtractOpInterface>();
+ registry.addOpInterface<tensor::InsertOp, tensor_ext::InsertOpInterface>();
registry.addOpInterface<tensor::InsertSliceOp,
tensor_ext::InsertSliceOpInterface>();
}
diff --git a/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 583c58e..3a70adb 100644
--- a/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -916,3 +916,15 @@
}
return %r : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_op
+// CHECK-SAME: %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
+func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
+ %s : f32, %i : index) -> tensor<?xf32> {
+ // CHECK: memref.store %[[s]], %[[t1]][%[[i]]]
+ %0 = tensor.insert %s into %t1[%i] : tensor<?xf32>
+ // CHECK: return
+ return %0 : tensor<?xf32>
+}