//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"

#include <algorithm>

using namespace mlir;

//===----------------------------------------------------------------------===//
// Integer types.
//===----------------------------------------------------------------------===//

MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }

bool mlirTypeIsAInteger(MlirType type) {
  return llvm::isa<IntegerType>(unwrap(type));
}

MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
  return wrap(IntegerType::get(unwrap(ctx), bitwidth));
}

MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
  return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
}

MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
  return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
}

unsigned mlirIntegerTypeGetWidth(MlirType type) {
  return llvm::cast<IntegerType>(unwrap(type)).getWidth();
}

bool mlirIntegerTypeIsSignless(MlirType type) {
  return llvm::cast<IntegerType>(unwrap(type)).isSignless();
}

bool mlirIntegerTypeIsSigned(MlirType type) {
  return llvm::cast<IntegerType>(unwrap(type)).isSigned();
}

bool mlirIntegerTypeIsUnsigned(MlirType type) {
  return llvm::cast<IntegerType>(unwrap(type)).isUnsigned();
}

//===----------------------------------------------------------------------===//
// Index type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }

bool mlirTypeIsAIndex(MlirType type) {
  return llvm::isa<IndexType>(unwrap(type));
}

MlirType mlirIndexTypeGet(MlirContext ctx) {
  return wrap(IndexType::get(unwrap(ctx)));
}

//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//

bool mlirTypeIsAFloat(MlirType type) {
  return llvm::isa<FloatType>(unwrap(type));
}

unsigned mlirFloatTypeGetWidth(MlirType type) {
  return llvm::cast<FloatType>(unwrap(type)).getWidth();
}

MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
  return wrap(Float8E5M2Type::getTypeID());
}

bool mlirTypeIsAFloat8E5M2(MlirType type) {
  return unwrap(type).isFloat8E5M2();
}

MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
  return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}

MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
  return wrap(Float8E4M3FNType::getTypeID());
}

bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
  return unwrap(type).isFloat8E4M3FN();
}

MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
  return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}

MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
  return wrap(Float8E5M2FNUZType::getTypeID());
}

bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
  return unwrap(type).isFloat8E5M2FNUZ();
}

MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
  return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}

MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
  return wrap(Float8E4M3FNUZType::getTypeID());
}

bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
  return unwrap(type).isFloat8E4M3FNUZ();
}

MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
  return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}

MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
  return wrap(Float8E4M3B11FNUZType::getTypeID());
}

bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
  return unwrap(type).isFloat8E4M3B11FNUZ();
}

MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
  return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
}

MlirTypeID mlirBFloat16TypeGetTypeID() {
  return wrap(BFloat16Type::getTypeID());
}

bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }

MlirType mlirBF16TypeGet(MlirContext ctx) {
  return wrap(FloatType::getBF16(unwrap(ctx)));
}

MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }

bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }

MlirType mlirF16TypeGet(MlirContext ctx) {
  return wrap(FloatType::getF16(unwrap(ctx)));
}

MlirTypeID mlirFloatTF32TypeGetTypeID() {
  return wrap(FloatTF32Type::getTypeID());
}

bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }

MlirType mlirTF32TypeGet(MlirContext ctx) {
  return wrap(FloatType::getTF32(unwrap(ctx)));
}

MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }

bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }

MlirType mlirF32TypeGet(MlirContext ctx) {
  return wrap(FloatType::getF32(unwrap(ctx)));
}

MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }

bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }

MlirType mlirF64TypeGet(MlirContext ctx) {
  return wrap(FloatType::getF64(unwrap(ctx)));
}

//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }

bool mlirTypeIsANone(MlirType type) {
  return llvm::isa<NoneType>(unwrap(type));
}

MlirType mlirNoneTypeGet(MlirContext ctx) {
  return wrap(NoneType::get(unwrap(ctx)));
}

//===----------------------------------------------------------------------===//
// Complex type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }

bool mlirTypeIsAComplex(MlirType type) {
  return llvm::isa<ComplexType>(unwrap(type));
}

MlirType mlirComplexTypeGet(MlirType elementType) {
  return wrap(ComplexType::get(unwrap(elementType)));
}

MlirType mlirComplexTypeGetElementType(MlirType type) {
  return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
}

//===----------------------------------------------------------------------===//
// Shaped type.
//===----------------------------------------------------------------------===//

bool mlirTypeIsAShaped(MlirType type) {
  return llvm::isa<ShapedType>(unwrap(type));
}

MlirType mlirShapedTypeGetElementType(MlirType type) {
  return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType());
}

bool mlirShapedTypeHasRank(MlirType type) {
  return llvm::cast<ShapedType>(unwrap(type)).hasRank();
}

int64_t mlirShapedTypeGetRank(MlirType type) {
  return llvm::cast<ShapedType>(unwrap(type)).getRank();
}

bool mlirShapedTypeHasStaticShape(MlirType type) {
  return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape();
}

bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
  return llvm::cast<ShapedType>(unwrap(type))
      .isDynamicDim(static_cast<unsigned>(dim));
}

int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
  return llvm::cast<ShapedType>(unwrap(type))
      .getDimSize(static_cast<unsigned>(dim));
}

int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }

bool mlirShapedTypeIsDynamicSize(int64_t size) {
  return ShapedType::isDynamic(size);
}

bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
  return ShapedType::isDynamic(val);
}

int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
  return ShapedType::kDynamic;
}

//===----------------------------------------------------------------------===//
// Vector type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }

bool mlirTypeIsAVector(MlirType type) {
  return llvm::isa<VectorType>(unwrap(type));
}

MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                           MlirType elementType) {
  return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
                              unwrap(elementType)));
}

MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                  const int64_t *shape, MlirType elementType) {
  return wrap(VectorType::getChecked(
      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
      unwrap(elementType)));
}

MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                   const bool *scalable, MlirType elementType) {
  return wrap(VectorType::get(
      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}

MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                          const int64_t *shape,
                                          const bool *scalable,
                                          MlirType elementType) {
  return wrap(VectorType::getChecked(
      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
      unwrap(elementType),
      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}

bool mlirVectorTypeIsScalable(MlirType type) {
  return unwrap(type).cast<VectorType>().isScalable();
}

bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
  return unwrap(type).cast<VectorType>().getScalableDims()[dim];
}

//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//

bool mlirTypeIsATensor(MlirType type) {
  return llvm::isa<TensorType>(unwrap(type));
}

MlirTypeID mlirRankedTensorTypeGetTypeID() {
  return wrap(RankedTensorType::getTypeID());
}

bool mlirTypeIsARankedTensor(MlirType type) {
  return llvm::isa<RankedTensorType>(unwrap(type));
}

MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
  return wrap(UnrankedTensorType::getTypeID());
}

bool mlirTypeIsAUnrankedTensor(MlirType type) {
  return llvm::isa<UnrankedTensorType>(unwrap(type));
}

MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
                                 MlirType elementType, MlirAttribute encoding) {
  return wrap(
      RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
                            unwrap(elementType), unwrap(encoding)));
}

MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                        const int64_t *shape,
                                        MlirType elementType,
                                        MlirAttribute encoding) {
  return wrap(RankedTensorType::getChecked(
      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
      unwrap(elementType), unwrap(encoding)));
}

MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
  return wrap(llvm::cast<RankedTensorType>(unwrap(type)).getEncoding());
}

MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
  return wrap(UnrankedTensorType::get(unwrap(elementType)));
}

MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
                                          MlirType elementType) {
  return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}

MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
  return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
}

//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }

bool mlirTypeIsAMemRef(MlirType type) {
  return llvm::isa<MemRefType>(unwrap(type));
}

MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
                           const int64_t *shape, MlirAttribute layout,
                           MlirAttribute memorySpace) {
  return wrap(MemRefType::get(
      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
      mlirAttributeIsNull(layout)
          ? MemRefLayoutAttrInterface()
          : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
      unwrap(memorySpace)));
}

MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
                                  intptr_t rank, const int64_t *shape,
                                  MlirAttribute layout,
                                  MlirAttribute memorySpace) {
  return wrap(MemRefType::getChecked(
      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
      unwrap(elementType),
      mlirAttributeIsNull(layout)
          ? MemRefLayoutAttrInterface()
          : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
      unwrap(memorySpace)));
}

MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
                                     const int64_t *shape,
                                     MlirAttribute memorySpace) {
  return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
                              unwrap(elementType), MemRefLayoutAttrInterface(),
                              unwrap(memorySpace)));
}

MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
                                            MlirType elementType, intptr_t rank,
                                            const int64_t *shape,
                                            MlirAttribute memorySpace) {
  return wrap(MemRefType::getChecked(
      unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
      unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}

MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
  return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout());
}

MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
  return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout().getAffineMap());
}

MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
  return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}

MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
                                                    int64_t *strides,
                                                    int64_t *offset) {
  MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
  SmallVector<int64_t> strides_;
  if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
    return mlirLogicalResultFailure();

  (void)std::copy(strides_.begin(), strides_.end(), strides);
  return mlirLogicalResultSuccess();
}

MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
  return wrap(UnrankedMemRefType::getTypeID());
}

bool mlirTypeIsAUnrankedMemRef(MlirType type) {
  return llvm::isa<UnrankedMemRefType>(unwrap(type));
}

MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
                                   MlirAttribute memorySpace) {
  return wrap(
      UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
}

MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
                                          MlirType elementType,
                                          MlirAttribute memorySpace) {
  return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
                                             unwrap(memorySpace)));
}

MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
  return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace());
}

//===----------------------------------------------------------------------===//
// Tuple type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }

bool mlirTypeIsATuple(MlirType type) {
  return llvm::isa<TupleType>(unwrap(type));
}

MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
                          MlirType const *elements) {
  SmallVector<Type, 4> types;
  ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
  return wrap(TupleType::get(unwrap(ctx), typeRef));
}

intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
  return llvm::cast<TupleType>(unwrap(type)).size();
}

MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
  return wrap(
      llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos)));
}

//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirFunctionTypeGetTypeID() {
  return wrap(FunctionType::getTypeID());
}

bool mlirTypeIsAFunction(MlirType type) {
  return llvm::isa<FunctionType>(unwrap(type));
}

MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
                             MlirType const *inputs, intptr_t numResults,
                             MlirType const *results) {
  SmallVector<Type, 4> inputsList;
  SmallVector<Type, 4> resultsList;
  (void)unwrapList(numInputs, inputs, inputsList);
  (void)unwrapList(numResults, results, resultsList);
  return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
}

intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
  return llvm::cast<FunctionType>(unwrap(type)).getNumInputs();
}

intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
  return llvm::cast<FunctionType>(unwrap(type)).getNumResults();
}

MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
  assert(pos >= 0 && "pos in array must be positive");
  return wrap(llvm::cast<FunctionType>(unwrap(type))
                  .getInput(static_cast<unsigned>(pos)));
}

MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
  assert(pos >= 0 && "pos in array must be positive");
  return wrap(llvm::cast<FunctionType>(unwrap(type))
                  .getResult(static_cast<unsigned>(pos)));
}

//===----------------------------------------------------------------------===//
// Opaque type.
//===----------------------------------------------------------------------===//

MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }

bool mlirTypeIsAOpaque(MlirType type) {
  return llvm::isa<OpaqueType>(unwrap(type));
}

MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
                           MlirStringRef typeData) {
  return wrap(
      OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
                      unwrap(typeData)));
}

MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
  return wrap(
      llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());
}

MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
  return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData());
}
