blob: 0fbd54b2375be29686d942ecffdf147f7cab7f63 [file] [log] [blame]
//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::x86vector;
/// Extracts the "main" vector element type from the given X86Vector operation.
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
return op.src().getType().template cast<VectorType>().getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
return op.a().getType().template cast<VectorType>().getElementType();
}
namespace {
/// Base conversion for AVX512 ops that can be lowered to one of the two
/// intrinsics based on the bitwidth of their "main" vector element type. This
/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
/// results of multi-result intrinsic ops.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
explicit LowerToIntrinsic(LLVMTypeConverter &converter)
: OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
LLVMTypeConverter &getTypeConverter() const {
return *static_cast<LLVMTypeConverter *>(
OpConversionPattern<OpTy>::getTypeConverter());
}
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type elementType = getSrcVectorElementType<OpTy>(op);
unsigned bitwidth = elementType.getIntOrFloatBitWidth();
if (bitwidth == 32)
return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
adaptor.getOperands(),
getTypeConverter(), rewriter);
if (bitwidth == 64)
return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
adaptor.getOperands(),
getTypeConverter(), rewriter);
return rewriter.notifyMatchFailure(
op, "expected 'src' to be either f32 or f64");
}
};
struct MaskCompressOpConversion
: public ConvertOpToLLVMPattern<MaskCompressOp> {
using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.a().getType();
Value src;
if (op.src()) {
src = adaptor.src();
} else if (op.constant_src()) {
src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
op.constant_srcAttr());
} else {
Attribute zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
}
rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
src, adaptor.k());
return success();
}
};
struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.a().getType();
rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
return success();
}
};
struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = adaptor.a().getType();
Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
// Dot product of all elements, broadcasted to all elements.
auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
Value scale =
rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
scale);
return success();
}
};
/// An entry associating the "main" AVX512 op with its instantiations for
/// vectors of 32-bit and 64-bit elements.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct RegEntry {
using MainOp = OpTy;
using Intr32Op = Intr32OpTy;
using Intr64Op = Intr64OpTy;
};
/// A container for op association entries facilitating the configuration of
/// dialect conversion.
template <typename... Args>
struct RegistryImpl {
/// Registers the patterns specializing the "main" op to one of the
/// "intrinsic" ops depending on elemental type.
static void registerPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns
.add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
typename Args::Intr64Op>...>(converter);
}
/// Configures the conversion target to lower out "main" ops.
static void configureTarget(LLVMConversionTarget &target) {
target.addIllegalOp<typename Args::MainOp...>();
target.addLegalOp<typename Args::Intr32Op...>();
target.addLegalOp<typename Args::Intr64Op...>();
}
};
using Registry = RegistryImpl<
RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
} // namespace
/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
Registry::registerPatterns(converter, patterns);
patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
converter);
}
void mlir::configureX86VectorLegalizeForExportTarget(
LLVMConversionTarget &target) {
Registry::configureTarget(target);
target.addLegalOp<MaskCompressIntrOp>();
target.addIllegalOp<MaskCompressOp>();
target.addLegalOp<RsqrtIntrOp>();
target.addIllegalOp<RsqrtOp>();
target.addLegalOp<DotIntrOp>();
target.addIllegalOp<DotOp>();
}