blob: 536373b82c67f7b0d951e9c4e1ce193b8ab14b02 [file] [log] [blame]
//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using DupQLaneLowering =
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
using ScalableMaskedAddFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
ScalableMaskedAddFIntrOp>;
using ScalableMaskedSubIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
ScalableMaskedSubIIntrOp>;
using ScalableMaskedSubFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
ScalableMaskedSubFIntrOp>;
using ScalableMaskedMulIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
ScalableMaskedMulIIntrOp>;
using ScalableMaskedMulFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
ScalableMaskedMulFIntrOp>;
using ScalableMaskedSDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
ScalableMaskedSDivIIntrOp>;
using ScalableMaskedUDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
ScalableMaskedUDivIIntrOp>;
using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
namespace {
/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
///
/// Example:
/// ```
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
/// ```
/// is rewritten into:
/// ```
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Op convertOp, typename Op::Adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = convertOp.getLoc();
auto source = convertOp.getSource();
VectorType sourceType = source.getType();
VectorType resultType = convertOp.getResult().getType();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));
// We want to iterate over the input vector in steps of the trailing
// dimension. So this creates tile shape where all leading dimensions are 1,
// and the trailing dimension step is the size of the dimension.
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
tileShape.back() = sourceType.getShape().back();
// Iterate over all scalable mask/predicate slices of the source vector.
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
auto sourceVector = rewriter.create<vector::ExtractOp>(
loc, source, extractOrInsertPosition);
VectorType convertedType =
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
.setDim(0, resultType.getShape().back());
auto convertedVector =
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
extractOrInsertPosition);
}
rewriter.replaceOp(convertOp, result);
return success();
}
};
using ConvertToSvboolOpLowering =
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
/// but first input (P1) and result predicates need conversion to/from svbool.
struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
auto loc = pselOp.getLoc();
auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
adaptor.getP1());
auto indexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), pselOp.getIndex());
auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
pselOp.getP2(), indexI32);
rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
pselOp, adaptor.getP1().getType(), pselIntr);
return success();
}
};
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
/// for more details. Note that we can't use (the more general) active.lane.mask
/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
/// `n` is zero (whereas `create_mask` just returns an all-false mask).
struct CreateMaskOpLowering
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::CreateMaskOp createMaskOp,
vector::CreateMaskOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto maskType = createMaskOp.getVectorType();
if (maskType.getRank() != 1 || !maskType.isScalable())
return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
// TODO: Support masks which are multiples of SVE predicates.
auto maskBaseSize = maskType.getDimSize(0);
if (maskBaseSize < 2 || maskBaseSize > 16 ||
!llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
return rewriter.notifyMatchFailure(createMaskOp,
"not SVE predicate-sized");
auto loc = createMaskOp.getLoc();
auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
adaptor.getOperands()[0]);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// Populate conversion patterns
// clang-format off
patterns.add<ConvertFromSvboolOpLowering,
ConvertToSvboolOpLowering,
DupQLaneLowering,
PselOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedDivFOpLowering,
ScalableMaskedMulFOpLowering,
ScalableMaskedMulIOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedSubFOpLowering,
ScalableMaskedSubIOpLowering,
ScalableMaskedUDivIOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
ZipX2OpLowering,
ZipX4OpLowering,
SdotOpLowering>(converter);
// Add vector.create_mask conversion with a high benefit as it produces much
// nicer code than the generic lowering.
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
// clang-format on
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedDivFIntrOp,
ScalableMaskedMulFIntrOp,
ScalableMaskedMulIIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedSubFIntrOp,
ScalableMaskedSubIIntrOp,
ScalableMaskedUDivIIntrOp,
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
WhileLTIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
SdotIntrOp>();
target.addIllegalOp<ConvertFromSvboolOp,
ConvertToSvboolOp,
DupQLaneOp,
PselOp,
ScalableMaskedAddFOp,
ScalableMaskedAddIOp,
ScalableMaskedDivFOp,
ScalableMaskedMulFOp,
ScalableMaskedMulIOp,
ScalableMaskedSDivIOp,
ScalableMaskedSubFOp,
ScalableMaskedSubIOp,
ScalableMaskedUDivIOp,
SmmlaOp,
UdotOp,
UmmlaOp,
ZipX2Op,
ZipX4Op,
SdotOp>();
// clang-format on
}