blob: 4c5dbad4550f926e162207302bd419849df994c7 [file] [log] [blame]
//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_STANDARDTYPES_H
#define MLIR_IR_STANDARDTYPES_H
#include "mlir/IR/Types.h"
namespace llvm {
struct fltSemantics;
} // namespace llvm
namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class Location;
class MLIRContext;
namespace detail {
struct IntegerTypeStorage;
struct ShapedTypeStorage;
struct VectorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
struct UnrankedMemRefTypeStorage;
struct ComplexTypeStorage;
struct TupleTypeStorage;
} // namespace detail
namespace StandardTypes {
enum Kind {
// Floating point.
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
F16,
F32,
F64,
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
// Derived types.
Integer,
Vector,
RankedTensor,
UnrankedTensor,
MemRef,
UnrankedMemRef,
Complex,
Tuple,
None,
};
} // namespace StandardTypes
/// Index is a special integer-like type with unknown platform-dependent bit
/// width.
class IndexType : public Type::TypeBase<IndexType, Type> {
public:
using Base::Base;
/// Get an instance of the IndexType.
static IndexType get(MLIRContext *context);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == StandardTypes::Index; }
};
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType
: public Type::TypeBase<IntegerType, Type, detail::IntegerTypeStorage> {
public:
using Base::Base;
/// Get or create a new IntegerType of the given width within the context.
/// Assume the width is within the allowed range and assert on failures.
/// Use getChecked to handle failures gracefully.
static IntegerType get(unsigned width, MLIRContext *context);
/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. If the width is
/// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(unsigned width, MLIRContext *context,
Location location);
/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
unsigned width);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
};
class FloatType : public Type::TypeBase<FloatType, Type> {
public:
using Base::Base;
static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
// Convenience factories.
static FloatType getBF16(MLIRContext *ctx) {
return get(StandardTypes::BF16, ctx);
}
static FloatType getF16(MLIRContext *ctx) {
return get(StandardTypes::F16, ctx);
}
static FloatType getF32(MLIRContext *ctx) {
return get(StandardTypes::F32, ctx);
}
static FloatType getF64(MLIRContext *ctx) {
return get(StandardTypes::F64, ctx);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
}
/// Return the bitwidth of this float type.
unsigned getWidth();
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics();
};
/// The 'complex' type represents a complex number with a parameterized element
/// type, which is composed of a real and imaginary value of that element type.
///
/// The element must be a floating point or integer scalar type.
///
class ComplexType
: public Type::TypeBase<ComplexType, Type, detail::ComplexTypeStorage> {
public:
using Base::Base;
/// Get or create a ComplexType with the provided element type.
static ComplexType get(Type elementType);
/// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type
/// isn't supported.
static ComplexType getChecked(Type elementType, Location location);
/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Type elementType);
Type getElementType();
static bool kindof(unsigned kind) { return kind == StandardTypes::Complex; }
};
/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
/// and MemRef types because they share behavior and semantics around shape,
/// rank, and fixed element type. Any type with these semantics should inherit
/// from ShapedType.
class ShapedType : public Type {
public:
using ImplType = detail::ShapedTypeStorage;
using Type::Type;
// TODO(ntv): merge these two special values in a single one used everywhere.
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
// hard to track.
static constexpr int64_t kDynamicSize = -1;
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
/// Return the element type.
Type getElementType() const;
/// If an element type is an integer or a float, return its width. Otherwise,
/// abort.
unsigned getElementTypeBitWidth() const;
/// If it has static shape, return the number of elements. Otherwise, abort.
int64_t getNumElements() const;
/// If this is a ranked type, return the rank. Otherwise, abort.
int64_t getRank() const;
/// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
/// have a rank, while unranked tensors do not.
bool hasRank() const;
/// If this is a ranked type, return the shape. Otherwise, abort.
ArrayRef<int64_t> getShape() const;
/// If this is unranked type or any dimension has unknown size (<0), it
/// doesn't have static shape. If all dimensions have known size (>= 0), it
/// has static shape.
bool hasStaticShape() const;
/// If this has a static shape and the shape is equal to `shape` return true.
bool hasStaticShape(ArrayRef<int64_t> shape) const;
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const;
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(int64_t i) const;
/// Returns the position of the dynamic dimension relative to just the dynamic
/// dimensions, given its `index` within the shape.
unsigned getDynamicDimIndex(unsigned index) const;
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
/// if the size cannot be computed statically, i.e. if the type has a dynamic
/// shape or if its elemental type does not have a known bit width.
int64_t getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::Vector ||
type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor ||
type.getKind() == StandardTypes::UnrankedMemRef ||
type.getKind() == StandardTypes::MemRef;
}
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
return dStrideOrOffset == kDynamicStrideOrOffset;
}
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType
: public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
public:
using Base::Base;
/// Get or create a new VectorType of the provided shape and element type.
/// Assumes the arguments define a well-formed VectorType.
static VectorType get(ArrayRef<int64_t> shape, Type elementType);
/// Get or create a new VectorType of the provided shape and element type
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location);
/// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
ArrayRef<int64_t> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
};
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public ShapedType {
public:
using ShapedType::ShapedType;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type) {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return type.isIntOrFloat() || type.isa<ComplexType>() ||
type.isa<VectorType>() || type.isa<OpaqueType>() ||
(type.getKind() > Type::Kind::LAST_STANDARD_TYPE);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor;
}
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
/// with a fixed number of dimensions. Each shape element can be a positive
/// integer or unknown (represented -1).
class RankedTensorType
: public Type::TypeBase<RankedTensorType, TensorType,
detail::RankedTensorTypeStorage> {
public:
using Base::Base;
/// Get or create a new RankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
/// Get or create a new RankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type.
static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location);
/// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
ArrayRef<int64_t> getShape() const;
static bool kindof(unsigned kind) {
return kind == StandardTypes::RankedTensor;
}
};
/// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape.
class UnrankedTensorType
: public Type::TypeBase<UnrankedTensorType, TensorType,
detail::UnrankedTensorTypeStorage> {
public:
using Base::Base;
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static UnrankedTensorType get(Type elementType);
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type.
static UnrankedTensorType getChecked(Type elementType, Location location);
/// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }
static bool kindof(unsigned kind) {
return kind == StandardTypes::UnrankedTensor;
}
};
/// Base MemRef for Ranked and Unranked variants
class BaseMemRefType : public ShapedType {
public:
using ShapedType::ShapedType;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::MemRef ||
type.getKind() == StandardTypes::UnrankedMemRef;
}
};
/// MemRef types represent a region of memory that have a shape with a fixed
/// number of dimensions. Each shape element can be a non-negative integer or
/// unknown (represented by any negative integer). MemRef types also have an
/// affine map composition, represented as an array AffineMap pointers.
class MemRefType : public Type::TypeBase<MemRefType, BaseMemRefType,
detail::MemRefTypeStorage> {
public:
using Base::Base;
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space. Assumes the arguments define a
/// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
/// construction failures.
static MemRefType get(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space declared at the given location.
/// If the location is unknown, the last argument should be an instance of
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Location location);
ArrayRef<int64_t> getShape() const;
/// Returns an array of affine map pointers representing the memref affine
/// map composition.
ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
// TODO(ntv): merge these two special values in a single one used everywhere.
// Unfortunately, uses of `-1` have crept deep into the codebase now and are
// hard to track.
static constexpr int64_t kDynamicSize = -1;
static int64_t getDynamicStrideOrOffset() {
return ShapedType::kDynamicStrideOrOffset;
}
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
private:
/// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided,
/// emit detailed error messages.
static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Optional<Location> location);
using Base::getImpl;
};
/// Unranked MemRef type represent multi-dimensional MemRefs that
/// have an unknown rank.
class UnrankedMemRefType
: public Type::TypeBase<UnrankedMemRefType, BaseMemRefType,
detail::UnrankedMemRefTypeStorage> {
public:
using Base::Base;
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space
static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space declared at the given, potentially unknown,
/// location. If the UnrankedMemRefType defined by the arguments would be
/// ill-formed, emit errors and return a nullptr-wrapping type.
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
Location location);
/// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Type elementType,
unsigned memorySpace);
ArrayRef<int64_t> getShape() const { return llvm::None; }
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
static bool kindof(unsigned kind) {
return kind == StandardTypes::UnrankedMemRef;
}
};
/// Tuple types represent a collection of other types. Note: This type merely
/// provides a common mechanism for representing tuples in MLIR. It is up to
/// dialect authors to provides operations for manipulating them, e.g.
/// extract_tuple_element. When possible, users should prefer multi-result
/// operations in the place of tuples.
class TupleType
: public Type::TypeBase<TupleType, Type, detail::TupleTypeStorage> {
public:
using Base::Base;
/// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type.
static TupleType get(ArrayRef<Type> elementTypes, MLIRContext *context);
/// Get or create an empty tuple type.
static TupleType get(MLIRContext *context) { return get({}, context); }
/// Return the elements types for this tuple.
ArrayRef<Type> getTypes() const;
/// Accumulate the types contained in this tuple and tuples nested within it.
/// Note that this only flattens nested tuples, not any other container type,
/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
/// (i32, tensor<i32>, f32, i64)
void getFlattenedTypes(SmallVectorImpl<Type> &types);
/// Return the number of held types.
size_t size() const;
/// Iterate over the held elements.
using iterator = ArrayRef<Type>::iterator;
iterator begin() const { return getTypes().begin(); }
iterator end() const { return getTypes().end(); }
/// Return the element type at index 'index'.
Type getType(size_t index) const {
assert(index < size() && "invalid index for tuple type");
return getTypes()[index];
}
static bool kindof(unsigned kind) { return kind == StandardTypes::Tuple; }
};
/// NoneType is a unit type, i.e. a type with exactly one possible value, where
/// its value does not have a defined dynamic representation.
class NoneType : public Type::TypeBase<NoneType, Type> {
public:
using Base::Base;
/// Get an instance of the NoneType.
static NoneType get(MLIRContext *context);
static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
};
/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with layout maps in strided form include:
/// 1. empty or identity layout map, in which case the stride information is
/// the canonical form computed from sizes;
/// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
/// where K and ki's are constants or symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with getDynamicStrideOrOffset()). Strides encode the
/// distance in the number of elements between successive entries along a
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
/// elements in which the distance between two consecutive elements along the
/// outer dimension is `1` and the distance between two consecutive elements
/// along the inner dimension is `64`.
///
/// If a simple strided form cannot be extracted from the composition of the
/// layout map, returns llvm::None.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<AffineExpr> &strides,
AffineExpr &offset);
/// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset()
/// represents a dynamic value), return the single result AffineMap which
/// represents the linearized strided layout map. Dimensions correspond to the
/// offset followed by the strides in order. Symbols are inserted for each
/// dynamic dimension in order. A stride cannot take value `0`.
///
/// Examples:
/// =========
///
/// 1. For offset: 0 strides: ?, ?, 1 return
/// (i, j, k)[M, N]->(M * i + N * j + k)
///
/// 2. For offset: 3 strides: 32, ?, 16 return
/// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k)
///
/// 3. For offset: ? strides: ?, ?, ? return
/// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k)
AffineMap makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, int64_t offset,
MLIRContext *context);
/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplifed layout.
MemRefType canonicalizeStridedLayout(MemRefType t);
/// Return true if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);
} // end namespace mlir
#endif // MLIR_IR_STANDARDTYPES_H