[mlir][spirv] Add support for SPV_ARM_graph extension - part 3 (#156845)
This is the third patch to add support for the `SPV_ARM_graph` SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
`Graph` abstraction for expressing dataflow computations over full
resources.
The part 3 implementation includes:
- ABI lowering support for graph entry points via
`LowerABIAttributesPass`.
- Tests covering ABI handling.
Graphs currently support only `SPV_ARM_tensors`, but are designed to
generalize to other resource types, such as images.
Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC:
https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947
---------
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
GitOrigin-RevId: 8036edb21dbedf79687613caef3d40aa5a50ddf2
diff --git a/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526..44c86bc 100644
--- a/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1066,7 +1066,12 @@
}
LogicalResult SPIRVDialect::verifyRegionResultAttribute(
- Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+ Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
NamedAttribute attribute) {
- return op->emitError("cannot attach SPIR-V attributes to region result");
+ if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
+ return verifyRegionAttribute(
+ op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
+ return op->emitError(
+ "cannot attach SPIR-V attributes to region result which is "
+ "not part of a spirv::GraphARMOp type");
}
diff --git a/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3911ec0..5607a3c 100644
--- a/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -22,6 +22,7 @@
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace spirv {
@@ -85,10 +86,36 @@
abiInfo.getBinding());
}
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+ unsigned index, bool isArg,
+ spirv::InterfaceVarABIAttr abiInfo) {
+ auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule)
+ return nullptr;
+
+ OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+ builder.setInsertionPoint(graphOp.getOperation());
+ std::string varName = llvm::formatv("{}_{}_{}", graphOp.getName(),
+ isArg ? "arg" : "res", index);
+
+ Type varType = isArg ? graphOp.getFunctionType().getInput(index)
+ : graphOp.getFunctionType().getResult(index);
+
+ auto pointerType = spirv::PointerType::get(
+ varType,
+ abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+ return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+}
+
/// Gets the global variables that need to be specified as interface variable
/// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
SmallVectorImpl<Attribute> &interfaceVars) {
auto module = funcOp->getParentOfType<spirv::ModuleOp>();
if (!module) {
@@ -224,6 +251,21 @@
ConversionPatternRewriter &rewriter) const override;
};
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+ : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Pass to implement the ABI information specified as attributes.
class LowerABIAttributesPass final
: public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +339,63 @@
return success();
}
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+ spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // Non-entry point graphs are not handled.
+ if (!graphOp.getEntryPoint().value_or(false))
+ return failure();
+
+ TypeConverter::SignatureConversion signatureConverter(
+ graphOp.getFunctionType().getNumInputs());
+
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ SmallVector<Attribute, 4> interfaceVars;
+
+ // Convert arguments.
+ unsigned numInputs = graphOp.getFunctionType().getNumInputs();
+ unsigned numResults = graphOp.getFunctionType().getNumResults();
+ for (unsigned index = 0; index < numInputs; ++index) {
+ auto abiInfo =
+ graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, true, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ for (unsigned index = 0; index < numResults; ++index) {
+ auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, attrName);
+ if (!abiInfo)
+ return failure();
+ spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+ rewriter, graphOp, index, false, abiInfo);
+ if (!var)
+ return failure();
+ interfaceVars.push_back(
+ SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+ }
+
+ // Update graph signature.
+ rewriter.modifyOpInPlace(graphOp, [&] {
+ for (unsigned index = 0; index < numInputs; ++index) {
+ graphOp.removeArgAttr(index, attrName);
+ }
+ for (unsigned index = 0; index < numResults; ++index) {
+ graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
+ }
+ });
+
+ spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
+ interfaceVars);
+ return success();
+}
+
void LowerABIAttributesPass::runOnOperation() {
// Uses the signature conversion methodology of the dialect conversion
// framework to implement the conversion.
@@ -322,7 +421,8 @@
});
RewritePatternSet patterns(context);
- patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+ patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
+ typeConverter, context);
ConversionTarget target(*context);
// "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +433,17 @@
return false;
return true;
});
+ target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+ StringRef attrName = spirv::getInterfaceVarABIAttrName();
+ for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+ if (op.getArgAttr(i, attrName))
+ return false;
+ for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+ if (op.getResultAttr(i, attrName))
+ return false;
+ return true;
+ });
+
// All other SPIR-V ops are legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return op->getDialect()->getNamespace() ==
diff --git a/test/Dialect/SPIRV/IR/target-and-abi.mlir b/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf0..63dea6a 100644
--- a/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -101,6 +101,14 @@
// -----
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
+ !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.resource_limits
//===----------------------------------------------------------------------===//
diff --git a/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index f3a3218..04667c8 100644
--- a/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@
// -----
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+ // CHECK-DAG: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+ // CHECK-DAG: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+ // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+ // CHECK: spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+ spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+ -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+ }
+} // end spirv.module
+
+} // end module
+
+// -----
+
module {
// expected-error@+1 {{'spirv.module' op missing SPIR-V target env attribute}}
spirv.module Logical GLSL450 {}