blob: c47701957944735db86eb69e83914c3be270f65e [file] [log] [blame]
//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
if (!type)
return nullptr;
auto *mlirContext = type.getContext();
if (!LLVM::isCompatibleType(type))
emitError(UnknownLoc::get(mlirContext),
"conversion resulted in a non-LLVM type");
return type;
}
static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
LLVMTypeConverter &converter) {
auto elementType = unwrap(converter.convertType(svType.getElementType()));
if (!elementType)
return {};
auto sVectorType =
LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
return sVectorType;
}
template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
rewriter.updateRootInPlace(
op, [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
rewriter.updateRootInPlace(
op, [&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
static Optional<Value> addUnrealizedCast(OpBuilder &builder,
ScalableVectorType svType,
ValueRange inputs, Location loc) {
if (inputs.size() != 1 ||
!inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
return Value();
return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
.getResult(0);
}
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
using ScalableMaskedAddFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
ScalableMaskedAddFIntrOp>;
using ScalableMaskedSubIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
ScalableMaskedSubIIntrOp>;
using ScalableMaskedSubFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
ScalableMaskedSubFIntrOp>;
using ScalableMaskedMulIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
ScalableMaskedMulIIntrOp>;
using ScalableMaskedMulFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
ScalableMaskedMulFIntrOp>;
using ScalableMaskedSDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
ScalableMaskedSDivIIntrOp>;
using ScalableMaskedUDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
ScalableMaskedUDivIIntrOp>;
using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
// Load operation is lowered to code that obtains a pointer to the indexed
// element and loads from it.
struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = loadOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(loadOp.getContext());
auto resultType = loadOp.result().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
loadOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
return success();
}
};
// Store operation is lowered to code that obtains a pointer to the indexed
// element, and stores the given value to it.
struct ScalableStoreOpLowering
: public ConvertOpToLLVMPattern<ScalableStoreOp> {
using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto type = storeOp.getMemRefType();
if (!isConvertibleAndHasIdentityMaps(type))
return failure();
LLVMTypeConverter converter(storeOp.getContext());
auto resultType = storeOp.value().getType();
LLVM::LLVMPointerType llvmDataTypePtr;
if (resultType.isa<VectorType>()) {
llvmDataTypePtr =
LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
} else if (resultType.isa<ScalableVectorType>()) {
llvmDataTypePtr = LLVM::LLVMPointerType::get(
convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
converter)
.getValue());
}
Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(),
adaptor.index(), rewriter);
Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
storeOp.getLoc(), llvmDataTypePtr, dataPtr);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.value(),
bitCastedPtr);
return success();
}
};
static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
>(converter);
// clang-format on
}
static void
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableAddIOp,
ScalableAddFOp,
ScalableSubIOp,
ScalableSubFOp,
ScalableMulIOp,
ScalableMulFOp,
ScalableSDivIOp,
ScalableUDivIOp,
ScalableDivFOp>();
// clang-format on
}
static void
populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns) {
// clang-format off
patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
>(converter);
// clang-format on
}
static void
configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
// clang-format off
target.addIllegalOp<ScalableCmpFOp,
ScalableCmpIOp>();
// clang-format on
}
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// Populate conversion patterns
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, converter);
converter.addConversion([&converter](ScalableVectorType svType) {
return convertScalableVectorTypeToLLVM(svType, converter);
});
converter.addSourceMaterialization(addUnrealizedCast);
// clang-format off
patterns.add<ForwardOperands<CallOp>,
ForwardOperands<CallIndirectOp>,
ForwardOperands<ReturnOp>>(converter,
&converter.getContext());
patterns.add<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
VectorScaleOpLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
ScalableMaskedSubFOpLowering,
ScalableMaskedMulIOpLowering,
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering>(converter);
patterns.add<ScalableLoadOpLowering,
ScalableStoreOpLowering>(converter);
// clang-format on
populateBasicSVEArithmeticExportPatterns(converter, patterns);
populateSVEMaskGenerationExportPatterns(converter, patterns);
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<SdotIntrOp,
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
VectorScaleIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
ScalableMaskedSubFIntrOp,
ScalableMaskedMulIIntrOp,
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
UmmlaOp,
VectorScaleOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
ScalableMaskedSubFOp,
ScalableMaskedMulIOp,
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ScalableLoadOp,
ScalableStoreOp>();
// clang-format on
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
configureBasicSVEArithmeticLegalizations(target);
configureSVEMaskGenerationLegalizations(target);
}