blob: 78f74eef7bee370536e15b07658ab8bcc3a47af6 [file] [log] [blame]
//=- VectorEmulateMaskedLoadStore.cpp - Emulate 'vector.maskedload/store' op =//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to emulate the
// 'vector.maskedload' and 'vector.maskedstore' operation.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
using namespace mlir;
namespace {
/// Convert vector.maskedload
///
/// Before:
///
/// vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
///
/// After:
///
/// %ivalue = %pass_thru
/// %m = vector.extract %mask[0]
/// %result0 = scf.if %m {
/// %v = memref.load %base[%idx_0, %idx_1]
/// %combined = vector.insert %v, %ivalue[0]
/// scf.yield %combined
/// } else {
/// scf.yield %ivalue
/// }
/// %m = vector.extract %mask[1]
/// %result1 = scf.if %m {
/// %v = memref.load %base[%idx_0, %idx_1 + 1]
/// %combined = vector.insert %v, %result0[1]
/// scf.yield %combined
/// } else {
/// scf.yield %result0
/// }
/// ...
///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
PatternRewriter &rewriter) const override {
VectorType maskVType = maskedLoadOp.getMaskVectorType();
if (maskVType.getShape().size() != 1)
return rewriter.notifyMatchFailure(
maskedLoadOp, "expected vector.maskedstore with 1-D mask");
Location loc = maskedLoadOp.getLoc();
int64_t maskLength = maskVType.getShape()[0];
Type indexType = rewriter.getIndexType();
Value mask = maskedLoadOp.getMask();
Value base = maskedLoadOp.getBase();
Value iValue = maskedLoadOp.getPassThru();
std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
auto ifOp = scf::IfOp::create(
rewriter, loc, maskBit,
[&](OpBuilder &builder, Location loc) {
auto loadedValue = memref::LoadOp::create(
builder, loc, base, indices, /*nontemporal=*/false,
alignment.value_or(0));
auto combinedValue =
vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
scf::YieldOp::create(builder, loc, combinedValue.getResult());
},
[&](OpBuilder &builder, Location loc) {
scf::YieldOp::create(builder, loc, iValue);
});
iValue = ifOp.getResult(0);
indices.back() =
arith::AddIOp::create(rewriter, loc, indices.back(), one);
}
rewriter.replaceOp(maskedLoadOp, iValue);
return success();
}
};
/// Convert vector.maskedstore
///
/// Before:
///
/// vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
///
/// After:
///
/// %m = vector.extract %mask[0]
/// scf.if %m {
/// %extracted = vector.extract %value[0]
/// memref.store %extracted, %base[%idx_0, %idx_1]
/// }
/// %m = vector.extract %mask[1]
/// scf.if %m {
/// %extracted = vector.extract %value[1]
/// memref.store %extracted, %base[%idx_0, %idx_1 + 1]
/// }
/// ...
///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
PatternRewriter &rewriter) const override {
VectorType maskVType = maskedStoreOp.getMaskVectorType();
if (maskVType.getShape().size() != 1)
return rewriter.notifyMatchFailure(
maskedStoreOp, "expected vector.maskedstore with 1-D mask");
Location loc = maskedStoreOp.getLoc();
int64_t maskLength = maskVType.getShape()[0];
Type indexType = rewriter.getIndexType();
Value mask = maskedStoreOp.getMask();
Value base = maskedStoreOp.getBase();
Value value = maskedStoreOp.getValueToStore();
bool nontemporal = false;
std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
for (int64_t i = 0; i < maskLength; ++i) {
auto maskBit = vector::ExtractOp::create(rewriter, loc, mask, i);
auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
nontemporal, alignment.value_or(0));
rewriter.setInsertionPointAfter(ifOp);
indices.back() =
arith::AddIOp::create(rewriter, loc, indices.back(), one);
}
rewriter.eraseOp(maskedStoreOp);
return success();
}
};
} // namespace
void mlir::vector::populateVectorMaskedLoadStoreEmulationPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
patterns.getContext(), benefit);
}