blob: ba69fa75cf2b89b2ce7e1ffa113d6d91c06a21ac [file] [edit]
//===- SPIRVOpDefinition.cpp - MLIR SPIR-V Op Definition Implementation ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the TableGen'erated SPIR-V op implementation in the SPIR-V dialect.
// These are placed in a separate file to reduce the total amount of code in
// SPIRVOps.cpp and make that file faster to recompile.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVParsingUtils.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir::spirv {
/// Returns true if the given op is a function-like op or nested in a
/// function-like op without a module-like op in the middle.
static bool isNestedInFunctionOpInterface(Operation *op) {
if (!op)
return false;
if (op->hasTrait<OpTrait::SymbolTable>())
return false;
if (isa<FunctionOpInterface>(op))
return true;
return isNestedInFunctionOpInterface(op->getParentOp());
}
/// Returns true if the given op is a GraphARM op or nested in a
/// GraphARM op without a module-like op in the middle.
static bool isNestedInGraphARMOpInterface(Operation *op) {
if (!op)
return false;
if (op->hasTrait<OpTrait::SymbolTable>())
return false;
if (isa<spirv::GraphARMOp>(op))
return true;
return isNestedInGraphARMOpInterface(op->getParentOp());
}
/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
return op && op->hasTrait<OpTrait::SymbolTable>();
}
/// Result of a logical op must be a scalar or vector of boolean type.
static Type getUnaryOpResultType(Type operandType) {
Builder builder(operandType.getContext());
Type resultType = builder.getIntegerType(1);
if (auto vecType = dyn_cast<VectorType>(operandType))
return VectorType::get(vecType.getNumElements(), resultType);
return resultType;
}
static ParseResult parseImageOperands(OpAsmParser &parser,
spirv::ImageOperandsAttr &attr) {
// Expect image operands
if (parser.parseOptionalLSquare())
return success();
spirv::ImageOperands imageOperands;
if (parseEnumStrAttr(imageOperands, parser))
return failure();
attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
return parser.parseRSquare();
}
static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
spirv::ImageOperandsAttr attr) {
if (attr) {
auto strImageOperands = stringifyImageOperands(attr.getValue());
printer << "[\"" << strImageOperands << "\"]";
}
}
/// Adapted from the cf.switch implementation.
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
SmallVectorImpl<Block *> &targets,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
&targetOperands,
SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultTarget))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
SmallVector<APInt> values;
unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
Block *target;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseOperandList(operands,
OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
targets.push_back(target);
targetOperands.emplace_back(operands);
targetOperandTypes.emplace_back(operandTypes);
}
if (!values.empty()) {
ShapedType literalType =
VectorType::get(static_cast<int64_t>(values.size()), selectorType);
literals = DenseIntElementsAttr::get(literalType, values);
}
return success();
}
static void
printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
Block *defaultTarget, OperandRange defaultOperands,
TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
SuccessorRange targets, OperandRangeRange targetOperands,
const TypeRangeRange &targetOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultTarget, defaultOperands);
if (!literals)
return;
for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << literal.getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(targets[index], targetOperands[index]);
}
p.printNewline();
}
} // namespace mlir::spirv
// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"