blob: 318b8eb10c16fa1a224aa0001240f9995979a475 [file] [log] [blame]
//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Integer types.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAInteger(MlirType type) {
return unwrap(type).isa<IntegerType>();
}
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth));
}
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
}
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
}
unsigned mlirIntegerTypeGetWidth(MlirType type) {
return unwrap(type).cast<IntegerType>().getWidth();
}
bool mlirIntegerTypeIsSignless(MlirType type) {
return unwrap(type).cast<IntegerType>().isSignless();
}
bool mlirIntegerTypeIsSigned(MlirType type) {
return unwrap(type).cast<IntegerType>().isSigned();
}
bool mlirIntegerTypeIsUnsigned(MlirType type) {
return unwrap(type).cast<IntegerType>().isUnsigned();
}
//===----------------------------------------------------------------------===//
// Index type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
MlirType mlirIndexTypeGet(MlirContext ctx) {
return wrap(IndexType::get(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getBF16(unwrap(ctx)));
}
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getF32(unwrap(ctx)));
}
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(FloatType::getF64(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
MlirType mlirNoneTypeGet(MlirContext ctx) {
return wrap(NoneType::get(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// Complex type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAComplex(MlirType type) {
return unwrap(type).isa<ComplexType>();
}
MlirType mlirComplexTypeGet(MlirType elementType) {
return wrap(ComplexType::get(unwrap(elementType)));
}
MlirType mlirComplexTypeGetElementType(MlirType type) {
return wrap(unwrap(type).cast<ComplexType>().getElementType());
}
//===----------------------------------------------------------------------===//
// Shaped type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
MlirType mlirShapedTypeGetElementType(MlirType type) {
return wrap(unwrap(type).cast<ShapedType>().getElementType());
}
bool mlirShapedTypeHasRank(MlirType type) {
return unwrap(type).cast<ShapedType>().hasRank();
}
int64_t mlirShapedTypeGetRank(MlirType type) {
return unwrap(type).cast<ShapedType>().getRank();
}
bool mlirShapedTypeHasStaticShape(MlirType type) {
return unwrap(type).cast<ShapedType>().hasStaticShape();
}
bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
return unwrap(type).cast<ShapedType>().isDynamicDim(
static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
}
bool mlirShapedTypeIsDynamicSize(int64_t size) {
return ShapedType::isDynamic(size);
}
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
return ShapedType::isDynamicStrideOrOffset(val);
}
//===----------------------------------------------------------------------===//
// Vector type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(
VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
bool mlirTypeIsARankedTensor(MlirType type) {
return unwrap(type).isa<RankedTensorType>();
}
bool mlirTypeIsAUnrankedTensor(MlirType type) {
return unwrap(type).isa<UnrankedTensorType>();
}
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirAttribute encoding) {
return wrap(RankedTensorType::get(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
unwrap(encoding)));
}
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
MlirType elementType,
MlirAttribute encoding) {
return wrap(RankedTensorType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), unwrap(encoding)));
}
MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
return wrap(unwrap(type).cast<RankedTensorType>().getEncoding());
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
MlirType elementType) {
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
const int64_t *shape, MlirAttribute layout,
MlirAttribute memorySpace) {
return wrap(MemRefType::get(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
: unwrap(layout).cast<MemRefLayoutAttrInterface>(),
unwrap(memorySpace)));
}
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
intptr_t rank, const int64_t *shape,
MlirAttribute layout,
MlirAttribute memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
: unwrap(layout).cast<MemRefLayoutAttrInterface>(),
unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
const int64_t *shape,
MlirAttribute memorySpace) {
return wrap(MemRefType::get(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
MlirType elementType, intptr_t rank,
const int64_t *shape,
MlirAttribute memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}
MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
return wrap(unwrap(type).cast<MemRefType>().getLayout());
}
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
return wrap(unwrap(type).cast<MemRefType>().getLayout().getAffineMap());
}
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(unwrap(type).cast<MemRefType>().getMemorySpace());
}
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
return unwrap(type).isa<UnrankedMemRefType>();
}
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
MlirAttribute memorySpace) {
return wrap(
UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
}
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
MlirType elementType,
MlirAttribute memorySpace) {
return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
unwrap(memorySpace)));
}
MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
return wrap(unwrap(type).cast<UnrankedMemRefType>().getMemorySpace());
}
//===----------------------------------------------------------------------===//
// Tuple type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType const *elements) {
SmallVector<Type, 4> types;
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
return wrap(TupleType::get(unwrap(ctx), typeRef));
}
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
return unwrap(type).cast<TupleType>().size();
}
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
}
//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAFunction(MlirType type) {
return unwrap(type).isa<FunctionType>();
}
MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
MlirType const *inputs, intptr_t numResults,
MlirType const *results) {
SmallVector<Type, 4> inputsList;
SmallVector<Type, 4> resultsList;
(void)unwrapList(numInputs, inputs, inputsList);
(void)unwrapList(numResults, results, resultsList);
return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
}
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
return unwrap(type).cast<FunctionType>().getNumInputs();
}
intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
return unwrap(type).cast<FunctionType>().getNumResults();
}
MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
return wrap(
unwrap(type).cast<FunctionType>().getInput(static_cast<unsigned>(pos)));
}
MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
return wrap(
unwrap(type).cast<FunctionType>().getResult(static_cast<unsigned>(pos)));
}