blob: 4d67e07c12a9f4f0617f1e69e999c9eb89fe46ee [file] [log] [blame]
//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorLowering
: public PassWrapper<TestVectorToVectorLowering, FunctionPass> {
TestVectorToVectorLowering() = default;
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
llvm::cl::init(false)};
void runOnFunction() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
if (unroll) {
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
populateBubbleVectorBitCastOpPatterns(patterns);
populateCastAwayVectorLeadingOneDimPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
private:
// Return the target shape based on op type.
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
if (isa<arith::AddFOp, SelectOp, arith::CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
// For transfer ops, just propagate the shape coming from
// InsertStridedSlices/ExtractStridedSlices.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
VectorType dstVec;
for (Operation *users : readOp->getUsers()) {
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
if (!extract)
return llvm::None;
auto vecType = extract.getResult().getType().cast<VectorType>();
if (dstVec && dstVec != vecType)
return llvm::None;
dstVec = vecType;
}
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
auto insert = writeOp.vector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return llvm::None;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
}
return llvm::None;
}
static LogicalResult filter(Operation *op) {
return success(isa<arith::AddFOp, SelectOp, arith::CmpFOp, ContractionOp,
TransferReadOp, TransferWriteOp>(op));
}
};
struct TestVectorContractionLowering
: public PassWrapper<TestVectorContractionLowering, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionLowering() = default;
TestVectorContractionLowering(const TestVectorContractionLowering &pass) {}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToFilterOuterProduct{
*this, "vector-filter-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(options,
&getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
// Test on one pattern in isolation.
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(
options, &getContext(), [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
return failure();
return success();
});
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions options{contractLowering,
vectorMultiReductionLowering,
VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-transpose-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorTransposeLowering() = default;
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {}
Option<bool> lowerToEltwise{
*this, "eltwise",
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "flat",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "shuffle",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToAvx2{
*this, "avx2",
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
llvm::cl::init(false)};
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
// Explicitly disable shape_cast lowering.
LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
.enableVectorTransposeLowering()
.enableShapeCastLowering(false);
if (lowerToEltwise) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::EltWise));
}
if (lowerToFlatTranspose) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::Flat));
}
if (lowerToShuffleTranspose) {
options = options.setVectorTransformsOptions(
VectorTransformsOptions().setVectorTransposeLowering(
VectorTransposeLowering::Shuffle));
}
if (lowerToAvx2) {
options = options.enableAVX2Lowering().setAVX2LoweringOptions(
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32()
.lower8x8xf32()));
}
OpPassManager dynamicPM("builtin.func");
dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
if (failed(runPipeline(dynamicPM, getFunction())))
return signalPassFailure();
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<arith::AddFOp, vector::FMAOp>(op));
}));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
SmallVector<int64_t, 4> nativeShape = {4, 4, 2};
if (auto floatType = contractOp.getLhsType()
.getElementType()
.dyn_cast<FloatType>()) {
if (floatType.getWidth() == 16) {
nativeShape[2] = 4;
}
}
return nativeShape;
};
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
} else {
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
llvm::cl::init(false)};
};
struct TestVectorDistributePatterns
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
StringRef getArgument() const final {
return "test-vector-distribute-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to distribute vector ops in the vector "
"dialect";
}
TestVectorDistributePatterns() = default;
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<VectorDialect>();
registry.insert<AffineDialect>();
}
ListOption<int32_t> multiplicity{
*this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Set the multiplicity used for distributing vector")};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
FuncOp func = getFunction();
func.walk([&](arith::AddFOp op) {
OpBuilder builder(op);
if (auto vecType = op.getType().dyn_cast<VectorType>()) {
SmallVector<int64_t, 2> mul;
SmallVector<AffineExpr, 2> perm;
SmallVector<Value, 2> ids;
unsigned count = 0;
// Remove the multiplicity of 1 and calculate the affine map based on
// the multiplicity.
SmallVector<int32_t, 4> m(multiplicity.begin(), multiplicity.end());
for (unsigned i = 0, e = vecType.getRank(); i < e; i++) {
if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) {
mul.push_back(m[i]);
ids.push_back(func.getArgument(count++));
perm.push_back(getAffineDimExpr(i, ctx));
}
}
auto map = AffineMap::get(op.getType().cast<VectorType>().getRank(), 0,
perm, ctx);
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
builder, op.getOperation(), ids, mul, map);
if (ops.hasValue()) {
SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
op.getResult().replaceAllUsesExcept(ops->insert.getResult(),
extractOp);
}
}
});
populatePropagateVectorDistributionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorToLoopPatterns
: public PassWrapper<TestVectorToLoopPatterns, FunctionPass> {
StringRef getArgument() const final { return "test-vector-to-forloop"; }
StringRef getDescription() const final {
return "Test lowering patterns to break up a vector op into a for loop";
}
TestVectorToLoopPatterns() = default;
TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<VectorDialect>();
registry.insert<AffineDialect>();
}
Option<int32_t> multiplicity{
*this, "distribution-multiplicity",
llvm::cl::desc("Set the multiplicity used for distributing vector"),
llvm::cl::init(32)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
FuncOp func = getFunction();
func.walk([&](arith::AddFOp op) {
// Check that the operation type can be broken down into a loop.
VectorType type = op.getType().dyn_cast<VectorType>();
if (!type || type.getRank() != 1 ||
type.getNumElements() % multiplicity != 0)
return mlir::WalkResult::advance();
auto filterAlloc = [](Operation *op) {
if (isa<arith::ConstantOp, memref::AllocOp, CallOp>(op))
return false;
return true;
};
auto dependentOps = getSlice(op, filterAlloc);
// Create a loop and move instructions from the Op slice into the loop.
OpBuilder builder(op);
auto zero = builder.create<arith::ConstantIndexOp>(op.getLoc(), 0);
auto one = builder.create<arith::ConstantIndexOp>(op.getLoc(), 1);
auto numIter =
builder.create<arith::ConstantIndexOp>(op.getLoc(), multiplicity);
auto forOp = builder.create<scf::ForOp>(op.getLoc(), zero, numIter, one);
for (Operation *it : dependentOps) {
it->moveBefore(forOp.getBody()->getTerminator());
}
auto map = AffineMap::getMultiDimIdentityMap(1, ctx);
// break up the original op and let the patterns propagate.
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity},
map);
if (ops.hasValue()) {
SmallPtrSet<Operation *, 1> extractOp({ops->extract, ops->insert});
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
}
return mlir::WalkResult::interrupt();
});
populatePropagateVectorDistributionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
}));
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
FunctionPass> {
StringRef getArgument() const final {
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"linalg.copy operations."),
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, FunctionPass> {
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void runOnFunction() override { transferOpflowOpt(getFunction()); }
};
struct TestVectorTransferLoweringPatterns
: public PassWrapper<TestVectorTransferLoweringPatterns, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorTransferPermutationMapLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorMultiReductionLoweringPatterns
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
FunctionPass> {
TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns(
const TestVectorMultiReductionLoweringPatterns &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
*this, "use-outer-reductions",
llvm::cl::desc("Move reductions to outer most dimensions"),
llvm::cl::init(false)};
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(
patterns, useOuterReductions
? vector::VectorMultiReductionLowering::InnerParallel
: vector::VectorMultiReductionLowering::InnerReduction);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
FunctionPass> {
TestVectorTransferCollapseInnerMostContiguousDims() = default;
TestVectorTransferCollapseInnerMostContiguousDims(
const TestVectorTransferCollapseInnerMostContiguousDims &pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-collapse-inner-most-dims";
}
StringRef getDescription() const final {
return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
FunctionPass> {
StringRef getArgument() const final {
return "test-vector-reduction-to-contract-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert multireduce op to contract and combine "
"broadcast/transpose to contract";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();
PassRegistration<TestVectorContractionLowering>();
PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns>();
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
PassRegistration<TestVectorDistributePatterns>();
PassRegistration<TestVectorToLoopPatterns>();
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns>();
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
}
} // namespace test
} // namespace mlir