blob: 424bbae179c3352d2f36b63d8a7729695465341e [file] [log] [blame]
//===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/CAPI/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser.h"
#include "llvm/Support/Debug.h"
#include <cstddef>
using namespace mlir;
//===----------------------------------------------------------------------===//
// Context API.
//===----------------------------------------------------------------------===//
MlirContext mlirContextCreate() {
auto *context = new MLIRContext;
return wrap(context);
}
bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) {
return unwrap(ctx1) == unwrap(ctx2);
}
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) {
unwrap(context)->allowUnregisteredDialects(allow);
}
bool mlirContextGetAllowUnregisteredDialects(MlirContext context) {
return unwrap(context)->allowsUnregisteredDialects();
}
intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) {
return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size());
}
// TODO: expose a cheaper way than constructing + sorting a vector only to take
// its size.
intptr_t mlirContextGetNumLoadedDialects(MlirContext context) {
return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size());
}
MlirDialect mlirContextGetOrLoadDialect(MlirContext context,
MlirStringRef name) {
return wrap(unwrap(context)->getOrLoadDialect(unwrap(name)));
}
bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) {
return unwrap(context)->isOperationRegistered(unwrap(name));
}
void mlirContextEnableMultithreading(MlirContext context, bool enable) {
return unwrap(context)->enableMultithreading(enable);
}
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MlirContext mlirDialectGetContext(MlirDialect dialect) {
return wrap(unwrap(dialect)->getContext());
}
bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) {
return unwrap(dialect1) == unwrap(dialect2);
}
MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
return wrap(unwrap(dialect)->getNamespace());
}
//===----------------------------------------------------------------------===//
// Printing flags API.
//===----------------------------------------------------------------------===//
MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
return wrap(new OpPrintingFlags());
}
void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
delete unwrap(flags);
}
void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
intptr_t largeElementLimit) {
unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
}
void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
bool prettyForm) {
unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
}
void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
unwrap(flags)->printGenericOpForm();
}
void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
unwrap(flags)->useLocalScope();
}
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
MlirLocation mlirLocationFileLineColGet(MlirContext context,
MlirStringRef filename, unsigned line,
unsigned col) {
return wrap(Location(
FileLineColLoc::get(unwrap(context), unwrap(filename), line, col)));
}
MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
}
MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations,
MlirAttribute metadata) {
SmallVector<Location, 4> locs;
ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs);
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
}
MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
MlirLocation childLoc) {
if (mlirLocationIsNull(childLoc))
return wrap(
Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name)))));
return wrap(Location(NameLoc::get(
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
}
MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(Location(UnknownLoc::get(unwrap(context))));
}
bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) {
return unwrap(l1) == unwrap(l2);
}
MlirContext mlirLocationGetContext(MlirLocation location) {
return wrap(unwrap(location).getContext());
}
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(location).print(stream);
}
//===----------------------------------------------------------------------===//
// Module API.
//===----------------------------------------------------------------------===//
MlirModule mlirModuleCreateEmpty(MlirLocation location) {
return wrap(ModuleOp::create(unwrap(location)));
}
MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) {
OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context));
if (!owning)
return MlirModule{nullptr};
return MlirModule{owning.release().getOperation()};
}
MlirContext mlirModuleGetContext(MlirModule module) {
return wrap(unwrap(module).getContext());
}
MlirBlock mlirModuleGetBody(MlirModule module) {
return wrap(unwrap(module).getBody());
}
void mlirModuleDestroy(MlirModule module) {
// Transfer ownership to an OwningModuleRef so that its destructor is called.
OwningModuleRef(unwrap(module));
}
MlirOperation mlirModuleGetOperation(MlirModule module) {
return wrap(unwrap(module).getOperation());
}
MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}
//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) {
MlirOperationState state;
state.name = name;
state.location = loc;
state.nResults = 0;
state.results = nullptr;
state.nOperands = 0;
state.operands = nullptr;
state.nRegions = 0;
state.regions = nullptr;
state.nSuccessors = 0;
state.successors = nullptr;
state.nAttributes = 0;
state.attributes = nullptr;
state.enableResultTypeInference = false;
return state;
}
#define APPEND_ELEMS(type, sizeName, elemName) \
state->elemName = \
(type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \
memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \
state->sizeName += n;
void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n,
MlirType const *results) {
APPEND_ELEMS(MlirType, nResults, results);
}
void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n,
MlirValue const *operands) {
APPEND_ELEMS(MlirValue, nOperands, operands);
}
void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n,
MlirRegion const *regions) {
APPEND_ELEMS(MlirRegion, nRegions, regions);
}
void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
MlirBlock const *successors) {
APPEND_ELEMS(MlirBlock, nSuccessors, successors);
}
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
MlirNamedAttribute const *attributes) {
APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes);
}
void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) {
state->enableResultTypeInference = true;
}
//===----------------------------------------------------------------------===//
// Operation API.
//===----------------------------------------------------------------------===//
static LogicalResult inferOperationTypes(OperationState &state) {
MLIRContext *context = state.getContext();
Optional<RegisteredOperationName> info = state.name.getRegisteredInfo();
if (!info) {
emitError(state.location)
<< "type inference was requested for the operation " << state.name
<< ", but the operation was not registered. Ensure that the dialect "
"containing the operation is linked into MLIR and registered with "
"the context";
return failure();
}
// Fallback to inference via an op interface.
auto *inferInterface = info->getInterface<InferTypeOpInterface>();
if (!inferInterface) {
emitError(state.location)
<< "type inference was requested for the operation " << state.name
<< ", but the operation does not support type inference. Result "
"types must be specified explicitly.";
return failure();
}
if (succeeded(inferInterface->inferReturnTypes(
context, state.location, state.operands,
state.attributes.getDictionary(context), state.regions, state.types)))
return success();
// Diagnostic emitted by interface.
return failure();
}
MlirOperation mlirOperationCreate(MlirOperationState *state) {
assert(state);
OperationState cppState(unwrap(state->location), unwrap(state->name));
SmallVector<Type, 4> resultStorage;
SmallVector<Value, 8> operandStorage;
SmallVector<Block *, 2> successorStorage;
cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage));
cppState.addOperands(
unwrapList(state->nOperands, state->operands, operandStorage));
cppState.addSuccessors(
unwrapList(state->nSuccessors, state->successors, successorStorage));
cppState.attributes.reserve(state->nAttributes);
for (intptr_t i = 0; i < state->nAttributes; ++i)
cppState.addAttribute(unwrap(state->attributes[i].name),
unwrap(state->attributes[i].attribute));
for (intptr_t i = 0; i < state->nRegions; ++i)
cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i])));
free(state->results);
free(state->operands);
free(state->successors);
free(state->regions);
free(state->attributes);
// Infer result types.
if (state->enableResultTypeInference) {
assert(cppState.types.empty() &&
"result type inference enabled and result types provided");
if (failed(inferOperationTypes(cppState)))
return {nullptr};
}
MlirOperation result = wrap(Operation::create(cppState));
return result;
}
MlirOperation mlirOperationClone(MlirOperation op) {
return wrap(unwrap(op)->clone());
}
void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); }
void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); }
bool mlirOperationEqual(MlirOperation op, MlirOperation other) {
return unwrap(op) == unwrap(other);
}
MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
MlirLocation mlirOperationGetLocation(MlirOperation op) {
return wrap(unwrap(op)->getLoc());
}
MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
if (auto info = unwrap(op)->getRegisteredInfo())
return wrap(info->getTypeID());
return {nullptr};
}
MlirIdentifier mlirOperationGetName(MlirOperation op) {
return wrap(unwrap(op)->getName().getIdentifier());
}
MlirBlock mlirOperationGetBlock(MlirOperation op) {
return wrap(unwrap(op)->getBlock());
}
MlirOperation mlirOperationGetParentOperation(MlirOperation op) {
return wrap(unwrap(op)->getParentOp());
}
intptr_t mlirOperationGetNumRegions(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumRegions());
}
MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) {
return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos)));
}
MlirRegion mlirOperationGetFirstRegion(MlirOperation op) {
Operation *cppOp = unwrap(op);
if (cppOp->getNumRegions() == 0)
return wrap(static_cast<Region *>(nullptr));
return wrap(&cppOp->getRegion(0));
}
MlirRegion mlirRegionGetNextInOperation(MlirRegion region) {
Region *cppRegion = unwrap(region);
Operation *parent = cppRegion->getParentOp();
intptr_t next = cppRegion->getRegionNumber() + 1;
if (parent->getNumRegions() > next)
return wrap(&parent->getRegion(next));
return wrap(static_cast<Region *>(nullptr));
}
MlirOperation mlirOperationGetNextInBlock(MlirOperation op) {
return wrap(unwrap(op)->getNextNode());
}
intptr_t mlirOperationGetNumOperands(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumOperands());
}
MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos)));
}
void mlirOperationSetOperand(MlirOperation op, intptr_t pos,
MlirValue newValue) {
unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue));
}
intptr_t mlirOperationGetNumResults(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumResults());
}
MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos)));
}
intptr_t mlirOperationGetNumSuccessors(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getNumSuccessors());
}
MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) {
return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos)));
}
intptr_t mlirOperationGetNumAttributes(MlirOperation op) {
return static_cast<intptr_t>(unwrap(op)->getAttrs().size());
}
MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) {
NamedAttribute attr = unwrap(op)->getAttrs()[pos];
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
}
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
MlirStringRef name) {
return wrap(unwrap(op)->getAttr(unwrap(name)));
}
void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name,
MlirAttribute attr) {
unwrap(op)->setAttr(unwrap(name), unwrap(attr));
}
bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) {
return !!unwrap(op)->removeAttr(unwrap(name));
}
void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(op)->print(stream);
}
void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(op)->print(stream, *unwrap(flags));
}
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
bool mlirOperationVerify(MlirOperation op) {
return succeeded(verify(unwrap(op)));
}
void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveAfter(unwrap(other));
}
void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) {
return unwrap(op)->moveBefore(unwrap(other));
}
//===----------------------------------------------------------------------===//
// Region API.
//===----------------------------------------------------------------------===//
MlirRegion mlirRegionCreate() { return wrap(new Region); }
bool mlirRegionEqual(MlirRegion region, MlirRegion other) {
return unwrap(region) == unwrap(other);
}
MlirBlock mlirRegionGetFirstBlock(MlirRegion region) {
Region *cppRegion = unwrap(region);
if (cppRegion->empty())
return wrap(static_cast<Block *>(nullptr));
return wrap(&cppRegion->front());
}
void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) {
unwrap(region)->push_back(unwrap(block));
}
void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos,
MlirBlock block) {
auto &blockList = unwrap(region)->getBlocks();
blockList.insert(std::next(blockList.begin(), pos), unwrap(block));
}
void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference,
MlirBlock block) {
Region *cppRegion = unwrap(region);
if (mlirBlockIsNull(reference)) {
cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block));
return;
}
assert(unwrap(reference)->getParent() == unwrap(region) &&
"expected reference block to belong to the region");
cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)),
unwrap(block));
}
void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference,
MlirBlock block) {
if (mlirBlockIsNull(reference))
return mlirRegionAppendOwnedBlock(region, block);
assert(unwrap(reference)->getParent() == unwrap(region) &&
"expected reference block to belong to the region");
unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)),
unwrap(block));
}
void mlirRegionDestroy(MlirRegion region) {
delete static_cast<Region *>(region.ptr);
}
//===----------------------------------------------------------------------===//
// Block API.
//===----------------------------------------------------------------------===//
MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) {
Block *b = new Block;
for (intptr_t i = 0; i < nArgs; ++i)
b->addArgument(unwrap(args[i]));
return wrap(b);
}
bool mlirBlockEqual(MlirBlock block, MlirBlock other) {
return unwrap(block) == unwrap(other);
}
MlirOperation mlirBlockGetParentOperation(MlirBlock block) {
return wrap(unwrap(block)->getParentOp());
}
MlirRegion mlirBlockGetParentRegion(MlirBlock block) {
return wrap(unwrap(block)->getParent());
}
MlirBlock mlirBlockGetNextInRegion(MlirBlock block) {
return wrap(unwrap(block)->getNextNode());
}
MlirOperation mlirBlockGetFirstOperation(MlirBlock block) {
Block *cppBlock = unwrap(block);
if (cppBlock->empty())
return wrap(static_cast<Operation *>(nullptr));
return wrap(&cppBlock->front());
}
MlirOperation mlirBlockGetTerminator(MlirBlock block) {
Block *cppBlock = unwrap(block);
if (cppBlock->empty())
return wrap(static_cast<Operation *>(nullptr));
Operation &back = cppBlock->back();
if (!back.hasTrait<OpTrait::IsTerminator>())
return wrap(static_cast<Operation *>(nullptr));
return wrap(&back);
}
void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) {
unwrap(block)->push_back(unwrap(operation));
}
void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos,
MlirOperation operation) {
auto &opList = unwrap(block)->getOperations();
opList.insert(std::next(opList.begin(), pos), unwrap(operation));
}
void mlirBlockInsertOwnedOperationAfter(MlirBlock block,
MlirOperation reference,
MlirOperation operation) {
Block *cppBlock = unwrap(block);
if (mlirOperationIsNull(reference)) {
cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation));
return;
}
assert(unwrap(reference)->getBlock() == unwrap(block) &&
"expected reference operation to belong to the block");
cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)),
unwrap(operation));
}
void mlirBlockInsertOwnedOperationBefore(MlirBlock block,
MlirOperation reference,
MlirOperation operation) {
if (mlirOperationIsNull(reference))
return mlirBlockAppendOwnedOperation(block, operation);
assert(unwrap(reference)->getBlock() == unwrap(block) &&
"expected reference operation to belong to the block");
unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)),
unwrap(operation));
}
void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); }
intptr_t mlirBlockGetNumArguments(MlirBlock block) {
return static_cast<intptr_t>(unwrap(block)->getNumArguments());
}
MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type) {
return wrap(unwrap(block)->addArgument(unwrap(type)));
}
MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
}
void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(block)->print(stream);
}
//===----------------------------------------------------------------------===//
// Value API.
//===----------------------------------------------------------------------===//
bool mlirValueEqual(MlirValue value1, MlirValue value2) {
return unwrap(value1) == unwrap(value2);
}
bool mlirValueIsABlockArgument(MlirValue value) {
return unwrap(value).isa<BlockArgument>();
}
bool mlirValueIsAOpResult(MlirValue value) {
return unwrap(value).isa<OpResult>();
}
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<BlockArgument>().getOwner());
}
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<BlockArgument>().getArgNumber());
}
void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
unwrap(value).cast<BlockArgument>().setType(unwrap(type));
}
MlirOperation mlirOpResultGetOwner(MlirValue value) {
return wrap(unwrap(value).cast<OpResult>().getOwner());
}
intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
unwrap(value).cast<OpResult>().getResultNumber());
}
MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
void mlirValuePrint(MlirValue value, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(value).print(stream);
}
//===----------------------------------------------------------------------===//
// Type API.
//===----------------------------------------------------------------------===//
MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) {
return wrap(mlir::parseType(unwrap(type), unwrap(context)));
}
MlirContext mlirTypeGetContext(MlirType type) {
return wrap(unwrap(type).getContext());
}
MlirTypeID mlirTypeGetTypeID(MlirType type) {
return wrap(unwrap(type).getTypeID());
}
bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}
void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(type).print(stream);
}
void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
//===----------------------------------------------------------------------===//
// Attribute API.
//===----------------------------------------------------------------------===//
MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) {
return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context)));
}
MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
return wrap(unwrap(attribute).getContext());
}
MlirType mlirAttributeGetType(MlirAttribute attribute) {
return wrap(unwrap(attribute).getType());
}
MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
return wrap(unwrap(attr).getTypeID());
}
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
unwrap(attr).print(stream);
}
void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name,
MlirAttribute attr) {
return MlirNamedAttribute{name, attr};
}
//===----------------------------------------------------------------------===//
// Identifier API.
//===----------------------------------------------------------------------===//
MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(context), unwrap(str)));
}
MlirContext mlirIdentifierGetContext(MlirIdentifier ident) {
return wrap(unwrap(ident).getContext());
}
bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
return unwrap(ident) == unwrap(other);
}
MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
return wrap(unwrap(ident).strref());
}
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
return unwrap(typeID1) == unwrap(typeID2);
}
size_t mlirTypeIDHashValue(MlirTypeID typeID) {
return hash_value(unwrap(typeID));
}
//===----------------------------------------------------------------------===//
// Symbol and SymbolTable API.
//===----------------------------------------------------------------------===//
MlirStringRef mlirSymbolTableGetSymbolAttributeName() {
return wrap(SymbolTable::getSymbolAttrName());
}
MlirStringRef mlirSymbolTableGetVisibilityAttributeName() {
return wrap(SymbolTable::getVisibilityAttrName());
}
MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) {
if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>())
return wrap(static_cast<SymbolTable *>(nullptr));
return wrap(new SymbolTable(unwrap(operation)));
}
void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) {
delete unwrap(symbolTable);
}
MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable,
MlirStringRef name) {
return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length)));
}
MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable,
MlirOperation operation) {
return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation)));
}
void mlirSymbolTableErase(MlirSymbolTable symbolTable,
MlirOperation operation) {
unwrap(symbolTable)->erase(unwrap(operation));
}
MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol,
MlirStringRef newSymbol,
MlirOperation from) {
auto cppFrom = unwrap(from);
auto *context = cppFrom->getContext();
auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context);
auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context);
return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr,
unwrap(from)));
}
void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible,
void (*callback)(MlirOperation, bool,
void *userData),
void *userData) {
SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible,
[&](Operation *foundOpCpp, bool isVisible) {
callback(wrap(foundOpCpp), isVisible,
userData);
});
}