blob: c5125860d43794077fba3fb9cef697cfadc1056c [file] [log] [blame]
//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the folders and canonicalization patterns for SPIR-V ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//
/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
/// or splat vector bool constant.
static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
if (!boolAttr)
return llvm::None;
auto type = boolAttr.getType();
if (type.isInteger(1)) {
auto attr = boolAttr.cast<BoolAttr>();
return attr.getValue();
}
if (auto vecType = type.cast<VectorType>()) {
if (vecType.getElementType().isInteger(1))
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
return attr.getSplatValue<bool>();
}
return llvm::None;
}
// Extracts an element from the given `composite` by following the given
// `indices`. Returns a null Attribute if error happens.
static Attribute extractCompositeElement(Attribute composite,
ArrayRef<unsigned> indices) {
// Check that given composite is a constant.
if (!composite)
return {};
// Return composite itself if we reach the end of the index chain.
if (indices.empty())
return composite;
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
assert(indices.size() == 1 && "must have exactly one index for a vector");
return vector.getValues<Attribute>()[indices[0]];
}
if (auto array = composite.dyn_cast<ArrayAttr>()) {
assert(!indices.empty() && "must have at least one index for an array");
return extractCompositeElement(array.getValue()[indices[0]],
indices.drop_front());
}
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
namespace {
#include "SPIRVCanonicalization.inc"
}
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
namespace {
/// Combines chained `spirv::AccessChainOp` operations into one
/// `spirv::AccessChainOp` operation.
struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
accessChainOp.base_ptr().getDefiningOp());
if (!parentAccessChainOp) {
return failure();
}
// Combine indices.
SmallVector<Value, 4> indices(parentAccessChainOp.indices());
indices.append(accessChainOp.indices().begin(),
accessChainOp.indices().end());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.base_ptr(), indices);
return success();
}
};
} // end anonymous namespace
void spirv::AccessChainOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<CombineChainedAccessChain>(context);
}
//===----------------------------------------------------------------------===//
// spv.BitcastOp
//===----------------------------------------------------------------------===//
void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertChainedBitcast>(context);
}
//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
auto indexVector =
llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
}));
return extractCompositeElement(operands[0], indexVector);
}
//===----------------------------------------------------------------------===//
// spv.Constant
//===----------------------------------------------------------------------===//
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "spv.Constant has no operands");
return value();
}
//===----------------------------------------------------------------------===//
// spv.IAdd
//===----------------------------------------------------------------------===//
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IAdd expects two operands");
// x + 0 = x
if (matchPattern(operand2(), m_Zero()))
return operand1();
// According to the SPIR-V spec:
//
// The resulting value will equal the low-order N bits of the correct result
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a + b; });
}
//===----------------------------------------------------------------------===//
// spv.IMul
//===----------------------------------------------------------------------===//
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.IMul expects two operands");
// x * 0 == 0
if (matchPattern(operand2(), m_Zero()))
return operand2();
// x * 1 = x
if (matchPattern(operand2(), m_One()))
return operand1();
// According to the SPIR-V spec:
//
// The resulting value will equal the low-order N bits of the correct result
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a * b; });
}
//===----------------------------------------------------------------------===//
// spv.ISub
//===----------------------------------------------------------------------===//
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
// x - x = 0
if (operand1() == operand2())
return Builder(getContext()).getIntegerAttr(getType(), 0);
// According to the SPIR-V spec:
//
// The resulting value will equal the low-order N bits of the correct result
// R, where N is the component width and R is computed with enough precision
// to avoid overflow and underflow.
return constFoldBinaryOp<IntegerAttr>(operands,
[](APInt a, APInt b) { return a - b; });
}
//===----------------------------------------------------------------------===//
// spv.LogicalAnd
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
// x && true = x
if (rhs.getValue())
return operand1();
// x && false = false
if (!rhs.getValue())
return operands.back();
}
return Attribute();
}
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
void spirv::LogicalNotOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
.add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
context);
}
//===----------------------------------------------------------------------===//
// spv.LogicalOr
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
if (rhs.getValue())
// x || true = true
return operands.back();
// x || false = x
if (!rhs.getValue())
return operand1();
}
return Attribute();
}
//===----------------------------------------------------------------------===//
// spv.mlir.selection
//===----------------------------------------------------------------------===//
namespace {
// Blocks from the given `spv.mlir.selection` operation must satisfy the
// following layout:
//
// +-----------------------------------------------+
// | header block |
// | spv.BranchConditionalOp %cond, ^case0, ^case1 |
// +-----------------------------------------------+
// / \
// ...
//
//
// +------------------------+ +------------------------+
// | case #0 | | case #1 |
// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
// | spv.Branch ^merge | | spv.Branch ^merge |
// +------------------------+ +------------------------+
//
//
// ...
// \ /
// v
// +-------------+
// | merge block |
// +-------------+
//
struct ConvertSelectionOpToSelect
: public OpRewritePattern<spirv::SelectionOp> {
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
PatternRewriter &rewriter) const override {
auto *op = selectionOp.getOperation();
auto &body = op->getRegion(0);
// Verifier allows an empty region for `spv.mlir.selection`.
if (body.empty()) {
return failure();
}
// Check that region consists of 4 blocks:
// header block, `true` block, `false` block and merge block.
if (std::distance(body.begin(), body.end()) != 4) {
return failure();
}
auto *headerBlock = selectionOp.getHeaderBlock();
if (!onlyContainsBranchConditionalOp(headerBlock)) {
return failure();
}
auto brConditionalOp =
cast<spirv::BranchConditionalOp>(headerBlock->front());
auto *trueBlock = brConditionalOp.getSuccessor(0);
auto *falseBlock = brConditionalOp.getSuccessor(1);
auto *mergeBlock = selectionOp.getMergeBlock();
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
return failure();
auto trueValue = getSrcValue(trueBlock);
auto falseValue = getSrcValue(falseBlock);
auto ptrValue = getDstPtr(trueBlock);
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
auto selectOp = rewriter.create<spirv::SelectOp>(
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
trueValue, falseValue);
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
selectOp.getResult(), storeOpAttributes);
// `spv.mlir.selection` is not needed anymore.
rewriter.eraseOp(op);
return success();
}
private:
// Checks that given blocks follow the following rules:
// 1. Each conditional block consists of two operations, the first operation
// is a `spv.Store` and the last operation is a `spv.Branch`.
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
// 3. A control flow goes into the given merge block from the given
// conditional blocks.
LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
Block *mergeBlock) const;
bool onlyContainsBranchConditionalOp(Block *block) const {
return std::next(block->begin()) == block->end() &&
isa<spirv::BranchConditionalOp>(block->front());
}
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
return lhs->getAttrDictionary() == rhs->getAttrDictionary();
}
// Returns a source value for the given block.
Value getSrcValue(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.value();
}
// Returns a destination value for the given block.
Value getDstPtr(Block *block) const {
auto storeOp = cast<spirv::StoreOp>(block->front());
return storeOp.ptr();
}
};
LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
// Each block must consists of 2 operations.
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
return failure();
}
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
auto trueBrBranchOp =
dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
auto falseBrBranchOp =
dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
!falseBrBranchOp) {
return failure();
}
// Checks that given type is valid for `spv.SelectOp`.
// According to SPIR-V spec:
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
// Starting with version 1.4, Result Type can additionally be a composite type
// other than a vector."
bool isScalarOrVector = trueBrStoreOp.value()
.getType()
.cast<spirv::SPIRVType>()
.isScalarOrVector();
// Check that each `spv.Store` uses the same pointer, memory access
// attributes and a valid type of the value.
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
return failure();
}
if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
(falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
return failure();
}
return success();
}
} // end anonymous namespace
void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConvertSelectionOpToSelect>(context);
}