blob: 80619f204df7063b748bd235f972898c2ea12f38 [file] [log] [blame] [edit]
//===-- MathToNVVM.cpp - conversion from Math to CUDA libdevice calls ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/Pass.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTONVVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
#define DEBUG_TYPE "math-to-nvvm"
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
PatternBenefit benefit, StringRef f32Func,
StringRef f64Func, StringRef f32ApproxFunc = "",
StringRef f16Func = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
f32ApproxFunc, f16Func,
/*i32Func=*/"", benefit);
}
template <typename OpTy>
static void populateIntOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
PatternBenefit benefit, StringRef i32Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
benefit);
}
template <typename OpTy>
static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
PatternBenefit benefit,
StringRef f32Func, StringRef f64Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
/*i32Func=*/"", benefit);
}
// Custom pattern for sincos since it returns two values
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = adaptor.getOperand();
Type inputType = input.getType();
auto convertedInput = maybeExt(input, rewriter);
auto computeType = convertedInput.getType();
StringRef sincosFunc;
if (isa<Float32Type>(computeType)) {
const arith::FastMathFlags flag = op.getFastmath();
const bool useApprox =
mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
} else if (isa<Float64Type>(computeType)) {
sincosFunc = "__nv_sincos";
} else {
return rewriter.notifyMatchFailure(op,
"unsupported operand type for sincos");
}
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value sinPtr, cosPtr;
{
OpBuilder::InsertionGuard guard(rewriter);
auto *scope =
op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(1));
sinPtr =
LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
cosPtr =
LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
}
createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
op);
auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
maybeTrunc(cosResult, inputType, rewriter)});
return success();
}
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType()))
return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
Float32Type::get(rewriter.getContext()),
operand);
return operand;
}
Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
if (operand.getType() != type)
return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
return operand;
}
void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
StringRef funcName, Value input, Value sinPtr,
Value cosPtr, Operation *op) const {
auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
auto ptrType = sinPtr.getType();
SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
auto funcAttr = StringAttr::get(op->getContext(), funcName);
auto funcOp =
SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
if (!funcOp) {
auto parentFunc = op->getParentOfType<FunctionOpInterface>();
assert(parentFunc && "expected there to be a parent function");
OpBuilder b(parentFunc);
auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
}
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
}
};
void mlir::populateLibDeviceConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
PatternBenefit benefit) {
populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
"__nv_fmod");
populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
"__nv_fmaxf", "__nv_fmax");
populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
"__nv_fminf", "__nv_fmin");
populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
"__nv_fabs");
populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
"__nv_acos");
populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
"__nv_acosh");
populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
"__nv_asin");
populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
"__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
"__nv_atan2");
populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
"__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
"__nv_ceil");
populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
"__nv_copysignf", "__nv_copysign");
populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
"__nv_cos", "__nv_fast_cosf");
populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
"__nv_cosh");
populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
"__nv_erf");
populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
"__nv_erfc");
populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
"__nv_exp", "__nv_fast_expf");
populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
"__nv_expm1");
populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
"__nv_floor");
populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
"__nv_fma");
// Note: libdevice uses a different name for 32-bit finite checking
populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
"__nv_finitef", "__nv_isfinited");
populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
"__nv_isinfd");
populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
"__nv_isnand");
populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
"__nv_log", "__nv_fast_logf");
populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
"__nv_log10", "__nv_fast_log10f");
populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
"__nv_log1p");
populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
"__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
"__nv_pow", "__nv_fast_powf");
populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
"__nv_powif", "__nv_powi");
populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
"__nv_round");
populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
"__nv_rintf", "__nv_rint");
populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
"__nv_rsqrt");
populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
"__nv_sin", "__nv_fast_sinf");
populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
"__nv_sinh");
populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
"__nv_sqrt");
populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
"__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
"__nv_tanh");
// Custom pattern for sincos since it returns two values
patterns.add<SincosOpLowering>(converter, benefit);
}
namespace {
struct ConvertMathToNVVMPass final
: impl::ConvertMathToNVVMBase<ConvertMathToNVVMPass> {
using impl::ConvertMathToNVVMBase<
ConvertMathToNVVMPass>::ConvertMathToNVVMBase;
void runOnOperation() override;
};
} // namespace
void ConvertMathToNVVMPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateLibDeviceConversionPatterns(converter, patterns, /*benefit=*/1);
ConversionTarget target(getContext());
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, NVVM::NVVMDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
LLVM::SqrtOp>();
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
}