| //===- MathToAPFloat.cpp - Mathmetic to APFloat Conversion ----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Utils.h" |
| |
| #include "mlir/Conversion/ArithAndMathToAPFloat/MathToAPFloat.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Func/Utils/Utils.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/Math/Transforms/Passes.h" |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/Verifier.h" |
| #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_MATHTOAPFLOATCONVERSIONPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::func; |
| |
| struct AbsFOpToAPFloatConversion final : OpRewritePattern<math::AbsFOp> { |
| AbsFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<math::AbsFOp>(context, benefit), symTable(symTable) {} |
| |
| LogicalResult matchAndRewrite(math::AbsFOp op, |
| PatternRewriter &rewriter) const override { |
| // Cast operands to 64-bit integers. |
| auto operand = op.getOperand(); |
| auto floatTy = dyn_cast<FloatType>(operand.getType()); |
| if (!floatTy) |
| return rewriter.notifyMatchFailure(op, |
| "only scalar FloatTypes supported"); |
| if (floatTy.getIntOrFloatBitWidth() > 64) { |
| return rewriter.notifyMatchFailure(op, |
| "bitwidth > 64 bits is not supported"); |
| } |
| // Get APFloat function from runtime library. |
| auto i32Type = IntegerType::get(symTable->getContext(), 32); |
| auto i64Type = IntegerType::get(symTable->getContext(), 64); |
| FailureOr<FuncOp> fn = lookupOrCreateFnDecl( |
| rewriter, symTable, "_mlir_apfloat_abs", {i32Type, i64Type}); |
| if (failed(fn)) |
| return fn; |
| Location loc = op.getLoc(); |
| rewriter.setInsertionPoint(op); |
| auto intWType = rewriter.getIntegerType(floatTy.getWidth()); |
| Value operandBits = arith::ExtUIOp::create( |
| rewriter, loc, i64Type, |
| arith::BitcastOp::create(rewriter, loc, intWType, operand)); |
| |
| // Call APFloat function. |
| Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); |
| SmallVector<Value> params = {semValue, operandBits}; |
| Value negatedBits = func::CallOp::create(rewriter, loc, TypeRange(i64Type), |
| SymbolRefAttr::get(*fn), params) |
| ->getResult(0); |
| |
| // Truncate result to the original width. |
| Value truncatedBits = |
| arith::TruncIOp::create(rewriter, loc, intWType, negatedBits); |
| rewriter.replaceOp( |
| op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits)); |
| return success(); |
| } |
| |
| SymbolOpInterface symTable; |
| }; |
| |
| template <typename OpTy> |
| struct IsOpToAPFloatConversion final : OpRewritePattern<OpTy> { |
| IsOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, |
| SymbolOpInterface symTable, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<OpTy>(context, benefit), symTable(symTable), |
| APFloatName(APFloatName) {}; |
| |
| LogicalResult matchAndRewrite(OpTy op, |
| PatternRewriter &rewriter) const override { |
| // Cast operands to 64-bit integers. |
| auto operand = op.getOperand(); |
| auto floatTy = dyn_cast<FloatType>(operand.getType()); |
| if (!floatTy) |
| return rewriter.notifyMatchFailure(op, |
| "only scalar FloatTypes supported"); |
| if (floatTy.getIntOrFloatBitWidth() > 64) { |
| return rewriter.notifyMatchFailure(op, |
| "bitwidth > 64 bits is not supported"); |
| } |
| // Get APFloat function from runtime library. |
| auto i1 = IntegerType::get(symTable->getContext(), 1); |
| auto i32Type = IntegerType::get(symTable->getContext(), 32); |
| auto i64Type = IntegerType::get(symTable->getContext(), 64); |
| std::string funcName = |
| (llvm::Twine("_mlir_apfloat_is") + APFloatName).str(); |
| FailureOr<FuncOp> fn = lookupOrCreateFnDecl( |
| rewriter, symTable, funcName, {i32Type, i64Type}, nullptr, i1); |
| if (failed(fn)) |
| return fn; |
| Location loc = op.getLoc(); |
| rewriter.setInsertionPoint(op); |
| auto intWType = rewriter.getIntegerType(floatTy.getWidth()); |
| Value operandBits = arith::ExtUIOp::create( |
| rewriter, loc, i64Type, |
| arith::BitcastOp::create(rewriter, loc, intWType, operand)); |
| |
| // Call APFloat function. |
| Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); |
| SmallVector<Value> params = {semValue, operandBits}; |
| rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(i1), |
| SymbolRefAttr::get(*fn), params); |
| return success(); |
| } |
| |
| SymbolOpInterface symTable; |
| const char *APFloatName; |
| }; |
| |
| struct FmaOpToAPFloatConversion final : OpRewritePattern<math::FmaOp> { |
| FmaOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, |
| PatternBenefit benefit = 1) |
| : OpRewritePattern<math::FmaOp>(context, benefit), symTable(symTable) {}; |
| |
| LogicalResult matchAndRewrite(math::FmaOp op, |
| PatternRewriter &rewriter) const override { |
| // Cast operands to 64-bit integers. |
| auto floatTy = cast<FloatType>(op.getResult().getType()); |
| if (!floatTy) |
| return rewriter.notifyMatchFailure(op, |
| "only scalar FloatTypes supported"); |
| if (floatTy.getIntOrFloatBitWidth() > 64) { |
| return rewriter.notifyMatchFailure(op, |
| "bitwidth > 64 bits is not supported"); |
| } |
| |
| auto i32Type = IntegerType::get(symTable->getContext(), 32); |
| auto i64Type = IntegerType::get(symTable->getContext(), 64); |
| FailureOr<FuncOp> fn = lookupOrCreateFnDecl( |
| rewriter, symTable, "_mlir_apfloat_fused_multiply_add", |
| {i32Type, i64Type, i64Type, i64Type}); |
| if (failed(fn)) |
| return fn; |
| Location loc = op.getLoc(); |
| rewriter.setInsertionPoint(op); |
| |
| auto intWType = rewriter.getIntegerType(floatTy.getWidth()); |
| auto int64Type = rewriter.getI64Type(); |
| Value operand = arith::ExtUIOp::create( |
| rewriter, loc, int64Type, |
| arith::BitcastOp::create(rewriter, loc, intWType, op.getA())); |
| Value multiplicand = arith::ExtUIOp::create( |
| rewriter, loc, int64Type, |
| arith::BitcastOp::create(rewriter, loc, intWType, op.getB())); |
| Value addend = arith::ExtUIOp::create( |
| rewriter, loc, int64Type, |
| arith::BitcastOp::create(rewriter, loc, intWType, op.getC())); |
| |
| // Call APFloat function. |
| Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy); |
| SmallVector<Value> params = {semValue, operand, multiplicand, addend}; |
| auto resultOp = |
| func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), |
| SymbolRefAttr::get(*fn), params); |
| |
| // Truncate result to the original width. |
| Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType, |
| resultOp->getResult(0)); |
| rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, floatTy, truncatedBits); |
| return success(); |
| } |
| |
| SymbolOpInterface symTable; |
| }; |
| |
| namespace { |
| struct MathToAPFloatConversionPass final |
| : impl::MathToAPFloatConversionPassBase<MathToAPFloatConversionPass> { |
| using Base::Base; |
| |
| void runOnOperation() override; |
| }; |
| |
| void MathToAPFloatConversionPass::runOnOperation() { |
| MLIRContext *context = &getContext(); |
| RewritePatternSet patterns(context); |
| |
| patterns.add<AbsFOpToAPFloatConversion>(context, getOperation()); |
| patterns.add<IsOpToAPFloatConversion<math::IsFiniteOp>>(context, "finite", |
| getOperation()); |
| patterns.add<IsOpToAPFloatConversion<math::IsInfOp>>(context, "infinite", |
| getOperation()); |
| patterns.add<IsOpToAPFloatConversion<math::IsNaNOp>>(context, "nan", |
| getOperation()); |
| patterns.add<IsOpToAPFloatConversion<math::IsNormalOp>>(context, "normal", |
| getOperation()); |
| patterns.add<FmaOpToAPFloatConversion>(context, getOperation()); |
| |
| LogicalResult result = success(); |
| ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { |
| if (diag.getSeverity() == DiagnosticSeverity::Error) { |
| result = failure(); |
| } |
| // NB: if you don't return failure, no other diag handlers will fire (see |
| // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit). |
| return failure(); |
| }); |
| walkAndApplyPatterns(getOperation(), std::move(patterns)); |
| if (failed(result)) |
| return signalPassFailure(); |
| } |
| } // namespace |