blob: b1772c2645c7e68f63d3322dcb9f86ffec3c6563 [file] [log] [blame]
//===- DecorateCompositeTypeLayoutPass.cpp - Decorate composite type ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to decorate the composite types used by
// composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
// PushConstant storage classes with layout information. See SPIR-V spec
// "2.16.2. Validation Rules for Shader Capabilities" for more details.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
class SPIRVGlobalVariableOpLayoutInfoDecoration
: public OpRewritePattern<spirv::GlobalVariableOp> {
public:
using OpRewritePattern<spirv::GlobalVariableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::GlobalVariableOp op,
PatternRewriter &rewriter) const override {
SmallVector<NamedAttribute, 4> globalVarAttrs;
auto ptrType = op.type().cast<spirv::PointerType>();
auto structType = VulkanLayoutUtils::decorateType(
ptrType.getPointeeType().cast<spirv::StructType>());
if (!structType)
return failure();
auto decoratedType =
spirv::PointerType::get(structType, ptrType.getStorageClass());
// Save all named attributes except "type" attribute.
for (const auto &attr : op->getAttrs()) {
if (attr.getName() == "type")
continue;
globalVarAttrs.push_back(attr);
}
rewriter.replaceOpWithNewOp<spirv::GlobalVariableOp>(
op, TypeAttr::get(decoratedType), globalVarAttrs);
return success();
}
};
class SPIRVAddressOfOpLayoutInfoDecoration
: public OpRewritePattern<spirv::AddressOfOp> {
public:
using OpRewritePattern<spirv::AddressOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
PatternRewriter &rewriter) const override {
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
auto varName = op.variableAttr();
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
op, varOp.type(), SymbolRefAttr::get(varName.getAttr()));
return success();
}
};
template <typename OpT>
class SPIRVPassThroughConversion : public OpConversionPattern<OpT> {
public:
using OpConversionPattern<OpT>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
} // namespace
static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns) {
patterns.add<SPIRVGlobalVariableOpLayoutInfoDecoration,
SPIRVAddressOfOpLayoutInfoDecoration,
SPIRVPassThroughConversion<spirv::AccessChainOp>,
SPIRVPassThroughConversion<spirv::LoadOp>,
SPIRVPassThroughConversion<spirv::StoreOp>>(
patterns.getContext());
}
namespace {
class DecorateSPIRVCompositeTypeLayoutPass
: public SPIRVCompositeTypeLayoutBase<
DecorateSPIRVCompositeTypeLayoutPass> {
void runOnOperation() override;
};
} // namespace
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
RewritePatternSet patterns(module.getContext());
populateSPIRVLayoutInfoPatterns(patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
target.addDynamicallyLegalOp<spirv::GlobalVariableOp>(
[](spirv::GlobalVariableOp op) {
return VulkanLayoutUtils::isLegalType(op.type());
});
// Change the type for the direct users.
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
});
// Change the type for the indirect users.
target.addDynamicallyLegalOp<spirv::AccessChainOp, spirv::LoadOp,
spirv::StoreOp>([&](Operation *op) {
for (Value operand : op->getOperands()) {
auto addrOp = operand.getDefiningOp<spirv::AddressOfOp>();
if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType()))
return false;
}
return true;
});
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (auto spirvModule : module.getOps<spirv::ModuleOp>())
if (failed(applyFullConversion(spirvModule, target, frozenPatterns)))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass() {
return std::make_unique<DecorateSPIRVCompositeTypeLayoutPass>();
}