blob: 2361cf1371237bc9bb11f95c8cbbb393bbc7d0ab [file] [log] [blame]
//===- BuiltinTypes.h - MLIR Builtin Type Classes ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_BUILTINTYPES_H
#define MLIR_IR_BUILTINTYPES_H
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
namespace llvm {
class BitVector;
struct fltSemantics;
} // namespace llvm
//===----------------------------------------------------------------------===//
// Tablegen Interface Declarations
//===----------------------------------------------------------------------===//
namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
class RankedTensorType;
class StringAttr;
class TypeRange;
namespace detail {
struct FunctionTypeStorage;
struct IntegerTypeStorage;
struct TupleTypeStorage;
} // namespace detail
//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//
class FloatType : public Type {
public:
using Type::Type;
// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
/// Return the bitwidth of this float type.
unsigned getWidth();
/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth();
/// Get or create a new FloatType with bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
FloatType scaleElementBitwidth(unsigned scale);
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics();
};
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived tensor types.
class TensorType : public Type, public ShapedType::Trait<TensorType> {
public:
using Type::Type;
/// Returns the element type of this tensor type.
Type getElementType() const;
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;
/// Returns the shape of this tensor type.
ArrayRef<int64_t> getShape() const;
/// Clone this type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
// Make sure that base class overloads are visible.
using ShapedType::Trait<TensorType>::clone;
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
RankedTensorType clone(ArrayRef<int64_t> shape) const;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
/// This class provides a shared interface for ranked and unranked memref types.
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;
/// Returns the element type of this memref type.
Type getElementType() const;
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
bool hasRank() const;
/// Returns the shape of this memref type.
ArrayRef<int64_t> getShape() const;
/// Clone this type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;
/// Return a clone of this type with the given new shape and element type.
/// The returned type is ranked, even if this type is unranked.
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
/// Return a clone of this type with the given new shape. The returned type
/// is ranked, even if this type is unranked.
MemRefType clone(ArrayRef<int64_t> shape) const;
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);
/// Returns the memory space in which data referred to by this memref resides.
Attribute getMemorySpace() const;
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
} // namespace mlir
//===----------------------------------------------------------------------===//
// Tablegen Type Declarations
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.h.inc"
namespace mlir {
//===----------------------------------------------------------------------===//
// MemRefType
//===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class MemRefType::Builder {
public:
// Build from another MemRefType.
explicit Builder(MemRefType other)
: shape(other.getShape()), elementType(other.getElementType()),
layout(other.getLayout()), memorySpace(other.getMemorySpace()) {}
// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType)
: shape(shape), elementType(elementType) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
Builder &setLayout(MemRefLayoutAttrInterface newLayout) {
layout = newLayout;
return *this;
}
Builder &setMemorySpace(Attribute newMemorySpace) {
memorySpace = newMemorySpace;
return *this;
}
operator MemRefType() {
return MemRefType::get(shape, elementType, layout, memorySpace);
}
private:
ArrayRef<int64_t> shape;
Type elementType;
MemRefLayoutAttrInterface layout;
Attribute memorySpace;
};
//===----------------------------------------------------------------------===//
// RankedTensorType
//===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class RankedTensorType::Builder {
public:
/// Build from another RankedTensorType.
explicit Builder(RankedTensorType other)
: shape(other.getShape()), elementType(other.getElementType()),
encoding(other.getEncoding()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType, Attribute encoding)
: shape(shape), elementType(elementType), encoding(encoding) {}
Builder &setShape(ArrayRef<int64_t> newShape) {
shape = newShape;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
Builder &setEncoding(Attribute newEncoding) {
encoding = newEncoding;
return *this;
}
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
shape.erase(pos);
return *this;
}
/// Insert a val into shape @pos.
Builder &insertDim(int64_t val, unsigned pos) {
assert(pos <= shape.size() && "overflow");
shape.insert(pos, val);
return *this;
}
operator RankedTensorType() {
return RankedTensorType::get(shape, elementType, encoding);
}
private:
CopyOnWriteArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
//===----------------------------------------------------------------------===//
// VectorType
//===----------------------------------------------------------------------===//
/// This is a builder type that keeps local references to arguments. Arguments
/// that are passed into the builder must outlive the builder.
class VectorType::Builder {
public:
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
shape = newShape;
scalableDims = newIsScalableDim;
return *this;
}
Builder &setElementType(Type newElementType) {
elementType = newElementType;
return *this;
}
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
shape.erase(pos);
if (!scalableDims.empty())
scalableDims.erase(pos);
return *this;
}
/// Set a dim in shape @pos to val.
Builder &setDim(unsigned pos, int64_t val) {
assert(pos < shape.size() && "overflow");
shape.set(pos, val);
return *this;
}
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
private:
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
/// `originalShape` with some `1` entries erased, return the set of indices
/// that specifies which of the entries of `originalShape` are dropped to obtain
/// `reducedShape`. The returned mask can be applied as a projection to
/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
/// which dimensions must be kept when e.g. compute MemRef strides under
/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
/// obtained by dropping only `1` entries in `originalShape`.
std::optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask(ArrayRef<int64_t> originalShape,
ArrayRef<int64_t> reducedShape);
/// Enum that captures information related to verifier error conditions on
/// slice insert/extract type of ops.
enum class SliceVerificationResult {
Success,
RankTooLarge,
SizeMismatch,
ElemTypeMismatch,
// Error codes to ops with a memory space and a layout annotation.
MemSpaceMismatch,
LayoutMismatch
};
/// Check if `originalType` can be rank reduced to `candidateReducedType` type
/// by dropping some dimensions with static size `1`.
/// Return `SliceVerificationResult::Success` on success or an appropriate error
/// code.
SliceVerificationResult isRankReducedType(ShapedType originalType,
ShapedType candidateReducedType);
//===----------------------------------------------------------------------===//
// Deferred Method Definitions
//===----------------------------------------------------------------------===//
inline bool BaseMemRefType::classof(Type type) {
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
}
inline bool BaseMemRefType::isValidElementType(Type type) {
return type.isIntOrIndexOrFloat() ||
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
type) ||
llvm::isa<MemRefElementTypeInterface>(type);
}
inline bool FloatType::classof(Type type) {
return llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, BFloat16Type,
Float16Type, FloatTF32Type, Float32Type, Float64Type,
Float80Type, Float128Type>(type);
}
inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
return Float8E5M2Type::get(ctx);
}
inline FloatType FloatType::getFloat8E4M3FN(MLIRContext *ctx) {
return Float8E4M3FNType::get(ctx);
}
inline FloatType FloatType::getFloat8E5M2FNUZ(MLIRContext *ctx) {
return Float8E5M2FNUZType::get(ctx);
}
inline FloatType FloatType::getFloat8E4M3FNUZ(MLIRContext *ctx) {
return Float8E4M3FNUZType::get(ctx);
}
inline FloatType FloatType::getFloat8E4M3B11FNUZ(MLIRContext *ctx) {
return Float8E4M3B11FNUZType::get(ctx);
}
inline FloatType FloatType::getBF16(MLIRContext *ctx) {
return BFloat16Type::get(ctx);
}
inline FloatType FloatType::getF16(MLIRContext *ctx) {
return Float16Type::get(ctx);
}
inline FloatType FloatType::getTF32(MLIRContext *ctx) {
return FloatTF32Type::get(ctx);
}
inline FloatType FloatType::getF32(MLIRContext *ctx) {
return Float32Type::get(ctx);
}
inline FloatType FloatType::getF64(MLIRContext *ctx) {
return Float64Type::get(ctx);
}
inline FloatType FloatType::getF80(MLIRContext *ctx) {
return Float80Type::get(ctx);
}
inline FloatType FloatType::getF128(MLIRContext *ctx) {
return Float128Type::get(ctx);
}
inline bool TensorType::classof(Type type) {
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
}
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with a layout map in strided form include:
/// 1. empty or identity layout map, in which case the stride information is
/// the canonical form computed from sizes;
/// 2. a StridedLayoutAttr layout;
/// 3. any other layout that be converted into a single affine map layout of
/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
/// symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along a
/// particular dimension.
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);
/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
/// int64_t) that will assert if the logical result is not succeeded.
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
/// dynamic and need to be encoded with a different symbol.
/// For canonical strides expressions, the offset is always 0 and the fastest
/// varying stride is always `1`.
///
/// Examples:
/// - memref<3x4x5xf32> has canonical stride expression
/// `20*exprs[0] + 5*exprs[1] + exprs[2]`.
/// - memref<3x?x5xf32> has canonical stride expression
/// `s0*exprs[0] + 5*exprs[1] + exprs[2]`.
/// - memref<3x4x?xf32> has canonical stride expression
/// `s1*exprs[0] + s0*exprs[1] + exprs[2]`.
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
ArrayRef<AffineExpr> exprs,
MLIRContext *context);
/// Return the result of makeCanonicalStrudedLayoutExpr for the common case
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context);
/// Return "true" if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);
/// Return "true" if the last dimension of the given type has a static unit
/// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type);
/// Return "true" if the last N dimensions of the given type are contiguous.
///
/// Examples:
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
/// considering both _all_ and _only_ the trailing 3 dims,
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
/// considering the trailing 3 dims.
///
bool trailingNDimsContiguous(MemRefType type, int64_t n);
} // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H