blob: 0f1c4011b231c427a3bfeb032608bc406890725b [file] [log] [blame]
//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===//
//
/// Part of the LLVM Project, under the Apache License v2.0 with LLVM
/// Exceptions. See https://llvm.org/LICENSE.txt for license information.
/// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This file implements target-independent rewrites of MultiDimReductionOp.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#define DEBUG_TYPE "vector-multi-reduction"
using namespace mlir;
/// This file implements the following transformations as composable atomic
/// patterns.
/// Converts vector.multi_reduction into inner-most/outer-most reduction form
/// by using vector.transpose
class InnerOuterDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
explicit InnerOuterDimReductionConversion(
MLIRContext *context, vector::VectorMultiReductionLowering options)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto src = multiReductionOp.source();
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Separate reduction and parallel dims
auto reductionDimsRange =
multiReductionOp.reduction_dims().getAsValueRange<IntegerAttr>();
auto reductionDims = llvm::to_vector<4>(llvm::map_range(
reductionDimsRange, [](APInt a) { return a.getZExtValue(); }));
llvm::SmallDenseSet<int64_t> reductionDimsSet(reductionDims.begin(),
reductionDims.end());
int64_t reductionSize = reductionDims.size();
SmallVector<int64_t, 4> parallelDims;
for (int64_t i = 0; i < srcRank; ++i)
if (!reductionDimsSet.contains(i))
parallelDims.push_back(i);
// Add transpose only if inner-most/outer-most dimensions are not parallel
if (useInnerDimsForReduction &&
(parallelDims ==
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
return failure();
if (!useInnerDimsForReduction &&
(parallelDims !=
llvm::to_vector<4>(llvm::seq<int64_t>(0, parallelDims.size()))))
return failure();
SmallVector<int64_t, 4> indices;
if (useInnerDimsForReduction) {
indices.append(parallelDims.begin(), parallelDims.end());
indices.append(reductionDims.begin(), reductionDims.end());
} else {
indices.append(reductionDims.begin(), reductionDims.end());
indices.append(parallelDims.begin(), parallelDims.end());
}
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
SmallVector<bool> reductionMask(srcRank, false);
for (int i = 0; i < reductionSize; ++i) {
if (useInnerDimsForReduction)
reductionMask[srcRank - i - 1] = true;
else
reductionMask[i] = true;
}
rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
multiReductionOp, transposeOp.result(), reductionMask,
multiReductionOp.kind());
return success();
}
private:
const bool useInnerDimsForReduction;
};
/// Reduces the rank of vector.multi_reduction nd -> 2d given all reduction
/// dimensions are either inner most or outer most.
class ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
explicit ReduceMultiDimReductionRank(
MLIRContext *context, vector::VectorMultiReductionLowering options)
: mlir::OpRewritePattern<vector::MultiDimReductionOp>(context),
useInnerDimsForReduction(
options == vector::VectorMultiReductionLowering::InnerReduction) {}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
auto loc = multiReductionOp.getLoc();
// If rank less than 2, nothing to do.
if (srcRank < 2)
return failure();
// If already rank-2 ["parallel", "reduce"] or ["reduce", "parallel"] bail.
SmallVector<bool> reductionMask = multiReductionOp.getReductionMask();
if (srcRank == 2 && reductionMask.front() != reductionMask.back())
return failure();
// 1. Separate reduction and parallel dims.
SmallVector<int64_t, 4> parallelDims, parallelShapes;
SmallVector<int64_t, 4> reductionDims, reductionShapes;
for (auto it : llvm::enumerate(reductionMask)) {
int64_t i = it.index();
bool isReduction = it.value();
if (isReduction) {
reductionDims.push_back(i);
reductionShapes.push_back(srcShape[i]);
} else {
parallelDims.push_back(i);
parallelShapes.push_back(srcShape[i]);
}
}
// 2. Compute flattened parallel and reduction sizes.
int flattenedParallelDim = 0;
int flattenedReductionDim = 0;
if (parallelShapes.size() > 0) {
flattenedParallelDim = 1;
for (auto d : parallelShapes)
flattenedParallelDim *= d;
}
if (reductionShapes.size() > 0) {
flattenedReductionDim = 1;
for (auto d : reductionShapes)
flattenedReductionDim *= d;
}
// We must at least have some parallel or some reduction.
assert((flattenedParallelDim || flattenedReductionDim) &&
"expected at least one parallel or reduction dim");
// 3. Fail if reduction/parallel dims are not contiguous.
// Check parallelDims are exactly [0 .. size).
int64_t counter = 0;
if (useInnerDimsForReduction &&
llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
return failure();
// Check parallelDims are exactly {reductionDims.size()} + [0 .. size).
counter = reductionDims.size();
if (!useInnerDimsForReduction &&
llvm::any_of(parallelDims, [&](int64_t i) { return i != counter++; }))
return failure();
// 4. Shape cast to collapse consecutive parallel (resp. reduction dim) into
// a single parallel (resp. reduction) dim.
SmallVector<bool, 2> mask;
SmallVector<int64_t, 2> vectorShape;
if (flattenedParallelDim) {
mask.push_back(false);
vectorShape.push_back(flattenedParallelDim);
}
if (flattenedReductionDim) {
mask.push_back(true);
vectorShape.push_back(flattenedReductionDim);
}
if (!useInnerDimsForReduction && vectorShape.size() == 2) {
std::swap(mask.front(), mask.back());
std::swap(vectorShape.front(), vectorShape.back());
}
auto castedType = VectorType::get(
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.source());
// 5. Creates the flattened form of vector.multi_reduction with inner/outer
// most dim as reduction.
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
loc, cast, mask, multiReductionOp.kind());
// 6. If there are no parallel shapes, the result is a scalar.
// TODO: support 0-d vectors when available.
if (parallelShapes.empty()) {
rewriter.replaceOp(multiReductionOp, newOp.dest());
return success();
}
// 7. Creates shape cast for the output n-D -> 2-D
VectorType outputCastedType = VectorType::get(
parallelShapes,
multiReductionOp.getSourceVectorType().getElementType());
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
multiReductionOp, outputCastedType, newOp.dest());
return success();
}
private:
const bool useInnerDimsForReduction;
};
/// Unrolls vector.multi_reduction with outermost reductions
/// and combines results
struct TwoDimMultiReductionToElementWise
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-2 ["parallel", "reduce"] or bail.
if (srcRank != 2)
return failure();
if (multiReductionOp.isReducedDim(1) || !multiReductionOp.isReducedDim(0))
return failure();
auto loc = multiReductionOp.getLoc();
ArrayRef<int64_t> srcShape =
multiReductionOp.getSourceVectorType().getShape();
Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
if (!elementType.isIntOrIndexOrFloat())
return failure();
Value condition;
Value result =
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), 0)
.getResult();
for (int64_t i = 1; i < srcShape[0]; i++) {
auto operand =
rewriter.create<vector::ExtractOp>(loc, multiReductionOp.source(), i);
switch (multiReductionOp.kind()) {
case vector::CombiningKind::ADD:
if (elementType.isIntOrIndex())
result = rewriter.create<arith::AddIOp>(loc, operand, result);
else
result = rewriter.create<arith::AddFOp>(loc, operand, result);
break;
case vector::CombiningKind::MUL:
if (elementType.isIntOrIndex())
result = rewriter.create<arith::MulIOp>(loc, operand, result);
else
result = rewriter.create<arith::MulFOp>(loc, operand, result);
break;
case vector::CombiningKind::MINUI:
result = rewriter.create<arith::MinUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINSI:
result = rewriter.create<arith::MinSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINF:
result = rewriter.create<arith::MinFOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXUI:
result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXSI:
result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXF:
result = rewriter.create<arith::MaxFOp>(loc, operand, result);
break;
case vector::CombiningKind::AND:
result = rewriter.create<arith::AndIOp>(loc, operand, result);
break;
case vector::CombiningKind::OR:
result = rewriter.create<arith::OrIOp>(loc, operand, result);
break;
case vector::CombiningKind::XOR:
result = rewriter.create<arith::XOrIOp>(loc, operand, result);
break;
}
}
rewriter.replaceOp(multiReductionOp, result);
return success();
}
};
/// Converts 2d vector.multi_reduction with inner most reduction dimension into
/// a sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
if (srcRank != 2)
return failure();
if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
return failure();
auto loc = multiReductionOp.getLoc();
Value result = rewriter.create<ConstantOp>(
loc, multiReductionOp.getDestType(),
rewriter.getZeroAttr(multiReductionOp.getDestType()));
int outerDim = multiReductionOp.getSourceVectorType().getShape()[0];
// TODO: Add vector::CombiningKind attribute instead of string to
// vector.reduction.
auto getKindStr = [](vector::CombiningKind kind) {
switch (kind) {
case vector::CombiningKind::ADD:
return "add";
case vector::CombiningKind::MUL:
return "mul";
case vector::CombiningKind::MINUI:
return "minui";
case vector::CombiningKind::MINSI:
return "minsi";
case vector::CombiningKind::MINF:
return "minf";
case vector::CombiningKind::MAXUI:
return "maxui";
case vector::CombiningKind::MAXSI:
return "maxsi";
case vector::CombiningKind::MAXF:
return "maxf";
case vector::CombiningKind::AND:
return "and";
case vector::CombiningKind::OR:
return "or";
case vector::CombiningKind::XOR:
return "xor";
}
llvm_unreachable("unknown combining kind");
};
for (int i = 0; i < outerDim; ++i) {
auto v = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.source(), ArrayRef<int64_t>{i});
auto reducedValue = rewriter.create<vector::ReductionOp>(
loc, getElementTypeOrSelf(multiReductionOp.getDestType()),
rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v,
ValueRange{});
result = rewriter.create<vector::InsertElementOp>(
loc, reducedValue, result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
}
rewriter.replaceOp(multiReductionOp, result);
return success();
}
};
/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
/// form with both a single parallel and reduction dimension.
/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
/// The case with a single parallel dimension is a noop and folds away
/// separately.
struct OneDimMultiReductionToTwoDim
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-1 or bail.
if (srcRank != 1)
return failure();
auto loc = multiReductionOp.getLoc();
auto srcVectorType = multiReductionOp.getSourceVectorType();
auto srcShape = srcVectorType.getShape();
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
srcVectorType.getElementType());
assert(!multiReductionOp.getDestType().isa<VectorType>() &&
"multi_reduction with a single dimension expects a scalar result");
// If the unique dim is reduced and we insert a parallel in front, we need a
// {false, true} mask.
SmallVector<bool, 2> mask{false, true};
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.source());
Value reduced = rewriter.create<vector::MultiDimReductionOp>(
loc, cast, mask, multiReductionOp.kind());
rewriter.replaceOpWithNewOp<vector::ExtractOp>(multiReductionOp, reduced,
ArrayRef<int64_t>{0});
return success();
}
};
void mlir::vector::populateVectorMultiReductionLoweringPatterns(
RewritePatternSet &patterns, VectorMultiReductionLowering options) {
patterns.add<InnerOuterDimReductionConversion, ReduceMultiDimReductionRank>(
patterns.getContext(), options);
patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext());
if (options == VectorMultiReductionLowering ::InnerReduction)
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext());
else
patterns.add<TwoDimMultiReductionToElementWise>(patterns.getContext());
}