blob: 681e2448fb31474800b646e4dccc25c388a2021b [file] [log] [blame]
//===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/TypeFromLLVM.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Type.h"
using namespace mlir;
namespace mlir {
namespace LLVM {
namespace detail {
/// Support for translating LLVM IR types to MLIR LLVM dialect types.
class TypeFromLLVMIRTranslatorImpl {
public:
/// Constructs a class creating types in the given MLIR context.
TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
/// Translates the given type.
Type translateType(llvm::Type *type) {
if (knownTranslations.count(type))
return knownTranslations.lookup(type);
Type translated =
llvm::TypeSwitch<llvm::Type *, Type>(type)
.Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
llvm::ScalableVectorType>(
[this](auto *type) { return this->translate(type); })
.Default([this](llvm::Type *type) {
return translatePrimitiveType(type);
});
knownTranslations.try_emplace(type, translated);
return translated;
}
private:
/// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
/// type.
Type translatePrimitiveType(llvm::Type *type) {
if (type->isVoidTy())
return LLVM::LLVMVoidType::get(&context);
if (type->isHalfTy())
return Float16Type::get(&context);
if (type->isBFloatTy())
return BFloat16Type::get(&context);
if (type->isFloatTy())
return Float32Type::get(&context);
if (type->isDoubleTy())
return Float64Type::get(&context);
if (type->isFP128Ty())
return Float128Type::get(&context);
if (type->isX86_FP80Ty())
return Float80Type::get(&context);
if (type->isPPC_FP128Ty())
return LLVM::LLVMPPCFP128Type::get(&context);
if (type->isX86_MMXTy())
return LLVM::LLVMX86MMXType::get(&context);
if (type->isLabelTy())
return LLVM::LLVMLabelType::get(&context);
if (type->isMetadataTy())
return LLVM::LLVMMetadataType::get(&context);
llvm_unreachable("not a primitive type");
}
/// Translates the given array type.
Type translate(llvm::ArrayType *type) {
return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given function type.
Type translate(llvm::FunctionType *type) {
SmallVector<Type, 8> paramTypes;
translateTypes(type->params(), paramTypes);
return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
paramTypes, type->isVarArg());
}
/// Translates the given integer type.
Type translate(llvm::IntegerType *type) {
return IntegerType::get(&context, type->getBitWidth());
}
/// Translates the given pointer type.
Type translate(llvm::PointerType *type) {
return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
type->getAddressSpace());
}
/// Translates the given structure type.
Type translate(llvm::StructType *type) {
SmallVector<Type, 8> subtypes;
if (type->isLiteral()) {
translateTypes(type->subtypes(), subtypes);
return LLVM::LLVMStructType::getLiteral(&context, subtypes,
type->isPacked());
}
if (type->isOpaque())
return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
LLVM::LLVMStructType translated =
LLVM::LLVMStructType::getIdentified(&context, type->getName());
knownTranslations.try_emplace(type, translated);
translateTypes(type->subtypes(), subtypes);
LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
assert(succeeded(bodySet) &&
"could not set the body of an identified struct");
(void)bodySet;
return translated;
}
/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
return LLVM::getFixedVectorType(translateType(type->getElementType()),
type->getNumElements());
}
/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
return LLVM::LLVMScalableVectorType::get(
translateType(type->getElementType()), type->getMinNumElements());
}
/// Translates a list of types.
void translateTypes(ArrayRef<llvm::Type *> types,
SmallVectorImpl<Type> &result) {
result.reserve(result.size() + types.size());
for (llvm::Type *type : types)
result.push_back(translateType(type));
}
/// Map of known translations. Serves as a cache and as recursion stopper for
/// translating recursive structs.
llvm::DenseMap<llvm::Type *, Type> knownTranslations;
/// The context in which MLIR types are created.
MLIRContext &context;
};
} // end namespace detail
} // end namespace LLVM
} // end namespace mlir
LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
: impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
return impl->translateType(type);
}