blob: c4f8f2d905e56ce2d108b0310658b288ca4bb491 [file] [log] [blame]
//===- InferTypeOpInterface.h - Infer Type Interfaces -----------*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// This file contains the definitions of the infer op interfaces defined in
// ``.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
namespace mlir {
using ReifiedRankedShapedTypeDims = SmallVector<SmallVector<Value>>;
/// ShapedTypeComponents that represents the components of a ShapedType.
/// The components consist of
/// - A ranked or unranked shape with the dimension specification match those
/// of ShapeType's getShape() (e.g., dynamic dimension represented using
/// ShapedType::kDynamicSize)
/// - A element type, may be unset (nullptr)
/// - A attribute, may be unset (nullptr)
/// Used by ShapedType type inferences.
class ShapedTypeComponents {
/// Internal storage type for shape.
using ShapeStorageT = SmallVector<int64_t, 3>;
/// Default construction is an unranked shape.
ShapedTypeComponents() : elementType(nullptr), attr(nullptr), ranked(false){};
ShapedTypeComponents(Type elementType)
: elementType(elementType), attr(nullptr), ranked(false) {}
ShapedTypeComponents(ShapedType shapedType) : attr(nullptr) {
ranked = shapedType.hasRank();
elementType = shapedType.getElementType();
if (ranked)
dims = llvm::to_vector<4>(shapedType.getShape());
template <typename Arg, typename = typename std::enable_if_t<
std::is_constructible<ShapeStorageT, Arg>::value>>
ShapedTypeComponents(Arg &&arg, Type elementType = nullptr,
Attribute attr = nullptr)
: dims(std::forward<Arg>(arg)), elementType(elementType), attr(attr),
ranked(true) {}
ShapedTypeComponents(ArrayRef<int64_t> vec, Type elementType = nullptr,
Attribute attr = nullptr)
: dims(vec.begin(), vec.end()), elementType(elementType), attr(attr),
ranked(true) {}
/// Return the dimensions of the shape.
/// Requires: shape is ranked.
ArrayRef<int64_t> getDims() const {
assert(ranked && "requires ranked shape");
return dims;
/// Return whether the shape has a rank.
bool hasRank() const { return ranked; };
/// Return the element type component.
Type getElementType() const { return elementType; };
/// Return the raw attribute component.
Attribute getAttribute() const { return attr; };
friend class ShapeAdaptor;
ShapeStorageT dims;
Type elementType;
Attribute attr;
bool ranked;
/// Adaptor class to abstract the differences between whether value is from
/// a ShapedType or ShapedTypeComponents or DenseIntElementsAttribute.
class ShapeAdaptor {
ShapeAdaptor(Type t) {
if (auto st = t.dyn_cast<ShapedType>())
val = st;
ShapeAdaptor(Attribute t) {
if (auto da = t.dyn_cast<DenseIntElementsAttr>())
val = da;
ShapeAdaptor(ShapedTypeComponents *components) : val(components) {}
ShapeAdaptor(ShapedTypeComponents &components) : val(&components) {}
/// Returns whether the shape has a rank.
bool hasRank() const;
/// Returns the element type.
Type getElementType() const;
/// Populates the dimensions from shape referenced.
/// Requires: shape is ranked.
void getDims(SmallVectorImpl<int64_t> &res) const;
/// Populates the dimensions of the ShapeTypeComponents.
/// Requires: shape is ranked.
void getDims(ShapedTypeComponents &res) const;
/// Returns the size of the index'th dimension.
/// Requires: shape is ranked.
int64_t getDimSize(int index) const;
/// Returns whether the index'th dimension is dynamic.
/// Requires: shape is ranked.
bool isDynamicDim(int index) const {
return ShapedType::isDynamic(getDimSize(index));
/// Returns whether the shape is fully static.
bool hasStaticShape() const;
/// Returns the rank of the shape.
/// Requires: shape is ranked.
int64_t getRank() const;
/// Returns the number of elements in the shape.
/// Requires: hasStaticShape
int64_t getNumElements() const;
/// Returns whether valid (non-null) shape.
operator bool() const { return !val.isNull(); }
/// Dumps textual repesentation to stderr.
void dump() const;
// Union storing either ShapedTypeComponents, ShapedType (stored as Type and
// casted), or DenseIntElementsAttribute (stored as Atrtribute).
PointerUnion<ShapedTypeComponents *, Type, Attribute> val = nullptr;
/// Range of values and shapes (corresponding effectively to Shapes dialect's
/// ValueShape type concept).
// Currently this exposes the Value (of operands) and Type of the Value. This is
// not ideal as then one can accidentally reference an out of date shape. This
// is done to both enable gradual switch and also as OpAdaptor doesn't currently
// allow returning anything other than Value.
class ValueShapeRange : public ValueRange::RangeBaseT {
using ValueShapeMapFn = function_ref<ShapeAdaptor(Value)>;
ValueShapeRange(ValueRange values, ValueShapeMapFn operandShape = nullptr,
ValueShapeMapFn valueToShape = nullptr)
: RangeBaseT(values), operandShape(operandShape),
valueToShape(valueToShape) {}
ValueShapeRange(const std::initializer_list<Value> &values)
: ValueShapeRange(ValueRange(values)) {}
ValueShapeRange(const ValueShapeRange &other) : RangeBaseT(other) {
operandShape = other.operandShape;
valueToShape = other.valueToShape;
/// Sets the Value to ShapeAdaptor mapping function and returns this.
ValueShapeRange &setValueToShapeMapping(ValueShapeMapFn fn) {
valueToShape = fn;
return *this;
ValueShapeRange &setOperandShapeMapping(ValueShapeMapFn fn) {
operandShape = fn;
return *this;
/// Returns the set Value to ShapeAdaptor mapping function.
ValueShapeMapFn getValueToShapeMapping() const { return valueToShape; }
ValueShapeMapFn getOperandShapeMapping() const { return operandShape; }
// Accessors.
/// Returns the types of the values within this range.
/// Note: This returns only the types of Values in the ValueRange and not a
/// more refined type.
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<ValueRange>;
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }
/// Returns the Values in the ValueRange.
/// To query the most up to date shape of a Value, query the shape
/// using getShape below rather than using the type of the Value.
ValueRange getValues() const { return ValueRange(begin(), end()); };
/// Returns an argument as shape. If the argument is not constant or not a
/// shape, then the function returns a nullptr.
/// This will first query the valueToShape mapping (if set), before querying
/// the ValueRange.
ShapeAdaptor getValueAsShape(int index);
/// Returns the shape of index'th operand.
// TODO: Update so that operator[] references these instead to avoid
// accidentally refering to less refined shape.
ShapeAdaptor getShape(int index) const;
/// Returns the shape of the given Value.
ShapeAdaptor getShape(Value val) const;
// Mapping from Value to ShapedTypeComponents corresponding to shape of type
// of Value.
ValueShapeMapFn operandShape;
// Mapping from Value to ShapedTypeComponents corresponding to constant Value
// if interpreted as shape.
ValueShapeMapFn valueToShape;
namespace detail {
// Helper function to infer return tensor returns types given element and
// shape inference function.
// TODO: Consider generating typedefs for trait member functions if this usage
// becomes more common.
LogicalResult inferReturnTensorTypes(
MLIRContext *, Optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &retComponents)>
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes);
/// Verifies that the inferred result types match the actual result types for
/// the op. Precondition: op implements InferTypeOpInterface.
LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail
namespace OpTrait {
template <typename ConcreteType>
class InferTensorType;
} // namespace OpTrait
} // namespace mlir
/// Include the generated interface declarations.
#include "mlir/Interfaces/"
namespace mlir {
namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the inferred
/// shape and elemental types.
/// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface.
/// Less strict is possible (e.g., implements inferReturnTypeComponents and
/// these always populates all element types and shapes or fails, but this\
/// trait is currently only used where the interfaces are, so keep it
/// restricted for now).
template <typename ConcreteType>
class InferTensorType : public TraitBase<ConcreteType, InferTensorType> {
static LogicalResult
inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
ConcreteType::template hasTrait<InferShapedTypeOpInterface::Trait>(),
"requires InferShapedTypeOpInterface to ensure succesful invocation");
ConcreteType::template hasTrait<InferTypeOpInterface::Trait>(),
"requires InferTypeOpInterface to ensure succesful invocation");
return ::mlir::detail::inferReturnTensorTypes(
ConcreteType::inferReturnTypeComponents, context, location, operands,
attributes, regions, inferredReturnTypes);
} // namespace OpTrait
} // namespace mlir