blob: e8ed4b339a0cba32b8a47d988361b3e3a3babe2b [file] [log] [blame]
//===- CallInterfaces.cpp - ControlFlow Interfaces ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/IR/Builders.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Argument and result attributes utilities
//===----------------------------------------------------------------------===//
static ParseResult
parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl<Type> &types,
SmallVectorImpl<DictionaryAttr> &attrs) {
// Parse individual function results.
return parser.parseCommaSeparatedList([&]() -> ParseResult {
types.emplace_back();
attrs.emplace_back();
NamedAttrList attrList;
if (parser.parseType(types.back()) ||
parser.parseOptionalAttrDict(attrList))
return failure();
attrs.back() = attrList.getDictionary(parser.getContext());
return success();
});
}
ParseResult call_interface_impl::parseFunctionResultList(
OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs) {
if (failed(parser.parseOptionalLParen())) {
// We already know that there is no `(`, so parse a type.
// Because there is no `(`, it cannot be a function type.
Type ty;
if (parser.parseType(ty))
return failure();
resultTypes.push_back(ty);
resultAttrs.emplace_back();
return success();
}
// Special case for an empty set of parens.
if (succeeded(parser.parseOptionalRParen()))
return success();
if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
return failure();
return parser.parseRParen();
}
ParseResult call_interface_impl::parseFunctionSignature(
OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<DictionaryAttr> &argAttrs,
SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
// Parse arguments.
if (parser.parseLParen())
return failure();
if (failed(parser.parseOptionalRParen())) {
if (parseTypeAndAttrList(parser, argTypes, argAttrs))
return failure();
if (parser.parseRParen())
return failure();
}
// Parse results.
if (succeeded(parser.parseOptionalArrow()))
return call_interface_impl::parseFunctionResultList(parser, resultTypes,
resultAttrs);
if (mustParseEmptyResult)
return failure();
return success();
}
/// Print a function result list. The provided `attrs` must either be null, or
/// contain a set of DictionaryAttrs of the same arity as `types`.
static void printFunctionResultList(OpAsmPrinter &p, TypeRange types,
ArrayAttr attrs) {
assert(!types.empty() && "Should not be called for empty result list.");
assert((!attrs || attrs.size() == types.size()) &&
"Invalid number of attributes.");
auto &os = p.getStream();
bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
(attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
if (needsParens)
os << '(';
llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
p.printType(types[i]);
if (attrs)
p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
});
if (needsParens)
os << ')';
}
void call_interface_impl::printFunctionSignature(
OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
bool printEmptyResult) {
bool isExternal = !body || body->empty();
if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
printEmptyResult) {
p.printFunctionalType(argTypes, resultTypes);
return;
}
p << '(';
for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
if (i > 0)
p << ", ";
if (!isExternal) {
ArrayRef<NamedAttribute> attrs;
if (argAttrs)
attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
p.printRegionArgument(body->getArgument(i), attrs);
} else {
p.printType(argTypes[i]);
if (argAttrs)
p.printOptionalAttrDict(
llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
}
}
if (isVariadic) {
if (!argTypes.empty())
p << ", ";
p << "...";
}
p << ')';
if (!resultTypes.empty()) {
p << " -> ";
printFunctionResultList(p, resultTypes, resultAttrs);
} else if (printEmptyResult) {
p << " -> ()";
}
}
void call_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
StringAttr resAttrsName) {
auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
return attrs && !attrs.empty();
};
// Convert the specified array of dictionary attrs (which may have null
// entries) to an ArrayAttr of dictionaries.
auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
SmallVector<Attribute> attrs;
for (auto &dict : dictAttrs)
attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
return builder.getArrayAttr(attrs);
};
// Add the attributes to the operation arguments.
if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
// Add the attributes to the operation results.
if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
}
void call_interface_impl::addArgAndResultAttrs(
Builder &builder, OperationState &result,
ArrayRef<OpAsmParser::Argument> args, ArrayRef<DictionaryAttr> resultAttrs,
StringAttr argAttrsName, StringAttr resAttrsName) {
SmallVector<DictionaryAttr> argAttrs;
for (const auto &arg : args)
argAttrs.push_back(arg.attrs);
addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
resAttrsName);
}
//===----------------------------------------------------------------------===//
// CallOpInterface
//===----------------------------------------------------------------------===//
Operation *
call_interface_impl::resolveCallable(CallOpInterface call,
SymbolTableCollection *symbolTable) {
CallInterfaceCallable callable = call.getCallableForCallee();
if (auto symbolVal = dyn_cast<Value>(callable))
return symbolVal.getDefiningOp();
// If the callable isn't a value, lookup the symbol reference.
auto symbolRef = cast<SymbolRefAttr>(callable);
if (symbolTable)
return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
}
//===----------------------------------------------------------------------===//
// CallInterfaces
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/CallInterfaces.cpp.inc"