blob: 07078c77626e1348c750f03a501ed63445cf5bd0 [file] [log] [blame]
//===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
using AddIOpLowering = VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp>;
using SubIOpLowering = VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp>;
using MulIOpLowering = VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp>;
using DivUIOpLowering =
VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using RemUIOpLowering =
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
using ShRUIOpLowering =
VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
using ShRSIOpLowering =
VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
using NegFOpLowering = VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp>;
using AddFOpLowering = VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp>;
using SubFOpLowering = VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp>;
using MulFOpLowering = VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp>;
using DivFOpLowering = VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp>;
using RemFOpLowering = VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp>;
using ExtUIOpLowering =
VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
using ExtSIOpLowering =
VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
using TruncFOpLowering =
VectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp>;
using UIToFPOpLowering =
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using FPToUIOpLowering =
VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
using FPToSIOpLowering =
VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
//===----------------------------------------------------------------------===//
// Op Lowering Patterns
//===----------------------------------------------------------------------===//
/// Directly lower to LLVM op.
struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
using ConvertOpToLLVMPattern<arith::ConstantOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// The lowering of index_cast becomes an integer conversion since index
/// becomes an integer. If the bit width of the source and target integer
/// types is the same, just erase the cast. If the target type is wider,
/// sign-extend the value, otherwise truncate it.
struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
using ConvertOpToLLVMPattern<arith::IndexCastOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
using ConvertOpToLLVMPattern<arith::CmpIOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
using ConvertOpToLLVMPattern<arith::CmpFOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// ConstantOpLowering
//===----------------------------------------------------------------------===//
LogicalResult
ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
adaptor.getOperands(),
*getTypeConverter(), rewriter);
}
//===----------------------------------------------------------------------===//
// IndexCastOpLowering
//===----------------------------------------------------------------------===//
LogicalResult IndexCastOpLowering::matchAndRewrite(
arith::IndexCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto targetType = typeConverter->convertType(op.getResult().getType());
auto targetElementType =
typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
.cast<IntegerType>();
auto sourceElementType =
getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
unsigned targetBits = targetElementType.getWidth();
unsigned sourceBits = sourceElementType.getWidth();
if (targetBits == sourceBits)
rewriter.replaceOp(op, adaptor.getIn());
else if (targetBits < sourceBits)
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
else
rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
return success();
}
//===----------------------------------------------------------------------===//
// CmpIOpLowering
//===----------------------------------------------------------------------===//
// Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
// share numerical values so just cast.
template <typename LLVMPredType, typename PredType>
static LLVMPredType convertCmpPredicate(PredType pred) {
return static_cast<LLVMPredType>(pred);
}
LogicalResult
CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto operandType = adaptor.getLhs().getType();
auto resultType = op.getResult().getType();
// Handle the scalar and 1D vector cases.
if (!operandType.isa<LLVM::LLVMArrayType>()) {
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
return rewriter.create<LLVM::ICmpOp>(
op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
},
rewriter);
return success();
}
//===----------------------------------------------------------------------===//
// CmpFOpLowering
//===----------------------------------------------------------------------===//
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto operandType = adaptor.getLhs().getType();
auto resultType = op.getResult().getType();
// Handle the scalar and 1D vector cases.
if (!operandType.isa<LLVM::LLVMArrayType>()) {
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
op, typeConverter->convertType(resultType),
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
OpAdaptor adaptor(operands);
return rewriter.create<LLVM::FCmpOp>(
op.getLoc(), llvm1DVectorTy,
convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
adaptor.getLhs(), adaptor.getRhs());
},
rewriter);
}
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
namespace {
struct ConvertArithmeticToLLVMPass
: public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
ConvertArithmeticToLLVMPass() = default;
void runOnFunction() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter converter(&getContext(), options);
mlir::arith::populateArithmeticToLLVMConversionPatterns(converter,
patterns);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//
void mlir::arith::populateArithmeticToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// clang-format off
patterns.add<
ConstantOpLowering,
AddIOpLowering,
SubIOpLowering,
MulIOpLowering,
DivUIOpLowering,
DivSIOpLowering,
RemUIOpLowering,
RemSIOpLowering,
AndIOpLowering,
OrIOpLowering,
XOrIOpLowering,
ShLIOpLowering,
ShRUIOpLowering,
ShRSIOpLowering,
NegFOpLowering,
AddFOpLowering,
SubFOpLowering,
MulFOpLowering,
DivFOpLowering,
RemFOpLowering,
ExtUIOpLowering,
ExtSIOpLowering,
ExtFOpLowering,
TruncIOpLowering,
TruncFOpLowering,
UIToFPOpLowering,
SIToFPOpLowering,
FPToUIOpLowering,
FPToSIOpLowering,
IndexCastOpLowering,
BitcastOpLowering,
CmpIOpLowering,
CmpFOpLowering
>(converter);
// clang-format on
}
std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToLLVMPass() {
return std::make_unique<ConvertArithmeticToLLVMPass>();
}