[mlir][x86vector] Improve intrinsic operands creation (#138666)

Refactors intrinsic op interface to delegate initial operands mapping to
the dialect converter and allow intrinsic operands getters to only
perform last mile post-processing.
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
index 4f8301f..25d9c40 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td
@@ -83,7 +83,10 @@
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -404,7 +407,10 @@
     }
   }];
   let extraClassDeclaration = [{
-    SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -452,7 +458,10 @@
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 
 }
@@ -500,7 +509,10 @@
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 
@@ -543,7 +555,10 @@
   }];
 
   let extraClassDeclaration = [{
-        SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
+    SmallVector<Value> getIntrinsicOperands(
+        ::mlir::ArrayRef<Value> operands,
+        const ::mlir::LLVMTypeConverter &typeConverter,
+        ::mlir::RewriterBase &rewriter);
   }];
 }
 #endif // X86VECTOR_OPS
diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
index 5176f4a..cde9d1d 100644
--- a/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
+++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td
@@ -58,9 +58,11 @@
       }],
       /*retType=*/"SmallVector<Value>",
       /*methodName=*/"getIntrinsicOperands",
-      /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
+      /*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
+                    "const ::mlir::LLVMTypeConverter &":$typeConverter,
+                    "::mlir::RewriterBase &":$rewriter),
       /*methodBody=*/"",
-      /*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
+      /*defaultImplementation=*/"return SmallVector<Value>(operands);"
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
index 8d383b1..cc7ab7f 100644
--- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
+++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp
@@ -31,24 +31,11 @@
       >();
 }
 
-static SmallVector<Value>
-getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
-                 RewriterBase &rewriter,
-                 const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands;
-  auto opType = memrefVal.getType();
-
-  Type llvmStructType = typeConverter.convertType(opType);
-  Value llvmStruct =
-      rewriter
-          .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
-          .getResult(0);
-  MemRefDescriptor memRefDescriptor(llvmStruct);
-
-  Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
-  operands.push_back(ptr);
-
-  return operands;
+static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
+                              const LLVMTypeConverter &typeConverter,
+                              RewriterBase &rewriter) {
+  MemRefDescriptor memRefDescriptor(buffer);
+  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
 }
 
 LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@
 }
 
 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
   auto loc = getLoc();
+  Adaptor adaptor(operands, *this);
 
-  auto opType = getA().getType();
+  auto opType = adaptor.getA().getType();
   Value src;
-  if (getSrc()) {
-    src = getSrc();
-  } else if (getConstantSrc()) {
-    src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
+  if (adaptor.getSrc()) {
+    src = adaptor.getSrc();
+  } else if (adaptor.getConstantSrc()) {
+    src = rewriter.create<LLVM::ConstantOp>(loc, opType,
+                                            adaptor.getConstantSrcAttr());
   } else {
     auto zeroAttr = rewriter.getZeroAttr(opType);
     src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
   }
 
-  return SmallVector<Value>{getA(), src, getK()};
+  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
 }
 
 SmallVector<Value>
-x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
-                                       const LLVMTypeConverter &typeConverter) {
-  SmallVector<Value> operands(getOperands());
+x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
+                                       const LLVMTypeConverter &typeConverter,
+                                       RewriterBase &rewriter) {
+  SmallVector<Value> intrinsicOperands(operands);
   // Dot product of all elements, broadcasted to all elements.
   Value scale =
       rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
-  operands.push_back(scale);
+  intrinsicOperands.push_back(scale);
 
-  return operands;
+  return intrinsicOperands;
 }
 
 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
-    RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
-  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
+    ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
+    RewriterBase &rewriter) {
+  Adaptor adaptor(operands, *this);
+  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
+                           typeConverter, rewriter)};
 }
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index 9ee44a6..483c1f5 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -84,20 +84,23 @@
 /// Generic one-to-one conversion of simply mappable operations into calls
 /// to their respective LLVM intrinsics.
 struct OneToOneIntrinsicOpConversion
-    : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
-  using OpInterfaceRewritePattern<
-      x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
+    : public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
+  using OpInterfaceConversionPattern<
+      x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
 
   OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
                                 PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
+      : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
+                                     benefit),
         typeConverter(typeConverter) {}
 
-  LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
-                                PatternRewriter &rewriter) const override {
-    return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
-                            op.getIntrinsicOperands(rewriter, typeConverter),
-                            typeConverter, rewriter);
+  LogicalResult
+  matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    return intrinsicRewrite(
+        op, rewriter.getStringAttr(op.getIntrinsicName()),
+        op.getIntrinsicOperands(operands, typeConverter, rewriter),
+        typeConverter, rewriter);
   }
 
 private: