blob: 47fe4d9c5b21c35f7f67e020ec3686da366879f6 [file] [edit]
//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::spirv::AttrNames;
//===----------------------------------------------------------------------===//
// spirv.GraphARM
//===----------------------------------------------------------------------===//
ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
OperationState &result) {
Builder &builder = parser.getBuilder();
// Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
// Parse the function signature.
bool isVariadic = false;
SmallVector<OpAsmParser::Argument> entryArgs;
SmallVector<Type> resultTypes;
SmallVector<DictionaryAttr> resultAttrs;
if (function_interface_impl::parseFunctionSignatureWithArguments(
parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
resultAttrs))
return failure();
SmallVector<Type> argTypes = llvm::map_to_vector(
entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
GraphType grType = builder.getGraphType(argTypes, resultTypes);
result.addAttribute(getFunctionTypeAttrName(result.name),
TypeAttr::get(grType));
// If additional attributes are present, parse them.
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
// Add the attributes to the function arguments.
assert(resultAttrs.size() == resultTypes.size());
call_interface_impl::addArgAndResultAttrs(
builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
getResAttrsAttrName(result.name));
// Parse the optional function body.
Region *body = result.addRegion();
OptionalParseResult parseResult =
parser.parseOptionalRegion(*body, entryArgs);
return failure(parseResult.has_value() && failed(*parseResult));
}
void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
// Print graph name, signature, and control.
printer << " ";
printer.printSymbolName(getSymName());
GraphType grType = getFunctionType();
function_interface_impl::printFunctionSignature(
printer, *this, grType.getInputs(),
/*isVariadic=*/false, grType.getResults());
function_interface_impl::printFunctionAttributes(printer, *this,
{getFunctionTypeAttrName(),
getArgAttrsAttrName(),
getResAttrsAttrName()});
// Print the body.
Region &body = this->getBody();
if (!body.empty()) {
printer << ' ';
printer.printRegion(body, /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
}
}
LogicalResult spirv::GraphARMOp::verifyType() {
if (getFunctionType().getNumResults() < 1)
return emitOpError("there should be at least one result");
return success();
}
LogicalResult spirv::GraphARMOp::verifyBody() {
for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
if (!isa<spirv::TensorArmType>(graphArgType)) {
return emitOpError("type of argument #")
<< index << " must be a TensorArmType, but got " << graphArgType;
}
}
for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
if (!isa<spirv::TensorArmType>(graphResType)) {
return emitOpError("type of result #")
<< index << " must be a TensorArmType, but got " << graphResType;
}
}
if (!isExternal()) {
Block &entryBlock = front();
unsigned numArguments = this->getNumArguments();
if (entryBlock.getNumArguments() != numArguments)
return emitOpError("entry block must have ")
<< numArguments << " arguments to match graph signature";
for (auto [index, grArgType, blockArgType] :
llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
if (blockArgType != grArgType) {
return emitOpError("type of entry block argument #")
<< index << '(' << blockArgType
<< ") must match the type of the corresponding argument in "
<< "graph signature(" << grArgType << ')';
}
}
}
GraphType grType = getFunctionType();
auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
if (grType.getNumResults() != op.getNumOperands())
return op.emitOpError("is returning ")
<< op.getNumOperands()
<< " value(s) but enclosing spirv.ARM.Graph requires "
<< grType.getNumResults() << " result(s)";
ValueTypeRange<OperandRange> graphOutputOperandTypes =
op.getValue().getType();
for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) {
if (type != grType.getResult(index))
return op.emitError("type of return operand ")
<< index << " (" << type << ") doesn't match graph result type ("
<< grType.getResult(index) << ")";
}
return WalkResult::advance();
});
return failure(walkResult.wasInterrupted());
}
void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
StringRef name, GraphType type,
ArrayRef<NamedAttribute> attrs, bool entryPoint) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(name));
state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
state.attributes.append(attrs);
state.addAttribute(getEntryPointAttrName(state.name),
builder.getBoolAttr(entryPoint));
state.addRegion();
}
ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
return getFunctionType().getInputs();
}
ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
return getFunctionType().getResults();
}
Region *spirv::GraphARMOp::getCallableRegion() {
return isExternal() ? nullptr : &getBody();
}
//===----------------------------------------------------------------------===//
// spirv.GraphOutputsARM
//===----------------------------------------------------------------------===//
LogicalResult spirv::GraphOutputsARMOp::verify() {
auto graph = cast<GraphARMOp>((*this)->getParentOp());
// The operand number and types must match the graph signature.
const ArrayRef<Type> &results = graph.getFunctionType().getResults();
if (getNumOperands() != results.size())
return emitOpError("has ")
<< getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
<< graph.getName() << ") returns " << results.size();
for (auto [index, result] : llvm::enumerate(results))
if (getOperand(index).getType() != result)
return emitError() << "type of return operand " << index << " ("
<< getOperand(index).getType()
<< ") doesn't match spirv.ARM.Graph result type ("
<< result << ")"
<< " in graph @" << graph.getName();
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GraphEntryPointARM
//===----------------------------------------------------------------------===//
void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
OperationState &state,
spirv::GraphARMOp graph,
ArrayRef<Attribute> interfaceVars) {
build(builder, state, SymbolRefAttr::get(graph),
builder.getArrayAttr(interfaceVars));
}
ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
OperationState &result) {
FlatSymbolRefAttr fn;
if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
return failure();
SmallVector<Attribute, 4> interfaceVars;
if (!parser.parseOptionalComma()) {
// Parse the interface variables.
if (parser.parseCommaSeparatedList([&]() -> ParseResult {
// The name of the interface variable attribute is not important.
FlatSymbolRefAttr var;
NamedAttrList attrs;
if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
return failure();
interfaceVars.push_back(var);
return success();
}))
return failure();
}
result.addAttribute("interface",
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
printer << " ";
printer.printSymbolName(getFn());
ArrayRef<Attribute> interfaceVars = getInterface().getValue();
if (!interfaceVars.empty()) {
printer << ", " << llvm::interleaved(interfaceVars);
}
}