blob: 160b6ae89215cc616ec379b4960177e47429294c [file] [log] [blame]
//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helper functions to call common simple C functions in
// LLVMIR (e.g. amon others to support printing and debugging).
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace mlir::LLVM;
/// Helper functions to lookup or create the declaration for commonly used
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as
/// part of the libc).
static constexpr llvm::StringRef kPrintI64 = "printI64";
static constexpr llvm::StringRef kPrintU64 = "printU64";
static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
static constexpr llvm::StringRef kPrintNewline = "printNewline";
static constexpr llvm::StringRef kMalloc = "malloc";
static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
static constexpr llvm::StringRef kFree = "free";
static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
static constexpr llvm::StringRef kGenericAlignedAlloc =
"_mlir_memref_to_llvm_aligned_alloc";
static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
namespace {
/// Search for an LLVMFuncOp with a given name within an operation with the
/// SymbolTable trait. An optional collection of cached symbol tables can be
/// given to avoid a linear scan of the symbol table operation.
LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
SymbolTableCollection *symbolTables = nullptr) {
if (symbolTables) {
return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
}
return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, name));
}
} // namespace
/// Generic print function lookupOrCreate helper.
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType,
bool isVarArg, bool isReserved,
SymbolTableCollection *symbolTables) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = lookupFuncOp(name, moduleOp, symbolTables);
auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
// Assert the signature of the found function is same as expected
if (func) {
if (funcT != func.getFunctionType()) {
if (isReserved) {
func.emitError("redefinition of reserved function '")
<< name << "' of different type " << func.getFunctionType()
<< " is prohibited";
} else {
func.emitError("redefinition of function '")
<< name << "' of different type " << funcT << " is prohibited";
}
return failure();
}
return func;
}
OpBuilder::InsertionGuard g(b);
assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
auto funcOp = LLVM::LLVMFuncOp::create(
b, moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
if (symbolTables) {
SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
}
return funcOp;
}
static FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes, Type resultType,
SymbolTableCollection *symbolTables) {
return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
/*isVarArg=*/false, /*isReserved=*/true,
symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintApFloat,
{IntegerType::get(moduleOp->getContext(), 32),
IntegerType::get(moduleOp->getContext(), 64)},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
return LLVM::LLVMPointerType::get(context);
}
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
// A char pointer and void ptr are the same in LLVM IR.
return getCharPtr(context);
}
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
OpBuilder &b, Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp,
Type indexType,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()),
symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp,
Type indexType,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp,
Type indexType,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()),
symbolTables);
}
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(
OpBuilder &b, Operation *moduleOp, Type indexType,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp>
mlir::LLVM::lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp,
SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
OpBuilder &b, Operation *moduleOp, Type indexType,
Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
return lookupOrCreateReservedFn(
b, moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
}