blob: c861935b4bc187b6d01409c2652640f8411ea0d2 [file] [log] [blame]
//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Vector dialect to SPIRV dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <cstdint>
#include <numeric>
using namespace mlir;
/// Returns the integer value from the first valid input element, assuming Value
/// inputs are defined by a constant index ops and Attribute inputs are integer
/// attributes.
static uint64_t getFirstIntValue(ArrayAttr attr) {
return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
}
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
// TODO: This does not take into account any memory layout or widening
// constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
// though in practice it will likely be stored as in a 4xi64 vector register.
if (auto vectorType = dyn_cast<VectorType>(type))
return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
return type.getIntOrFloatBitWidth();
}
namespace {
struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(shapeCastOp.getType());
if (!dstType)
return failure();
// If dstType is same as the source type or the vector size is 1, it can be
// directly replaced by the source.
if (dstType == adaptor.getSource().getType() ||
shapeCastOp.getResultVectorType().getNumElements() == 1) {
rewriter.replaceOp(shapeCastOp, adaptor.getSource());
return success();
}
// Lowering for size-n vectors when n > 1 hasn't been implemented.
return failure();
}
};
// Convert `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to SPIRV via other patterns.
struct VectorSplatToBroadcast final
: public OpConversionPattern<vector::SplatOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(bitcastOp.getType());
if (!dstType)
return failure();
if (dstType == adaptor.getSource().getType()) {
rewriter.replaceOp(bitcastOp, adaptor.getSource());
return success();
}
// Check that the source and destination type have the same bitwidth.
// Depending on the target environment, we may need to emulate certain
// types, which can cause issue with bitcast.
Type srcType = adaptor.getSource().getType();
if (getNumBits(dstType) != getNumBits(srcType)) {
return rewriter.notifyMatchFailure(
bitcastOp,
llvm::formatv("different source ({0}) and target ({1}) bitwidth",
srcType, dstType));
}
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType,
adaptor.getSource());
return success();
}
};
struct VectorBroadcastConvert final
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType =
getTypeConverter()->convertType(castOp.getResultVectorType());
if (!resultType)
return failure();
if (isa<spirv::ScalarType>(resultType)) {
rewriter.replaceOp(castOp, adaptor.getSource());
return success();
}
SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(castOp, resultType,
source);
return success();
}
};
// SPIR-V does not have a concept of a poison index for certain instructions,
// which creates a UB hazard when lowering from otherwise equivalent Vector
// dialect instructions, because this index will be considered out-of-bounds.
// To avoid this, this function implements a dynamic sanitization that returns
// some arbitrary safe index. For power-of-two vector sizes, this uses a bitmask
// (presumably more efficient), and otherwise index 0 (always in-bounds).
static Value sanitizeDynamicIndex(ConversionPatternRewriter &rewriter,
Location loc, Value dynamicIndex,
int64_t kPoisonIndex, unsigned vectorSize) {
if (llvm::isPowerOf2_32(vectorSize)) {
Value inBoundsMask = spirv::ConstantOp::create(
rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), vectorSize - 1));
return spirv::BitwiseAndOp::create(rewriter, loc, dynamicIndex,
inBoundsMask);
}
Value poisonIndex = spirv::ConstantOp::create(
rewriter, loc, dynamicIndex.getType(),
rewriter.getIntegerAttr(dynamicIndex.getType(), kPoisonIndex));
Value cmpResult =
spirv::IEqualOp::create(rewriter, loc, dynamicIndex, poisonIndex);
return spirv::SelectOp::create(
rewriter, loc, cmpResult,
spirv::ConstantOp::getZero(dynamicIndex.getType(), loc, rewriter),
dynamicIndex);
}
struct VectorExtractOpConvert final
: public OpConversionPattern<vector::ExtractOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
rewriter.replaceOp(extractOp, adaptor.getVector());
return success();
}
if (std::optional<int64_t> id =
getConstantIntValue(extractOp.getMixedPosition()[0])) {
if (id == vector::ExtractOp::kPoisonIndex)
return rewriter.notifyMatchFailure(
extractOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
extractOp, dstType, adaptor.getVector(),
rewriter.getI32ArrayAttr(id.value()));
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, extractOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::ExtractOp::kPoisonIndex,
extractOp.getSourceVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
}
return success();
}
};
struct VectorExtractStridedSliceOpConvert final
: public OpConversionPattern<vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
if (!dstType)
return failure();
uint64_t offset = getFirstIntValue(extractOp.getOffsets());
uint64_t size = getFirstIntValue(extractOp.getSizes());
uint64_t stride = getFirstIntValue(extractOp.getStrides());
if (stride != 1)
return failure();
Value srcVector = adaptor.getOperands().front();
// Extract vector<1xT> case.
if (isa<spirv::ScalarType>(dstType)) {
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp,
srcVector, offset);
return success();
}
SmallVector<int32_t, 2> indices(size);
std::iota(indices.begin(), indices.end(), offset);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
extractOp, dstType, srcVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
template <class SPIRVFMAOp>
struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(fmaOp.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<SPIRVFMAOp>(fmaOp, dstType, adaptor.getLhs(),
adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
struct VectorFromElementsOpConvert final
: public OpConversionPattern<vector::FromElementsOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = getTypeConverter()->convertType(op.getType());
if (!resultType)
return failure();
ValueRange elements = adaptor.getElements();
if (isa<spirv::ScalarType>(resultType)) {
// In the case with a single scalar operand / single-element result,
// pass through the scalar.
rewriter.replaceOp(op, elements[0]);
return success();
}
// SPIRVTypeConverter rejects vectors with rank > 1, so multi-dimensional
// vector.from_elements cases should not need to be handled, only 1d.
assert(cast<VectorType>(resultType).getRank() == 1);
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, resultType,
elements);
return success();
}
};
struct VectorInsertOpConvert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<VectorType>(insertOp.getValueToStoreType()))
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
return rewriter.notifyMatchFailure(insertOp,
"unsupported dest vector type");
// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getValueToStoreType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
rewriter.replaceOp(insertOp, adaptor.getValueToStore());
return success();
}
if (std::optional<int64_t> id =
getConstantIntValue(insertOp.getMixedPosition()[0])) {
if (id == vector::InsertOp::kPoisonIndex)
return rewriter.notifyMatchFailure(
insertOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::InsertOp::kPoisonIndex,
insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
insertOp, insertOp.getDest(), adaptor.getValueToStore(),
sanitizedIndex);
}
return success();
}
};
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value srcVector = adaptor.getOperands().front();
Value dstVector = adaptor.getOperands().back();
uint64_t stride = getFirstIntValue(insertOp.getStrides());
if (stride != 1)
return failure();
uint64_t offset = getFirstIntValue(insertOp.getOffsets());
if (isa<spirv::ScalarType>(srcVector.getType())) {
assert(!isa<spirv::ScalarType>(dstVector.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, dstVector.getType(), srcVector, dstVector,
rewriter.getI32ArrayAttr(offset));
return success();
}
uint64_t totalSize = cast<VectorType>(dstVector.getType()).getNumElements();
uint64_t insertSize =
cast<VectorType>(srcVector.getType()).getNumElements();
SmallVector<int32_t, 2> indices(totalSize);
std::iota(indices.begin(), indices.end(), 0);
std::iota(indices.begin() + offset, indices.begin() + offset + insertSize,
totalSize);
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
insertOp, dstVector.getType(), dstVector, srcVector,
rewriter.getI32ArrayAttr(indices));
return success();
}
};
static SmallVector<Value> extractAllElements(
vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
int numElements = static_cast<int>(srcVectorType.getDimSize(0));
SmallVector<Value> values;
values.reserve(numElements + (adaptor.getAcc() ? 1 : 0));
Location loc = reduceOp.getLoc();
for (int i = 0; i < numElements; ++i) {
values.push_back(spirv::CompositeExtractOp::create(
rewriter, loc, srcVectorType.getElementType(), adaptor.getVector(),
rewriter.getI32ArrayAttr({i})));
}
if (Value acc = adaptor.getAcc())
values.push_back(acc);
return values;
}
struct ReductionRewriteInfo {
Type resultType;
SmallVector<Value> extractedElements;
};
FailureOr<ReductionRewriteInfo> static getReductionInfo(
vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
Type resultType = typeConverter.convertType(op.getType());
if (!resultType)
return failure();
auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector().getType());
if (!srcVectorType || srcVectorType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "not a 1-D vector source");
SmallVector<Value> extractedElements =
extractAllElements(op, adaptor, srcVectorType, rewriter);
return ReductionRewriteInfo{resultType, std::move(extractedElements)};
}
template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
typename SPIRVSMinOp>
struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto reductionInfo =
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
if (failed(reductionInfo))
return failure();
auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {
#define INT_AND_FLOAT_CASE(kind, iop, fop) \
case vector::CombiningKind::kind: \
if (llvm::isa<IntegerType>(resultType)) { \
result = spirv::iop::create(rewriter, loc, resultType, result, next); \
} else { \
assert(llvm::isa<FloatType>(resultType)); \
result = spirv::fop::create(rewriter, loc, resultType, result, next); \
} \
break
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp);
INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp);
INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp);
INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp);
INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp);
INT_OR_FLOAT_CASE(MAXSI, SPIRVSMaxOp);
case vector::CombiningKind::AND:
case vector::CombiningKind::OR:
case vector::CombiningKind::XOR:
return rewriter.notifyMatchFailure(reduceOp, "unimplemented");
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
#undef INT_AND_FLOAT_CASE
#undef INT_OR_FLOAT_CASE
}
rewriter.replaceOp(reduceOp, result);
return success();
}
};
template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
struct VectorReductionFloatMinMax final
: OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto reductionInfo =
getReductionInfo(reduceOp, adaptor, rewriter, *getTypeConverter());
if (failed(reductionInfo))
return failure();
auto [resultType, extractedElements] = *reductionInfo;
Location loc = reduceOp->getLoc();
Value result = extractedElements.front();
for (Value next : llvm::drop_begin(extractedElements)) {
switch (reduceOp.getKind()) {
#define INT_OR_FLOAT_CASE(kind, fop) \
case vector::CombiningKind::kind: \
result = fop::create(rewriter, loc, resultType, result, next); \
break
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
INT_OR_FLOAT_CASE(MAXNUMF, SPIRVFMaxOp);
INT_OR_FLOAT_CASE(MINNUMF, SPIRVFMinOp);
default:
return rewriter.notifyMatchFailure(reduceOp, "not handled here");
}
#undef INT_OR_FLOAT_CASE
}
rewriter.replaceOp(reduceOp, result);
return success();
}
};
class VectorScalarBroadcastPattern final
: public OpConversionPattern<vector::BroadcastOp> {
public:
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<VectorType>(op.getSourceType())) {
return rewriter.notifyMatchFailure(
op, "only conversion of 'broadcast from scalar' is supported");
}
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return failure();
if (isa<spirv::ScalarType>(dstType)) {
rewriter.replaceOp(op, adaptor.getSource());
} else {
auto dstVecType = cast<VectorType>(dstType);
SmallVector<Value, 4> source(dstVecType.getNumElements(),
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstType,
source);
}
return success();
}
};
struct VectorShuffleOpConvert final
: public OpConversionPattern<vector::ShuffleOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType oldResultType = shuffleOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(shuffleOp,
"unsupported result vector type");
auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
VectorType oldV1Type = shuffleOp.getV1VectorType();
VectorType oldV2Type = shuffleOp.getV2VectorType();
// When both operands and the result are SPIR-V vectors, emit a SPIR-V
// shuffle.
if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
oldResultType.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(mask));
return success();
}
// When at least one of the operands or the result becomes a scalar after
// type conversion for SPIR-V, extract all the required elements and
// construct the result vector.
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
return spirv::CompositeExtractOp::create(rewriter, loc, scalarOrVec,
idx);
assert(idx == 0 && "Invalid scalar element index");
return scalarOrVec;
};
int32_t numV1Elems = oldV1Type.getNumElements();
SmallVector<Value> newOperands(mask.size());
for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) {
Value vec = adaptor.getV1();
int32_t elementIdx = shuffleIdx;
if (elementIdx >= numV1Elems) {
vec = adaptor.getV2();
elementIdx -= numV1Elems;
}
newOperand = getElementAtIdx(vec, elementIdx);
}
// Handle the scalar result corner case.
if (newOperands.size() == 1) {
rewriter.replaceOp(shuffleOp, newOperands.front());
return success();
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
return success();
}
};
struct VectorInterleaveOpConvert final
: public OpConversionPattern<vector::InterleaveOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check the result vector type.
VectorType oldResultType = interleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(interleaveOp,
"unsupported result vector type");
// Interleave the indices.
VectorType sourceType = interleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();
// Input vectors of size 1 are converted to scalars by the type converter.
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeConstructOp`.
if (n == 1) {
Value newOperands[] = {adaptor.getLhs(), adaptor.getRhs()};
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
interleaveOp, newResultType, newOperands);
return success();
}
auto seq = llvm::seq<int64_t>(2 * n);
auto indices = llvm::map_to_vector(
seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; });
// Emit a SPIR-V shuffle.
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
rewriter.getI32ArrayAttr(indices));
return success();
}
};
struct VectorDeinterleaveOpConvert final
: public OpConversionPattern<vector::DeinterleaveOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check the result vector type.
VectorType oldResultType = deinterleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(deinterleaveOp,
"unsupported result vector type");
Location loc = deinterleaveOp->getLoc();
// Deinterleave the indices.
Value sourceVector = adaptor.getSource();
VectorType sourceType = deinterleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();
// Output vectors of size 1 are converted to scalars by the type converter.
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
auto elem0 = spirv::CompositeExtractOp::create(
rewriter, loc, newResultType, sourceVector,
rewriter.getI32ArrayAttr({0}));
auto elem1 = spirv::CompositeExtractOp::create(
rewriter, loc, newResultType, sourceVector,
rewriter.getI32ArrayAttr({1}));
rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
}
// Indices for `shuffleEven` (result 0).
auto seqEven = llvm::seq<int64_t>(n / 2);
auto indicesEven =
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
// Indices for `shuffleOdd` (result 1).
auto seqOdd = llvm::seq<int64_t>(n / 2);
auto indicesOdd =
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
// Create two SPIR-V shuffles.
auto shuffleEven = spirv::VectorShuffleOp::create(
rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesEven));
auto shuffleOdd = spirv::VectorShuffleOp::create(
rewriter, loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
return success();
}
};
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto memrefType = loadOp.getMemRefType();
auto attr =
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
if (!attr)
return rewriter.notifyMatchFailure(
loadOp, "expected spirv.storage_class memory space");
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
auto loc = loadOp.getLoc();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return rewriter.notifyMatchFailure(
loadOp, "failed to get memref element pointer");
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = loadOp.getVectorType();
// Use the converted vector type instead of original (single element vector
// would get converted to scalar).
auto spirvVectorType = typeConverter.convertType(vectorType);
if (!spirvVectorType)
return rewriter.notifyMatchFailure(loadOp, "unsupported vector type");
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
std::optional<uint64_t> alignment = loadOp.getAlignment();
if (alignment > std::numeric_limits<uint32_t>::max()) {
return rewriter.notifyMatchFailure(loadOp,
"invalid alignment requirement");
}
auto memoryAccess = spirv::MemoryAccess::None;
spirv::MemoryAccessAttr memoryAccessAttr;
IntegerAttr alignmentAttr;
if (alignment.has_value()) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
memoryAccessAttr =
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
}
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
Value castedAccessChain =
(vectorType.getNumElements() == 1)
? accessChain
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain,
memoryAccessAttr, alignmentAttr);
return success();
}
};
struct VectorStoreOpConverter final
: public OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto memrefType = storeOp.getMemRefType();
auto attr =
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
if (!attr)
return rewriter.notifyMatchFailure(
storeOp, "expected spirv.storage_class memory space");
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
auto loc = storeOp.getLoc();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return rewriter.notifyMatchFailure(
storeOp, "failed to get memref element pointer");
std::optional<uint64_t> alignment = storeOp.getAlignment();
if (alignment > std::numeric_limits<uint32_t>::max()) {
return rewriter.notifyMatchFailure(storeOp,
"invalid alignment requirement");
}
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
// For single element vectors, we don't need to bitcast the access chain to
// the original vector type. Both is going to be the same, a pointer
// to a scalar.
Value castedAccessChain =
(vectorType.getNumElements() == 1)
? accessChain
: spirv::BitcastOp::create(rewriter, loc, vectorPtrType,
accessChain);
auto memoryAccess = spirv::MemoryAccess::None;
spirv::MemoryAccessAttr memoryAccessAttr;
IntegerAttr alignmentAttr;
if (alignment.has_value()) {
memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
memoryAccessAttr =
spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess);
alignmentAttr = rewriter.getI32IntegerAttr(alignment.value());
}
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr,
alignmentAttr);
return success();
}
};
struct VectorReductionToIntDotProd final
: OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
if (op.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
auto resultType = dyn_cast<IntegerType>(op.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "result is not an integer");
int64_t resultBitwidth = resultType.getIntOrFloatBitWidth();
if (!llvm::is_contained({32, 64}, resultBitwidth))
return rewriter.notifyMatchFailure(op, "unsupported integer bitwidth");
VectorType inVecTy = op.getSourceVectorType();
if (!llvm::is_contained({4, 3}, inVecTy.getNumElements()) ||
inVecTy.getShape().size() != 1 || inVecTy.isScalable())
return rewriter.notifyMatchFailure(op, "unsupported vector shape");
auto mul = op.getVector().getDefiningOp<arith::MulIOp>();
if (!mul)
return rewriter.notifyMatchFailure(
op, "reduction operand is not 'arith.muli'");
if (succeeded(handleCase<arith::ExtSIOp, arith::ExtSIOp, spirv::SDotOp,
spirv::SDotAccSatOp, false>(op, mul, rewriter)))
return success();
if (succeeded(handleCase<arith::ExtUIOp, arith::ExtUIOp, spirv::UDotOp,
spirv::UDotAccSatOp, false>(op, mul, rewriter)))
return success();
if (succeeded(handleCase<arith::ExtSIOp, arith::ExtUIOp, spirv::SUDotOp,
spirv::SUDotAccSatOp, false>(op, mul, rewriter)))
return success();
if (succeeded(handleCase<arith::ExtUIOp, arith::ExtSIOp, spirv::SUDotOp,
spirv::SUDotAccSatOp, true>(op, mul, rewriter)))
return success();
return failure();
}
private:
template <typename LhsExtensionOp, typename RhsExtensionOp, typename DotOp,
typename DotAccOp, bool SwapOperands>
static LogicalResult handleCase(vector::ReductionOp op, arith::MulIOp mul,
PatternRewriter &rewriter) {
auto lhs = mul.getLhs().getDefiningOp<LhsExtensionOp>();
if (!lhs)
return failure();
Value lhsIn = lhs.getIn();
auto lhsInType = cast<VectorType>(lhsIn.getType());
if (!lhsInType.getElementType().isInteger(8))
return failure();
auto rhs = mul.getRhs().getDefiningOp<RhsExtensionOp>();
if (!rhs)
return failure();
Value rhsIn = rhs.getIn();
auto rhsInType = cast<VectorType>(rhsIn.getType());
if (!rhsInType.getElementType().isInteger(8))
return failure();
if (op.getSourceVectorType().getNumElements() == 3) {
IntegerType i8Type = rewriter.getI8Type();
auto v4i8Type = VectorType::get({4}, i8Type);
Location loc = op.getLoc();
Value zero = spirv::ConstantOp::getZero(i8Type, loc, rewriter);
lhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
ValueRange{lhsIn, zero});
rhsIn = spirv::CompositeConstructOp::create(rewriter, loc, v4i8Type,
ValueRange{rhsIn, zero});
}
// There's no variant of dot prod ops for unsigned LHS and signed RHS, so
// we have to swap operands instead in that case.
if (SwapOperands)
std::swap(lhsIn, rhsIn);
if (Value acc = op.getAcc()) {
rewriter.replaceOpWithNewOp<DotAccOp>(op, op.getType(), lhsIn, rhsIn, acc,
nullptr);
} else {
rewriter.replaceOpWithNewOp<DotOp>(op, op.getType(), lhsIn, rhsIn,
nullptr);
}
return success();
}
};
struct VectorReductionToFPDotProd final
: OpConversionPattern<vector::ReductionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(op, "combining kind is not 'add'");
auto resultType = getTypeConverter()->convertType<FloatType>(op.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "result is not a float");
Value vec = adaptor.getVector();
Value acc = adaptor.getAcc();
auto vectorType = dyn_cast<VectorType>(vec.getType());
if (!vectorType) {
assert(isa<FloatType>(vec.getType()) &&
"Expected the vector to be scalarized");
if (acc) {
rewriter.replaceOpWithNewOp<spirv::FAddOp>(op, acc, vec);
return success();
}
rewriter.replaceOp(op, vec);
return success();
}
Location loc = op.getLoc();
Value lhs;
Value rhs;
if (auto mul = vec.getDefiningOp<arith::MulFOp>()) {
lhs = mul.getLhs();
rhs = mul.getRhs();
} else {
// If the operand is not a mul, use a vector of ones for the dot operand
// to just sum up all values.
lhs = vec;
Attribute oneAttr =
rewriter.getFloatAttr(vectorType.getElementType(), 1.0);
oneAttr = SplatElementsAttr::get(vectorType, oneAttr);
rhs = spirv::ConstantOp::create(rewriter, loc, vectorType, oneAttr);
}
assert(lhs);
assert(rhs);
Value res = spirv::DotOp::create(rewriter, loc, resultType, lhs, rhs);
if (acc)
res = spirv::FAddOp::create(rewriter, loc, acc, res);
rewriter.replaceOp(op, res);
return success();
}
};
struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Type dstType = typeConverter.convertType(stepOp.getType());
if (!dstType)
return failure();
Location loc = stepOp.getLoc();
int64_t numElements = stepOp.getType().getNumElements();
auto intType =
rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());
// Input vectors of size 1 are converted to scalars by the type converter.
// We just create a constant in this case.
if (numElements == 1) {
Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
rewriter.replaceOp(stepOp, zero);
return success();
}
SmallVector<Value> source;
source.reserve(numElements);
for (int64_t i = 0; i < numElements; ++i) {
Attribute intAttr = rewriter.getIntegerAttr(intType, i);
Value constOp =
spirv::ConstantOp::create(rewriter, loc, intType, intAttr);
source.push_back(constOp);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
source);
return success();
}
};
struct VectorToElementOpConvert final
: OpConversionPattern<vector::ToElementsOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> results(toElementsOp->getNumResults());
Location loc = toElementsOp.getLoc();
// Input vectors of size 1 are converted to scalars by the type converter.
// We cannot use `spirv::CompositeExtractOp` directly in this case.
// For a scalar source, the result is just the scalar itself.
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
results[0] = adaptor.getSource();
rewriter.replaceOp(toElementsOp, results);
return success();
}
Type srcElementType = toElementsOp.getElements().getType().front();
Type elementType = getTypeConverter()->convertType(srcElementType);
if (!elementType)
return rewriter.notifyMatchFailure(
toElementsOp,
llvm::formatv("failed to convert element type '{0}' to SPIR-V",
srcElementType));
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
// Create an CompositeExtract operation only for results that are not
// dead.
if (element.use_empty())
continue;
Value result = spirv::CompositeExtractOp::create(
rewriter, loc, elementType, adaptor.getSource(),
rewriter.getI32ArrayAttr({static_cast<int32_t>(idx)}));
results[idx] = result;
}
rewriter.replaceOp(toElementsOp, results);
return success();
}
};
} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
#define GL_INT_MAX_MIN_OPS \
spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
VectorToElementOpConvert, VectorInsertOpConvert,
VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
VectorShuffleOpConvert, VectorInterleaveOpConvert,
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
RewritePatternSet &patterns) {
patterns.add<VectorReductionToIntDotProd>(patterns.getContext());
}