[mlir][arith] `arith-to-apfloat`: Add vector support (#171024)
Add support for vectorized operations such as `arith.addf ... :
vector<4xf4E2M1FN>`. The computation is scalarized: scalar operands are
extracted with `vector.to_elements`, multiple scalar computations are
performed and the result is inserted back into a vector with
`vector.from_elements`.
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 75ab4b6..fcbaf3cc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -198,7 +198,8 @@
calls (APFloatWrappers.cpp). APFloat is a software implementation of
floating-point arithmetic operations.
}];
- let dependentDialects = ["func::FuncDialect"];
+ let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+ "vector::VectorDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
index 4776ba0..79816fc 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
+++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Utils/Utils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
@@ -90,6 +91,73 @@
b.getIntegerAttr(b.getI32Type(), sem));
}
+/// Given two operands of vector type and vector result type (with the same
+/// shape), call the given function for each pair of scalar operands and
+/// package the result into a vector. If the given operands and result type are
+/// not vectors, call the function directly. The second operand is optional.
+template <typename Fn, typename... Values>
+static Value forEachScalarValue(RewriterBase &rewriter, Location loc,
+ Value operand1, Value operand2, Type resultType,
+ Fn fn) {
+ auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
+ if (operand2) {
+ // Sanity check: Operand types must match.
+ assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
+ "expected same vector types");
+ }
+ if (!vecTy1) {
+ // Not a vector. Call the function directly.
+ return fn(operand1, operand2, resultType);
+ }
+
+ // Prepare scalar operands.
+ ResultRange sclars1 =
+ vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
+ SmallVector<Value> scalars2;
+ if (!operand2) {
+ // No second operand. Create a vector of empty values.
+ scalars2.assign(vecTy1.getNumElements(), Value());
+ } else {
+ llvm::append_range(
+ scalars2,
+ vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
+ }
+
+ // Call the function for each pair of scalar operands.
+ auto resultVecType = cast<VectorType>(resultType);
+ SmallVector<Value> results;
+ for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
+ Value result = fn(scalar1, scalar2, resultVecType.getElementType());
+ results.push_back(result);
+ }
+
+ // Package the results into a vector.
+ return vector::FromElementsOp::create(
+ rewriter, loc,
+ vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
+ results);
+}
+
+/// Check preconditions for the conversion:
+/// 1. All operands / results must be integers or floats (or vectors thereof).
+/// 2. The bitwidth of the operands / results must be <= 64.
+static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
+ for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
+ Type type = value.getType();
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ type = vecTy.getElementType();
+ }
+ if (!type.isIntOrFloat()) {
+ return rewriter.notifyMatchFailure(
+ op, "only integers and floats (or vectors thereof) are supported");
+ }
+ if (type.getIntOrFloatBitWidth() > 64)
+ return rewriter.notifyMatchFailure(op,
+ "bitwidth > 64 bits is not supported");
+ }
+ return success();
+}
+
/// Rewrite a binary arithmetic operation to an APFloat function call.
template <typename OpTy>
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
@@ -102,9 +170,8 @@
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
FailureOr<FuncOp> fn =
@@ -112,31 +179,37 @@
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- auto int64Type = rewriter.getI64Type();
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, int64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(resultType);
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ auto int64Type = rewriter.getI64Type();
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, int64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
- // Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
+ // Truncate result to the original width.
+ Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
+ resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -152,10 +225,8 @@
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -165,30 +236,36 @@
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- rewriter.replaceOp(
- op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits));
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -204,10 +281,8 @@
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -219,33 +294,39 @@
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inFloatTy = cast<FloatType>(op.getOperand().getType());
- auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand()));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inFloatTy = cast<FloatType>(operand1.getType());
+ auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
- // Call APFloat function.
- Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
- auto outIntTy = cast<IntegerType>(op.getType());
- Value outWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {inSemValue, outWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
+ // Call APFloat function.
+ Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
+ auto outIntTy = cast<IntegerType>(resultType);
+ Value outWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {inSemValue, outWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntTy,
- resultOp->getResult(0));
- rewriter.replaceOp(op, truncatedBits);
+ // Truncate result to the original width.
+ return arith::TruncIOp::create(rewriter, loc, outIntTy,
+ resultOp->getResult(0));
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -262,10 +343,8 @@
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
- if (op.getType().getIntOrFloatBitWidth() > 64 ||
- op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -277,42 +356,48 @@
if (failed(fn))
return fn;
- rewriter.setInsertionPoint(op);
- // Cast operands to 64-bit integers.
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto inIntTy = cast<IntegerType>(op.getOperand().getType());
- Value operandBits = op.getOperand();
- if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
- if (isUnsigned) {
- operandBits =
- arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
- } else {
- operandBits =
- arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
- }
- }
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto inIntTy = cast<IntegerType>(operand1.getType());
+ Value operandBits = operand1;
+ if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
+ if (isUnsigned) {
+ operandBits =
+ arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
+ } else {
+ operandBits =
+ arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
+ }
+ }
- // Call APFloat function.
- auto outFloatTy = cast<FloatType>(op.getType());
- Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
- Value inWidthValue = arith::ConstantOp::create(
- rewriter, loc, i32Type,
- rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
- Value isUnsignedValue = arith::ConstantOp::create(
- rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
- SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
- operandBits};
- auto resultOp =
- func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
- SymbolRefAttr::get(*fn), params);
+ // Call APFloat function.
+ auto outFloatTy = cast<FloatType>(resultType);
+ Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
+ Value inWidthValue = arith::ConstantOp::create(
+ rewriter, loc, i32Type,
+ rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
+ Value isUnsignedValue = arith::ConstantOp::create(
+ rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, isUnsigned));
+ SmallVector<Value> params = {outSemValue, inWidthValue,
+ isUnsignedValue, operandBits};
+ auto resultOp = func::CallOp::create(rewriter, loc,
+ TypeRange(rewriter.getI64Type()),
+ SymbolRefAttr::get(*fn), params);
- // Truncate result to the original width.
- auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
- resultOp->getResult(0));
- Value result =
- arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
- rewriter.replaceOp(op, result);
+ // Truncate result to the original width.
+ auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
+ Value truncatedBits = arith::TruncIOp::create(
+ rewriter, loc, outIntWType, resultOp->getResult(0));
+ return arith::BitcastOp::create(rewriter, loc, outFloatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -327,9 +412,8 @@
LogicalResult matchAndRewrite(arith::CmpFOp op,
PatternRewriter &rewriter) const override {
- if (op.getLhs().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i1Type = IntegerType::get(symTable->getContext(), 1);
@@ -342,121 +426,130 @@
if (failed(fn))
return fn;
- // Cast operands to 64-bit integers.
- rewriter.setInsertionPoint(op);
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getLhs().getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value lhsBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
- Value rhsBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type,
- arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
+ [&](Value lhs, Value rhs, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(lhs.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value lhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, lhs));
+ Value rhsBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, rhs));
- // Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, lhsBits, rhsBits};
- Value comparisonResult =
- func::CallOp::create(rewriter, loc, TypeRange(i8Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
+ Value comparisonResult =
+ func::CallOp::create(rewriter, loc, TypeRange(i8Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
- // Generate an i1 SSA value that is "true" if the comparison result matches
- // the given `val`.
- auto checkResult = [&](llvm::APFloat::cmpResult val) {
- return arith::CmpIOp::create(
- rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
- arith::ConstantOp::create(
- rewriter, loc, i8Type,
- rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
- .getResult());
- };
- // Generate an i1 SSA value that is "true" if the comparison result matches
- // any of the given `vals`.
- std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
- [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
- Value first = checkResult(vals.front());
- if (vals.size() == 1)
- return first;
- Value rest = checkResults(vals.drop_front());
- return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
- };
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches the given `val`.
+ auto checkResult = [&](llvm::APFloat::cmpResult val) {
+ return arith::CmpIOp::create(
+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
+ arith::ConstantOp::create(
+ rewriter, loc, i8Type,
+ rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
+ .getResult());
+ };
+ // Generate an i1 SSA value that is "true" if the comparison result
+ // matches any of the given `vals`.
+ std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)>
+ checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
+ Value first = checkResult(vals.front());
+ if (vals.size() == 1)
+ return first;
+ Value rest = checkResults(vals.drop_front());
+ return arith::OrIOp::create(rewriter, loc, first, rest)
+ .getResult();
+ };
- // This switch-case statement was taken from arith::applyCmpPredicate.
- Value result;
- switch (op.getPredicate()) {
- case arith::CmpFPredicate::AlwaysFalse:
- result = arith::ConstantOp::create(rewriter, loc, i1Type,
- rewriter.getIntegerAttr(i1Type, 0))
- .getResult();
- break;
- case arith::CmpFPredicate::OEQ:
- result = checkResult(llvm::APFloat::cmpEqual);
- break;
- case arith::CmpFPredicate::OGT:
- result = checkResult(llvm::APFloat::cmpGreaterThan);
- break;
- case arith::CmpFPredicate::OGE:
- result = checkResults(
- {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::OLT:
- result = checkResult(llvm::APFloat::cmpLessThan);
- break;
- case arith::CmpFPredicate::OLE:
- result =
- checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::ONE:
- // Not cmpUnordered and not cmpUnordered.
- result = checkResults(
- {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
- break;
- case arith::CmpFPredicate::ORD:
- // Not cmpUnordered.
- result = checkResults({llvm::APFloat::cmpLessThan,
- llvm::APFloat::cmpGreaterThan,
- llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::UEQ:
- result =
- checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::UGT:
- result = checkResults(
- {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
- break;
- case arith::CmpFPredicate::UGE:
- result = checkResults({llvm::APFloat::cmpUnordered,
- llvm::APFloat::cmpGreaterThan,
- llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::ULT:
- result = checkResults(
- {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
- break;
- case arith::CmpFPredicate::ULE:
- result =
- checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
- llvm::APFloat::cmpEqual});
- break;
- case arith::CmpFPredicate::UNE:
- // Not cmpEqual.
- result = checkResults({llvm::APFloat::cmpLessThan,
- llvm::APFloat::cmpGreaterThan,
- llvm::APFloat::cmpUnordered});
- break;
- case arith::CmpFPredicate::UNO:
- result = checkResult(llvm::APFloat::cmpUnordered);
- break;
- case arith::CmpFPredicate::AlwaysTrue:
- result = arith::ConstantOp::create(rewriter, loc, i1Type,
- rewriter.getIntegerAttr(i1Type, 1))
- .getResult();
- break;
- }
- rewriter.replaceOp(op, result);
+ // This switch-case statement was taken from arith::applyCmpPredicate.
+ Value result;
+ switch (op.getPredicate()) {
+ case arith::CmpFPredicate::AlwaysFalse:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 0))
+ .getResult();
+ break;
+ case arith::CmpFPredicate::OEQ:
+ result = checkResult(llvm::APFloat::cmpEqual);
+ break;
+ case arith::CmpFPredicate::OGT:
+ result = checkResult(llvm::APFloat::cmpGreaterThan);
+ break;
+ case arith::CmpFPredicate::OGE:
+ result = checkResults(
+ {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::OLT:
+ result = checkResult(llvm::APFloat::cmpLessThan);
+ break;
+ case arith::CmpFPredicate::OLE:
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ONE:
+ // Not cmpUnordered and not cmpUnordered.
+ result = checkResults(
+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::ORD:
+ // Not cmpUnordered.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UEQ:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UGT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
+ break;
+ case arith::CmpFPredicate::UGE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::ULT:
+ result = checkResults(
+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
+ break;
+ case arith::CmpFPredicate::ULE:
+ result = checkResults({llvm::APFloat::cmpUnordered,
+ llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpEqual});
+ break;
+ case arith::CmpFPredicate::UNE:
+ // Not cmpEqual.
+ result = checkResults({llvm::APFloat::cmpLessThan,
+ llvm::APFloat::cmpGreaterThan,
+ llvm::APFloat::cmpUnordered});
+ break;
+ case arith::CmpFPredicate::UNO:
+ result = checkResult(llvm::APFloat::cmpUnordered);
+ break;
+ case arith::CmpFPredicate::AlwaysTrue:
+ result =
+ arith::ConstantOp::create(rewriter, loc, i1Type,
+ rewriter.getIntegerAttr(i1Type, 1))
+ .getResult();
+ break;
+ }
+ return result;
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
@@ -470,9 +563,8 @@
LogicalResult matchAndRewrite(arith::NegFOp op,
PatternRewriter &rewriter) const override {
- if (op.getOperand().getType().getIntOrFloatBitWidth() > 64)
- return rewriter.notifyMatchFailure(op,
- "bitwidth > 64 bits is not supported");
+ if (failed(checkPreconditions(rewriter, op)))
+ return failure();
// Get APFloat function from runtime library.
auto i32Type = IntegerType::get(symTable->getContext(), 32);
@@ -482,28 +574,34 @@
if (failed(fn))
return fn;
- // Cast operand to 64-bit integer.
- rewriter.setInsertionPoint(op);
+ // Scalarize and convert to APFloat runtime calls.
Location loc = op.getLoc();
- auto floatTy = cast<FloatType>(op.getOperand().getType());
- auto intWType = rewriter.getIntegerType(floatTy.getWidth());
- Value operandBits = arith::ExtUIOp::create(
- rewriter, loc, i64Type, arith::BitcastOp::create(rewriter, loc, intWType, op.getOperand()));
+ rewriter.setInsertionPoint(op);
+ Value repl = forEachScalarValue(
+ rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
+ [&](Value operand1, Value operand2, Type resultType) {
+ // Cast operands to 64-bit integers.
+ auto floatTy = cast<FloatType>(operand1.getType());
+ auto intWType = rewriter.getIntegerType(floatTy.getWidth());
+ Value operandBits = arith::ExtUIOp::create(
+ rewriter, loc, i64Type,
+ arith::BitcastOp::create(rewriter, loc, intWType, operand1));
- // Call APFloat function.
- Value semValue = getSemanticsValue(rewriter, loc, floatTy);
- SmallVector<Value> params = {semValue, operandBits};
- Value negatedBits =
- func::CallOp::create(rewriter, loc, TypeRange(i64Type),
- SymbolRefAttr::get(*fn), params)
- ->getResult(0);
+ // Call APFloat function.
+ Value semValue = getSemanticsValue(rewriter, loc, floatTy);
+ SmallVector<Value> params = {semValue, operandBits};
+ Value negatedBits =
+ func::CallOp::create(rewriter, loc, TypeRange(i64Type),
+ SymbolRefAttr::get(*fn), params)
+ ->getResult(0);
- // Truncate result to the original width.
- Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
- negatedBits);
- Value result =
- arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits);
- rewriter.replaceOp(op, result);
+ // Truncate result to the original width.
+ Value truncatedBits =
+ arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
+ return arith::BitcastOp::create(rewriter, loc, floatTy,
+ truncatedBits);
+ });
+ rewriter.replaceOp(op, repl);
return success();
}
diff --git a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
index b5ec49c08..31fce7a 100644
--- a/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArithToAPFloat/CMakeLists.txt
@@ -15,4 +15,5 @@
MLIRArithTransforms
MLIRFuncDialect
MLIRFuncUtils
+ MLIRVectorDialect
)
diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
index ab05ede..bd4a9da 100644
--- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
+++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir
@@ -288,3 +288,42 @@
%8 = arith.sitofp %6 : i92 to f32
return
}
+
+// -----
+
+// CHECK-LABEL: func.func @addf_vector
+// CHECK-2: vector.to_elements
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: arith.bitcast
+// CHECK: arith.extui
+// CHECK: call
+// CHECK: arith.trunci
+
+// CHECK: vector.from_elements
+func.func @addf_vector(%arg0: vector<4xf4E2M1FN>, %arg1: vector<4xf4E2M1FN>) {
+ %0 = arith.addf %arg0, %arg1 : vector<4xf4E2M1FN>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
new file mode 100644
index 0000000..3b94dc4
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation-vector.mlir
@@ -0,0 +1,26 @@
+// REQUIRES: system-linux
+// TODO: Run only on Linux until we figure out how to build
+// mlir_apfloat_wrappers in a platform-independent way.
+
+// All floating-point arithmetics is lowered through APFloat.
+// RUN: mlir-opt %s --convert-arith-to-apfloat --convert-scf-to-vector | \
+// RUN: --convert-scf-to-cf --convert-to-llvm | \
+// RUN: mlir-runner -e entry --entry-point-result=void \
+// RUN: --shared-libs=%mlir_c_runner_utils \
+// RUN: --shared-libs=%mlir_apfloat_wrappers | FileCheck %s
+
+// Put rhs into separate function so that it won't be constant-folded.
+func.func @foo_vec() -> (vector<4xf8E4M3FN>, vector<4xf32>) {
+ %cst1 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf8E4M3FN>
+ %cst2 = arith.constant dense<[2.2, 2.2, 2.2, 2.2]> : vector<4xf32>
+ return %cst1, %cst2 : vector<4xf8E4M3FN>, vector<4xf32>
+}
+
+func.func @entry() {
+ // CHECK-NEXT: ( 3.5, 3.5, 3.5, 3.5 )
+ %a1_vec = arith.constant dense<[1.4, 1.4, 1.4, 1.4]> : vector<4xf8E4M3FN>
+ %b1_vec, %b2_vec = func.call @foo_vec() : () -> (vector<4xf8E4M3FN>, vector<4xf32>)
+ %c1_vec = arith.addf %a1_vec, %b1_vec : vector<4xf8E4M3FN> // not supported by LLVM
+ vector.print %c1_vec : vector<4xf8E4M3FN>
+ return
+}