[mlir][arith] `arith-to-apfloat`: Add vector support (#171024)

Add support for vectorized operations such as `arith.addf ... :
vector<4xf4E2M1FN>`. The computation is scalarized: scalar operands are
extracted with `vector.to_elements`, multiple scalar computations are
performed and the result is inserted back into a vector with
`vector.from_elements`.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 75ab4b6..fcbaf3cc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,7 +198,8 @@
     calls (APFloatWrappers.cpp). APFloat is a software implementation of
     floating-point arithmetic operations.
   }];
-  let dependentDialects = ["func::FuncDialect"];
+  let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+                           "vector::VectorDialect"];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 4776ba0..79816fc 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -90,6 +91,73 @@
                                    b.getIntegerAttr(b.getI32Type(), sem));
 }
 
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+                                Value operand1, Value operand2, Type resultType,
+                                Fn fn) {
+  auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+  if (operand2) {
+    // Sanity check: Operand types must match.
+    assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+           "expected same vector types");
+  }
+  if (!vecTy1) {
+    // Not a vector. Call the function directly.
+    return fn(operand1, operand2, resultType);
+  }
+
+  // Prepare scalar operands.
+  ResultRange sclars1 =
+      vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+  SmallVector<Value> scalars2;
+  if (!operand2) {
+    // No second operand. Create a vector of empty values.
+    scalars2.assign(vecTy1.getNumElements(), Value());
+  } else {
+    llvm::append_range(
+        scalars2,
+        vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+  }
+
+  // Call the function for each pair of scalar operands.
+  auto resultVecType = cast<VectorType>(resultType);
+  SmallVector<Value> results;
+  for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+    Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+    results.push_back(result);
+  }
+
+  // Package the results into a vector.
+  return vector::FromElementsOp::create(
+      rewriter, loc,
+      vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+      results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+  for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+    Type type = value.getType();
+    if (auto vecTy = dyn_cast<VectorType>(type)) {
+      type = vecTy.getElementType();
+    }
+    if (!type.isIntOrFloat()) {
+      return rewriter.notifyMatchFailure(
+          op, "only integers and floats (or vectors thereof) are supported");
+    }
+    if (type.getIntOrFloatBitWidth() > 64)
+      return rewriter.notifyMatchFailure(op,
+                                         "bitwidth > 64 bits is not supported");
+  }
+  return success();
+}
+
 /// Rewrite a binary arithmetic operation to an APFloat function call.
 template <typename OpTy>
 struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -102,9 +170,8 @@
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     FailureOr<FuncOp> fn =
@@ -112,31 +179,37 @@
     if (failed(fn))
       return fn;
 
-    rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto floatTy = cast<FloatType>(op.getType());
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    auto int64Type = rewriter.getI64Type();
-    Value lhsBits = arith::ExtUIOp::create(
-        rewriter, loc, int64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
-    Value rhsBits = arith::ExtUIOp::create(
-        rewriter, loc, int64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+        [&](Value lhs, Value rhs, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto floatTy = cast<FloatType>(resultType);
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          auto int64Type = rewriter.getI64Type();
+          Value lhsBits = arith::ExtUIOp::create(
+              rewriter, loc, int64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+          Value rhsBits = arith::ExtUIOp::create(
+              rewriter, loc, int64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
-    // Call APFloat function.
-    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
-    auto resultOp =
-        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*fn), params);
+          // Call APFloat function.
+          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+          auto resultOp = func::CallOp::create(rewriter, loc,
+                                               TypeRange(rewriter.getI64Type()),
+                                               SymbolRefAttr::get(*fn), params);
 
-    // Truncate result to the original width.
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
-                                                  resultOp->getResult(0));
-    rewriter.replaceOp(
-        op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+          // Truncate result to the original width.
+          Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+                                                        resultOp->getResult(0));
+          return arith::BitcastOp::create(rewriter, loc, floatTy,
+                                          truncatedBits);
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -152,10 +225,8 @@
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getType().getIntOrFloatBitWidth() > 64 ||
-        op.getOperand().getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -165,30 +236,36 @@
     if (failed(fn))
       return fn;
 
-    rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto inFloatTy = cast<FloatType>(op.getOperand().getType());
-    auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand1, Value operand2, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto inFloatTy = cast<FloatType>(operand1.getType());
+          auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
-    // Call APFloat function.
-    Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
-    auto outFloatTy = cast<FloatType>(op.getType());
-    Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
-    std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
-    auto resultOp =
-        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*fn), params);
+          // Call APFloat function.
+          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          auto outFloatTy = cast<FloatType>(resultType);
+          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+          auto resultOp = func::CallOp::create(rewriter, loc,
+                                               TypeRange(rewriter.getI64Type()),
+                                               SymbolRefAttr::get(*fn), params);
 
-    // Truncate result to the original width.
-    auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
-                                                  resultOp->getResult(0));
-    rewriter.replaceOp(
-        op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
+          // Truncate result to the original width.
+          auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+          Value truncatedBits = arith::TruncIOp::create(
+              rewriter, loc, outIntWType, resultOp->getResult(0));
+          return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+                                          truncatedBits);
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -204,10 +281,8 @@
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getType().getIntOrFloatBitWidth() > 64 ||
-        op.getOperand().getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -219,33 +294,39 @@
     if (failed(fn))
       return fn;
 
-    rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto inFloatTy = cast<FloatType>(op.getOperand().getType());
-    auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand1, Value operand2, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto inFloatTy = cast<FloatType>(operand1.getType());
+          auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
 
-    // Call APFloat function.
-    Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
-    auto outIntTy = cast<IntegerType>(op.getType());
-    Value outWidthValue = arith::ConstantOp::create(
-        rewriter, loc, i32Type,
-        rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
-    Value isUnsignedValue = arith::ConstantOp::create(
-        rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
-    SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
-                                 operandBits};
-    auto resultOp =
-        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*fn), params);
+          // Call APFloat function.
+          Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+          auto outIntTy = cast<IntegerType>(resultType);
+          Value outWidthValue = arith::ConstantOp::create(
+              rewriter, loc, i32Type,
+              rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+          Value isUnsignedValue = arith::ConstantOp::create(
+              rewriter, loc, i1Type,
+              rewriter.getIntegerAttr(i1Type, isUnsigned));
+          SmallVector<Value> params = {inSemValue, outWidthValue,
+                                       isUnsignedValue, operandBits};
+          auto resultOp = func::CallOp::create(rewriter, loc,
+                                               TypeRange(rewriter.getI64Type()),
+                                               SymbolRefAttr::get(*fn), params);
 
-    // Truncate result to the original width.
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
-                                                  resultOp->getResult(0));
-    rewriter.replaceOp(op, truncatedBits);
+          // Truncate result to the original width.
+          return arith::TruncIOp::create(rewriter, loc, outIntTy,
+                                         resultOp->getResult(0));
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -262,10 +343,8 @@
 
   LogicalResult matchAndRewrite(OpTy op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getType().getIntOrFloatBitWidth() > 64 ||
-        op.getOperand().getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -277,42 +356,48 @@
     if (failed(fn))
       return fn;
 
-    rewriter.setInsertionPoint(op);
-    // Cast operands to 64-bit integers.
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto inIntTy = cast<IntegerType>(op.getOperand().getType());
-    Value operandBits = op.getOperand();
-    if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
-      if (isUnsigned) {
-        operandBits =
-            arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
-      } else {
-        operandBits =
-            arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
-      }
-    }
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand1, Value operand2, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto inIntTy = cast<IntegerType>(operand1.getType());
+          Value operandBits = operand1;
+          if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+            if (isUnsigned) {
+              operandBits =
+                  arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+            } else {
+              operandBits =
+                  arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+            }
+          }
 
-    // Call APFloat function.
-    auto outFloatTy = cast<FloatType>(op.getType());
-    Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
-    Value inWidthValue = arith::ConstantOp::create(
-        rewriter, loc, i32Type,
-        rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
-    Value isUnsignedValue = arith::ConstantOp::create(
-        rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
-    SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
-                                 operandBits};
-    auto resultOp =
-        func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
-                             SymbolRefAttr::get(*fn), params);
+          // Call APFloat function.
+          auto outFloatTy = cast<FloatType>(resultType);
+          Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+          Value inWidthValue = arith::ConstantOp::create(
+              rewriter, loc, i32Type,
+              rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+          Value isUnsignedValue = arith::ConstantOp::create(
+              rewriter, loc, i1Type,
+              rewriter.getIntegerAttr(i1Type, isUnsigned));
+          SmallVector<Value> params = {outSemValue, inWidthValue,
+                                       isUnsignedValue, operandBits};
+          auto resultOp = func::CallOp::create(rewriter, loc,
+                                               TypeRange(rewriter.getI64Type()),
+                                               SymbolRefAttr::get(*fn), params);
 
-    // Truncate result to the original width.
-    auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
-                                                  resultOp->getResult(0));
-    Value result =
-        arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
-    rewriter.replaceOp(op, result);
+          // Truncate result to the original width.
+          auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+          Value truncatedBits = arith::TruncIOp::create(
+              rewriter, loc, outIntWType, resultOp->getResult(0));
+          return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+                                          truncatedBits);
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -327,9 +412,8 @@
 
   LogicalResult matchAndRewrite(arith::CmpFOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -342,121 +426,130 @@
     if (failed(fn))
       return fn;
 
-    // Cast operands to 64-bit integers.
-    rewriter.setInsertionPoint(op);
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto floatTy = cast<FloatType>(op.getLhs().getType());
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    Value lhsBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
-    Value rhsBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type,
-        arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+        [&](Value lhs, Value rhs, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto floatTy = cast<FloatType>(lhs.getType());
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          Value lhsBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+          Value rhsBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, rhs));
 
-    // Call APFloat function.
-    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, lhsBits, rhsBits};
-    Value comparisonResult =
-        func::CallOp::create(rewriter, loc, TypeRange(i8Type),
-                             SymbolRefAttr::get(*fn), params)
-            ->getResult(0);
+          // Call APFloat function.
+          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+          Value comparisonResult =
+              func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+                                   SymbolRefAttr::get(*fn), params)
+                  ->getResult(0);
 
-    // Generate an i1 SSA value that is "true" if the comparison result matches
-    // the given `val`.
-    auto checkResult = [&](llvm::APFloat::cmpResult val) {
-      return arith::CmpIOp::create(
-          rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
-          arith::ConstantOp::create(
-              rewriter, loc, i8Type,
-              rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
-              .getResult());
-    };
-    // Generate an i1 SSA value that is "true" if the comparison result matches
-    // any of the given `vals`.
-    std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
-        [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
-          Value first = checkResult(vals.front());
-          if (vals.size() == 1)
-            return first;
-          Value rest = checkResults(vals.drop_front());
-          return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
-        };
+          // Generate an i1 SSA value that is "true" if the comparison result
+          // matches the given `val`.
+          auto checkResult = [&](llvm::APFloat::cmpResult val) {
+            return arith::CmpIOp::create(
+                rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+                arith::ConstantOp::create(
+                    rewriter, loc, i8Type,
+                    rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+                    .getResult());
+          };
+          // Generate an i1 SSA value that is "true" if the comparison result
+          // matches any of the given `vals`.
+          std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)>
+              checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+                Value first = checkResult(vals.front());
+                if (vals.size() == 1)
+                  return first;
+                Value rest = checkResults(vals.drop_front());
+                return arith::OrIOp::create(rewriter, loc, first, rest)
+                    .getResult();
+              };
 
-    // This switch-case statement was taken from arith::applyCmpPredicate.
-    Value result;
-    switch (op.getPredicate()) {
-    case arith::CmpFPredicate::AlwaysFalse:
-      result = arith::ConstantOp::create(rewriter, loc, i1Type,
-                                         rewriter.getIntegerAttr(i1Type, 0))
-                   .getResult();
-      break;
-    case arith::CmpFPredicate::OEQ:
-      result = checkResult(llvm::APFloat::cmpEqual);
-      break;
-    case arith::CmpFPredicate::OGT:
-      result = checkResult(llvm::APFloat::cmpGreaterThan);
-      break;
-    case arith::CmpFPredicate::OGE:
-      result = checkResults(
-          {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::OLT:
-      result = checkResult(llvm::APFloat::cmpLessThan);
-      break;
-    case arith::CmpFPredicate::OLE:
-      result =
-          checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::ONE:
-      // Not cmpUnordered and not cmpUnordered.
-      result = checkResults(
-          {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
-      break;
-    case arith::CmpFPredicate::ORD:
-      // Not cmpUnordered.
-      result = checkResults({llvm::APFloat::cmpLessThan,
-                             llvm::APFloat::cmpGreaterThan,
-                             llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::UEQ:
-      result =
-          checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::UGT:
-      result = checkResults(
-          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
-      break;
-    case arith::CmpFPredicate::UGE:
-      result = checkResults({llvm::APFloat::cmpUnordered,
-                             llvm::APFloat::cmpGreaterThan,
-                             llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::ULT:
-      result = checkResults(
-          {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
-      break;
-    case arith::CmpFPredicate::ULE:
-      result =
-          checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
-                        llvm::APFloat::cmpEqual});
-      break;
-    case arith::CmpFPredicate::UNE:
-      // Not cmpEqual.
-      result = checkResults({llvm::APFloat::cmpLessThan,
-                             llvm::APFloat::cmpGreaterThan,
-                             llvm::APFloat::cmpUnordered});
-      break;
-    case arith::CmpFPredicate::UNO:
-      result = checkResult(llvm::APFloat::cmpUnordered);
-      break;
-    case arith::CmpFPredicate::AlwaysTrue:
-      result = arith::ConstantOp::create(rewriter, loc, i1Type,
-                                         rewriter.getIntegerAttr(i1Type, 1))
-                   .getResult();
-      break;
-    }
-    rewriter.replaceOp(op, result);
+          // This switch-case statement was taken from arith::applyCmpPredicate.
+          Value result;
+          switch (op.getPredicate()) {
+          case arith::CmpFPredicate::AlwaysFalse:
+            result =
+                arith::ConstantOp::create(rewriter, loc, i1Type,
+                                          rewriter.getIntegerAttr(i1Type, 0))
+                    .getResult();
+            break;
+          case arith::CmpFPredicate::OEQ:
+            result = checkResult(llvm::APFloat::cmpEqual);
+            break;
+          case arith::CmpFPredicate::OGT:
+            result = checkResult(llvm::APFloat::cmpGreaterThan);
+            break;
+          case arith::CmpFPredicate::OGE:
+            result = checkResults(
+                {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::OLT:
+            result = checkResult(llvm::APFloat::cmpLessThan);
+            break;
+          case arith::CmpFPredicate::OLE:
+            result = checkResults(
+                {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::ONE:
+            // Not cmpUnordered and not cmpUnordered.
+            result = checkResults(
+                {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+            break;
+          case arith::CmpFPredicate::ORD:
+            // Not cmpUnordered.
+            result = checkResults({llvm::APFloat::cmpLessThan,
+                                   llvm::APFloat::cmpGreaterThan,
+                                   llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::UEQ:
+            result = checkResults(
+                {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::UGT:
+            result = checkResults(
+                {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+            break;
+          case arith::CmpFPredicate::UGE:
+            result = checkResults({llvm::APFloat::cmpUnordered,
+                                   llvm::APFloat::cmpGreaterThan,
+                                   llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::ULT:
+            result = checkResults(
+                {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+            break;
+          case arith::CmpFPredicate::ULE:
+            result = checkResults({llvm::APFloat::cmpUnordered,
+                                   llvm::APFloat::cmpLessThan,
+                                   llvm::APFloat::cmpEqual});
+            break;
+          case arith::CmpFPredicate::UNE:
+            // Not cmpEqual.
+            result = checkResults({llvm::APFloat::cmpLessThan,
+                                   llvm::APFloat::cmpGreaterThan,
+                                   llvm::APFloat::cmpUnordered});
+            break;
+          case arith::CmpFPredicate::UNO:
+            result = checkResult(llvm::APFloat::cmpUnordered);
+            break;
+          case arith::CmpFPredicate::AlwaysTrue:
+            result =
+                arith::ConstantOp::create(rewriter, loc, i1Type,
+                                          rewriter.getIntegerAttr(i1Type, 1))
+                    .getResult();
+            break;
+          }
+          return result;
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
@@ -470,9 +563,8 @@
 
   LogicalResult matchAndRewrite(arith::NegFOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.getOperand().getType().getIntOrFloatBitWidth() > 64)
-      return rewriter.notifyMatchFailure(op,
-                                         "bitwidth > 64 bits is not supported");
+    if (failed(checkPreconditions(rewriter, op)))
+      return failure();
 
     // Get APFloat function from runtime library.
     auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -482,28 +574,34 @@
     if (failed(fn))
       return fn;
 
-    // Cast operand to 64-bit integer.
-    rewriter.setInsertionPoint(op);
+    // Scalarize and convert to APFloat runtime calls.
     Location loc = op.getLoc();
-    auto floatTy = cast<FloatType>(op.getOperand().getType());
-    auto intWType = rewriter.getIntegerType(floatTy.getWidth());
-    Value operandBits = arith::ExtUIOp::create(
-        rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand()));
+    rewriter.setInsertionPoint(op);
+    Value repl = forEachScalarValue(
+        rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+        [&](Value operand1, Value operand2, Type resultType) {
+          // Cast operands to 64-bit integers.
+          auto floatTy = cast<FloatType>(operand1.getType());
+          auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+          Value operandBits = arith::ExtUIOp::create(
+              rewriter, loc, i64Type,
+              arith::BitcastOp::create(rewriter, loc, intWType, operand1));
 
-    // Call APFloat function.
-    Value semValue = getSemanticsValue(rewriter, loc, floatTy);
-    SmallVector<Value> params = {semValue, operandBits};
-    Value negatedBits =
-        func::CallOp::create(rewriter, loc, TypeRange(i64Type),
-                             SymbolRefAttr::get(*fn), params)
-            ->getResult(0);
+          // Call APFloat function.
+          Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+          SmallVector<Value> params = {semValue, operandBits};
+          Value negatedBits =
+              func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+                                   SymbolRefAttr::get(*fn), params)
+                  ->getResult(0);
 
-    // Truncate result to the original width.
-    Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
-                                                  negatedBits);
-    Value result =
-        arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits);
-    rewriter.replaceOp(op, result);
+          // Truncate result to the original width.
+          Value truncatedBits =
+              arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+          return arith::BitcastOp::create(rewriter, loc, floatTy,
+                                          truncatedBits);
+        });
+    rewriter.replaceOp(op, repl);
     return success();
   }
 
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
index b5ec49c08..31fce7a 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -15,4 +15,5 @@
   MLIRArithTransforms
   MLIRFuncDialect
   MLIRFuncUtils
+  MLIRVectorDialect
   )
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index ab05ede..bd4a9da 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -288,3 +288,42 @@
   %8 = arith.sitofp %6 : i92 to f32
   return
 }
+
+// -----
+
+// CHECK-LABEL: func.func @addf_vector
+//     CHECK-2:   vector.to_elements
+
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   call
+//       CHECK:   arith.trunci
+
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   call
+//       CHECK:   arith.trunci
+
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   call
+//       CHECK:   arith.trunci
+
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   arith.bitcast
+//       CHECK:   arith.extui
+//       CHECK:   call
+//       CHECK:   arith.trunci
+
+//       CHECK:   vector.from_elements
+func.func @addf_vector(%arg0: vector<4xf4E2M1FN>, %arg1: vector<4xf4E2M1FN>) {
+  %0 = arith.addf %arg0, %arg1 : vector<4xf4E2M1FN>
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
new file mode 100644
index 0000000..3b94dc4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
@@ -0,0 +1,26 @@
+// REQUIRES: system-linux
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-scf-to-vector | \
+// RUN:     --convert-scf-to-cf --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN:             --shared-libs=%mlir_c_runner_utils \
+// RUN:             --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo_vec() -> (vector<4xf8E4M3FN>, vector<4xf32>) {
+  %cst1 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf8E4M3FN>
+  %cst2 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf32>
+  return %cst1, %cst2 : vector<4xf8E4M3FN>, vector<4xf32>
+}
+
+func.func @entry() {
+  // CHECK-NEXT: ( 3.5, 3.5, 3.5, 3.5 )
+  %a1_vec = arith.constant dense<[1.4, 1.4, 1.4, 1.4]> : vector<4xf8E4M3FN>
+  %b1_vec, %b2_vec = func.call @foo_vec() : () -> (vector<4xf8E4M3FN>, vector<4xf32>)
+  %c1_vec = arith.addf %a1_vec, %b1_vec : vector<4xf8E4M3FN>  // not supported by LLVM
+  vector.print %c1_vec : vector<4xf8E4M3FN>
+  return
+}