blob: d16b19bf309bebf1a38ec0a5604f848395638dd4 [file] [log] [blame]
//===- OperationSupport.h ---------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines a number of support types that Operation and related
// classes build on top of.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPERATION_SUPPORT_H
#define MLIR_IR_OPERATION_SUPPORT_H
#include "mlir/IR/BlockSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/InterfaceSupport.h"
#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include "llvm/Support/TrailingObjects.h"
#include <memory>
namespace llvm {
class BitVector;
} // end namespace llvm
namespace mlir {
class Dialect;
class DictionaryAttr;
class ElementsAttr;
class MutableOperandRangeRange;
class Operation;
struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class OperandRange;
class OperandRangeRange;
class OpFoldResult;
class ParseResult;
class Pattern;
class Region;
class ResultRange;
class RewritePattern;
class Type;
class Value;
class ValueRange;
template <typename ValueRangeT>
class ValueTypeRange;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
class OperationName {
public:
using GetCanonicalizationPatternsFn =
llvm::unique_function<void(RewritePatternSet &, MLIRContext *) const>;
using FoldHookFn = llvm::unique_function<LogicalResult(
Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) const>;
using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
using ParseAssemblyFn =
llvm::unique_function<ParseResult(OpAsmParser &, OperationState &) const>;
using PrintAssemblyFn =
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
protected:
/// This class represents a type erased version of an operation. It contains
/// all of the components necessary for opaquely interacting with an
/// operation. If the operation is not registered, some of these components
/// may not be populated.
struct Impl {
Impl(StringAttr name)
: name(name), dialect(nullptr), interfaceMap(llvm::None) {}
/// The name of the operation.
StringAttr name;
//===------------------------------------------------------------------===//
// Registered Operation Info
/// The following fields are only populated when the operation is
/// registered.
/// Returns true if the operation has been registered, i.e. if the
/// registration info has been populated.
bool isRegistered() const { return dialect; }
/// This is the dialect that this operation belongs to.
Dialect *dialect;
/// The unique identifier of the derived Op class.
TypeID typeID;
/// A map of interfaces that were registered to this operation.
detail::InterfaceMap interfaceMap;
/// Internal callback hooks provided by the op implementation.
FoldHookFn foldHookFn;
GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
HasTraitFn hasTraitFn;
ParseAssemblyFn parseAssemblyFn;
PrintAssemblyFn printAssemblyFn;
VerifyInvariantsFn verifyInvariantsFn;
/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
/// lookup/creation/etc., as opposed to raw strings.
ArrayRef<StringAttr> attributeNames;
};
public:
OperationName(StringRef name, MLIRContext *context);
/// Return if this operation is registered.
bool isRegistered() const { return impl->isRegistered(); }
/// If this operation is registered, returns the registered information, None
/// otherwise.
Optional<RegisteredOperationName> getRegisteredInfo() const;
/// Returns true if the operation was registered with a particular trait, e.g.
/// hasTrait<OperandsAreSignlessIntegerLike>(). Returns false if the operation
/// is unregistered.
template <template <typename T> class Trait> bool hasTrait() const {
return hasTrait(TypeID::get<Trait>());
}
bool hasTrait(TypeID traitID) const {
return isRegistered() && impl->hasTraitFn(traitID);
}
/// Returns true if the operation *might* have the provided trait. This
/// means that either the operation is unregistered, or it was registered with
/// the provide trait.
template <template <typename T> class Trait> bool mightHaveTrait() const {
return mightHaveTrait(TypeID::get<Trait>());
}
bool mightHaveTrait(TypeID traitID) const {
return !isRegistered() || impl->hasTraitFn(traitID);
}
/// Returns an instance of the concept object for the given interface if it
/// was registered to this operation, null otherwise. This should not be used
/// directly.
template <typename T> typename T::Concept *getInterface() const {
return impl->interfaceMap.lookup<T>();
}
/// Returns true if this operation has the given interface registered to it.
template <typename T> bool hasInterface() const {
return hasInterface(TypeID::get<T>());
}
bool hasInterface(TypeID interfaceID) const {
return impl->interfaceMap.contains(interfaceID);
}
/// Return the dialect this operation is registered to if the dialect is
/// loaded in the context, or nullptr if the dialect isn't loaded.
Dialect *getDialect() const {
return isRegistered() ? impl->dialect : impl->name.getReferencedDialect();
}
/// Return the name of the dialect this operation is registered to.
StringRef getDialectNamespace() const;
/// Return the operation name with dialect name stripped, if it has one.
StringRef stripDialect() const { return getStringRef().split('.').second; }
/// Return the name of this operation. This always succeeds.
StringRef getStringRef() const { return getIdentifier(); }
/// Return the name of this operation as a StringAttr.
StringAttr getIdentifier() const { return impl->name; }
void print(raw_ostream &os) const;
void dump() const;
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
void *getAsOpaquePointer() const { return const_cast<Impl *>(impl); }
static OperationName getFromOpaquePointer(const void *pointer) {
return OperationName(
const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer)));
}
bool operator==(const OperationName &rhs) const { return impl == rhs.impl; }
bool operator!=(const OperationName &rhs) const { return !(*this == rhs); }
protected:
OperationName(Impl *impl) : impl(impl) {}
/// The internal implementation of the operation name.
Impl *impl;
/// Allow access to the Impl struct.
friend MLIRContextImpl;
};
inline raw_ostream &operator<<(raw_ostream &os, OperationName info) {
info.print(os);
return os;
}
// Make operation names hashable.
inline llvm::hash_code hash_value(OperationName arg) {
return llvm::hash_value(arg.getAsOpaquePointer());
}
//===----------------------------------------------------------------------===//
// RegisteredOperationName
//===----------------------------------------------------------------------===//
/// This is a "type erased" representation of a registered operation. This
/// should only be used by things like the AsmPrinter and other things that need
/// to be parameterized by generic operation hooks. Most user code should use
/// the concrete operation types.
class RegisteredOperationName : public OperationName {
public:
/// Lookup the registered operation information for the given operation.
/// Returns None if the operation isn't registered.
static Optional<RegisteredOperationName> lookup(StringRef name,
MLIRContext *ctx) {
return OperationName(name, ctx).getRegisteredInfo();
}
/// Register a new operation in a Dialect object.
/// This constructor is used by Dialect objects when they register the list of
/// operations they contain.
template <typename T>
static void insert(Dialect &dialect) {
insert(T::getOperationName(), dialect, TypeID::get<T>(),
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getFoldHookFn(),
T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
T::getHasTraitFn(), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
static void
insert(StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames);
/// Return the dialect this operation is registered to.
Dialect &getDialect() const { return *impl->dialect; }
/// Return the unique identifier of the derived Op class.
TypeID getTypeID() const { return impl->typeID; }
/// Use the specified object to parse this ops custom assembly format.
ParseResult parseAssembly(OpAsmParser &parser, OperationState &result) const;
/// Return the static hook for parsing this operation assembly.
const ParseAssemblyFn &getParseAssemblyFn() const {
return impl->parseAssemblyFn;
}
/// This hook implements the AsmPrinter for this operation.
void printAssembly(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect) const {
return impl->printAssemblyFn(op, p, defaultDialect);
}
/// This hook implements the verifier for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
LogicalResult verifyInvariants(Operation *op) const {
return impl->verifyInvariantsFn(op);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the Builder::createOrFold API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return failure.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return success.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
/// success. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) const {
return impl->foldHookFn(op, operands, results);
}
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) const {
return impl->getCanonicalizationPatternsFn(results, context);
}
/// Attach the given models as implementations of the corresponding interfaces
/// for the concrete operation.
template <typename... Models>
void attachInterface() {
impl->interfaceMap.insert<Models...>();
}
/// Returns true if the operation has a particular trait.
template <template <typename T> class Trait> bool hasTrait() const {
return hasTrait(TypeID::get<Trait>());
}
/// Returns true if the operation has a particular trait.
bool hasTrait(TypeID traitID) const { return impl->hasTraitFn(traitID); }
/// Return the list of cached attribute names registered to this operation.
/// The order of attributes cached here is unique to each type of operation,
/// and the interpretation of this attribute list should generally be driven
/// by the respective operation. In many cases, this caching removes the need
/// to use the raw string name of a known attribute.
///
/// For example the ODS generator, with an op defining the following
/// attributes:
///
/// let arguments = (ins I32Attr:$attr1, I32Attr:$attr2);
///
/// ... may produce an order here of ["attr1", "attr2"]. This allows for the
/// ODS generator to directly access the cached name for a known attribute,
/// greatly simplifying the cost and complexity of attribute usage produced by
/// the generator.
///
ArrayRef<StringAttr> getAttributeNames() const {
return impl->attributeNames;
}
/// Represent the operation name as an opaque pointer. (Used to support
/// PointerLikeTypeTraits).
static RegisteredOperationName getFromOpaquePointer(const void *pointer) {
return RegisteredOperationName(
const_cast<Impl *>(reinterpret_cast<const Impl *>(pointer)));
}
private:
RegisteredOperationName(Impl *impl) : OperationName(impl) {}
/// Allow access to the constructor.
friend OperationName;
};
inline Optional<RegisteredOperationName>
OperationName::getRegisteredInfo() const {
return isRegistered() ? RegisteredOperationName(impl)
: Optional<RegisteredOperationName>();
}
//===----------------------------------------------------------------------===//
// Attribute Dictionary-Like Interface
//===----------------------------------------------------------------------===//
/// Attribute collections provide a dictionary-like interface. Define common
/// lookup functions.
namespace impl {
/// Unsorted string search or identifier lookups are linear scans.
template <typename IteratorT, typename NameT>
std::pair<IteratorT, bool> findAttrUnsorted(IteratorT first, IteratorT last,
NameT name) {
for (auto it = first; it != last; ++it)
if (it->getName() == name)
return {it, true};
return {last, false};
}
/// Using llvm::lower_bound requires an extra string comparison to check whether
/// the returned iterator points to the found element or whether it indicates
/// the lower bound. Skip this redundant comparison by checking if `compare ==
/// 0` during the binary search.
template <typename IteratorT>
std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
StringRef name) {
ptrdiff_t length = std::distance(first, last);
while (length > 0) {
ptrdiff_t half = length / 2;
IteratorT mid = first + half;
int compare = mid->getName().strref().compare(name);
if (compare < 0) {
first = mid + 1;
length = length - half - 1;
} else if (compare > 0) {
length = half;
} else {
return {mid, true};
}
}
return {first, false};
}
/// StringAttr lookups on large attribute lists will switch to string binary
/// search. String binary searches become significantly faster than linear scans
/// with the identifier when the attribute list becomes very large.
template <typename IteratorT>
std::pair<IteratorT, bool> findAttrSorted(IteratorT first, IteratorT last,
StringAttr name) {
constexpr unsigned kSmallAttributeList = 16;
if (std::distance(first, last) > kSmallAttributeList)
return findAttrSorted(first, last, name.strref());
return findAttrUnsorted(first, last, name);
}
} // end namespace impl
//===----------------------------------------------------------------------===//
// NamedAttrList
//===----------------------------------------------------------------------===//
/// NamedAttrList is array of NamedAttributes that tracks whether it is sorted
/// and does some basic work to remain sorted.
class NamedAttrList {
public:
using iterator = SmallVectorImpl<NamedAttribute>::iterator;
using const_iterator = SmallVectorImpl<NamedAttribute>::const_iterator;
using reference = NamedAttribute &;
using const_reference = const NamedAttribute &;
using size_type = size_t;
NamedAttrList() : dictionarySorted({}, true) {}
NamedAttrList(ArrayRef<NamedAttribute> attributes);
NamedAttrList(DictionaryAttr attributes);
NamedAttrList(const_iterator in_start, const_iterator in_end);
bool operator!=(const NamedAttrList &other) const {
return !(*this == other);
}
bool operator==(const NamedAttrList &other) const {
return attrs == other.attrs;
}
/// Add an attribute with the specified name.
void append(StringRef name, Attribute attr);
/// Add an attribute with the specified name.
void append(StringAttr name, Attribute attr) {
append(NamedAttribute(name, attr));
}
/// Append the given named attribute.
void append(NamedAttribute attr) { push_back(attr); }
/// Add an array of named attributes.
template <typename RangeT>
void append(RangeT &&newAttributes) {
append(std::begin(newAttributes), std::end(newAttributes));
}
/// Add a range of named attributes.
template <typename IteratorT>
void append(IteratorT in_start, IteratorT in_end) {
// TODO: expand to handle case where values appended are in order & after
// end of current list.
dictionarySorted.setPointerAndInt(nullptr, false);
attrs.append(in_start, in_end);
}
/// Replaces the attributes with new list of attributes.
void assign(const_iterator in_start, const_iterator in_end);
/// Replaces the attributes with new list of attributes.
void assign(ArrayRef<NamedAttribute> range) {
assign(range.begin(), range.end());
}
bool empty() const { return attrs.empty(); }
void reserve(size_type N) { attrs.reserve(N); }
/// Add an attribute with the specified name.
void push_back(NamedAttribute newAttribute);
/// Pop last element from list.
void pop_back() { attrs.pop_back(); }
/// Returns an entry with a duplicate name the list, if it exists, else
/// returns llvm::None.
Optional<NamedAttribute> findDuplicate() const;
/// Return a dictionary attribute for the underlying dictionary. This will
/// return an empty dictionary attribute if empty rather than null.
DictionaryAttr getDictionary(MLIRContext *context) const;
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
/// Return the specified attribute if present, null otherwise.
Attribute get(StringAttr name) const;
Attribute get(StringRef name) const;
/// Return the specified named attribute if present, None otherwise.
Optional<NamedAttribute> getNamed(StringRef name) const;
Optional<NamedAttribute> getNamed(StringAttr name) const;
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
/// Returns the previous attribute value of `name`, or null if no
/// attribute previously existed with `name`.
Attribute set(StringAttr name, Attribute value);
Attribute set(StringRef name, Attribute value);
/// Erase the attribute with the given name from the list. Return the
/// attribute that was erased, or nullptr if there was no attribute with such
/// name.
Attribute erase(StringAttr name);
Attribute erase(StringRef name);
iterator begin() { return attrs.begin(); }
iterator end() { return attrs.end(); }
const_iterator begin() const { return attrs.begin(); }
const_iterator end() const { return attrs.end(); }
NamedAttrList &operator=(const SmallVectorImpl<NamedAttribute> &rhs);
operator ArrayRef<NamedAttribute>() const;
private:
/// Return whether the attributes are sorted.
bool isSorted() const { return dictionarySorted.getInt(); }
/// Erase the attribute at the given iterator position.
Attribute eraseImpl(SmallVectorImpl<NamedAttribute>::iterator it);
/// Lookup an attribute in the list.
template <typename AttrListT, typename NameT>
static auto findAttr(AttrListT &attrs, NameT name) {
return attrs.isSorted()
? impl::findAttrSorted(attrs.begin(), attrs.end(), name)
: impl::findAttrUnsorted(attrs.begin(), attrs.end(), name);
}
// These are marked mutable as they may be modified (e.g., sorted)
mutable SmallVector<NamedAttribute, 4> attrs;
// Pair with cached DictionaryAttr and status of whether attrs is sorted.
// Note: just because sorted does not mean a DictionaryAttr has been created
// but the case where there is a DictionaryAttr but attrs isn't sorted should
// not occur.
mutable llvm::PointerIntPair<Attribute, 1, bool> dictionarySorted;
};
//===----------------------------------------------------------------------===//
// OperationState
//===----------------------------------------------------------------------===//
/// This represents an operation in an abstracted form, suitable for use with
/// the builder APIs. This object is a large and heavy weight object meant to
/// be used as a temporary object on the stack. It is generally unwise to put
/// this in a collection.
struct OperationState {
Location location;
OperationName name;
SmallVector<Value, 4> operands;
/// Types of the results of this operation.
SmallVector<Type, 4> types;
NamedAttrList attributes;
/// Successors of this operation and their respective operands.
SmallVector<Block *, 1> successors;
/// Regions that the op will hold.
SmallVector<std::unique_ptr<Region>, 1> regions;
public:
OperationState(Location location, StringRef name);
OperationState(Location location, OperationName name);
OperationState(Location location, StringRef name, ValueRange operands,
TypeRange types, ArrayRef<NamedAttribute> attributes,
BlockRange successors = {},
MutableArrayRef<std::unique_ptr<Region>> regions = {});
void addOperands(ValueRange newOperands);
void addTypes(ArrayRef<Type> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, ArrayRef<Type>>::value>
addTypes(RangeT &&newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
/// Add an attribute with the specified name.
void addAttribute(StringRef name, Attribute attr) {
addAttribute(StringAttr::get(getContext(), name), attr);
}
/// Add an attribute with the specified name.
void addAttribute(StringAttr name, Attribute attr) {
attributes.append(name, attr);
}
/// Add an array of named attributes.
void addAttributes(ArrayRef<NamedAttribute> newAttributes) {
attributes.append(newAttributes);
}
void addSuccessors(Block *successor) { successors.push_back(successor); }
void addSuccessors(BlockRange newSuccessors);
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
Region *addRegion();
/// Take a region that should be attached to the Operation. The body of the
/// region will be transferred when the Operation is constructed. If the
/// region is null, a new empty region will be attached to the Operation.
void addRegion(std::unique_ptr<Region> &&region);
/// Take ownership of a set of regions that should be attached to the
/// Operation.
void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
/// Get the context held by this operation state.
MLIRContext *getContext() const { return location->getContext(); }
};
//===----------------------------------------------------------------------===//
// OperandStorage
//===----------------------------------------------------------------------===//
namespace detail {
/// This class handles the management of operation operands. Operands are
/// stored either in a trailing array, or a dynamically resizable vector.
class alignas(8) OperandStorage {
public:
OperandStorage(Operation *owner, OpOperand *trailingOperands,
ValueRange values);
~OperandStorage();
/// Replace the operands contained in the storage with the ones provided in
/// 'values'.
void setOperands(Operation *owner, ValueRange values);
/// Replace the operands beginning at 'start' and ending at 'start' + 'length'
/// with the ones provided in 'operands'. 'operands' may be smaller or larger
/// than the range pointed to by 'start'+'length'.
void setOperands(Operation *owner, unsigned start, unsigned length,
ValueRange operands);
/// Erase the operands held by the storage within the given range.
void eraseOperands(unsigned start, unsigned length);
/// Erase the operands held by the storage that have their corresponding bit
/// set in `eraseIndices`.
void eraseOperands(const llvm::BitVector &eraseIndices);
/// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() { return {operandStorage, size()}; }
/// Return the number of operands held in the storage.
unsigned size() { return numOperands; }
private:
/// Resize the storage to the given size. Returns the array containing the new
/// operands.
MutableArrayRef<OpOperand> resize(Operation *owner, unsigned newSize);
/// The total capacity number of operands that the storage can hold.
unsigned capacity : 31;
/// A flag indicating if the operand storage was dynamically allocated, as
/// opposed to inlined into the owning operation.
unsigned isStorageDynamic : 1;
/// The number of operands within the storage.
unsigned numOperands;
/// A pointer to the operand storage.
OpOperand *operandStorage;
};
} // end namespace detail
//===----------------------------------------------------------------------===//
// OpPrintingFlags
//===----------------------------------------------------------------------===//
/// Set of flags used to control the behavior of the various IR print methods
/// (e.g. Operation::Print).
class OpPrintingFlags {
public:
OpPrintingFlags();
OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
/// Enables the elision of large elements attributes by printing a lexically
/// valid but otherwise meaningless form instead of the element data. The
/// `largeElementLimit` is used to configure what is considered to be a
/// "large" ElementsAttr by providing an upper limit to the number of
/// elements.
OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);
/// Enable printing of debug information. If 'prettyForm' is set to true,
/// debug information is printed in a more readable 'pretty' form. Note: The
/// IR generated with 'prettyForm' is not parsable.
OpPrintingFlags &enableDebugInfo(bool prettyForm = false);
/// Always print operations in the generic form.
OpPrintingFlags &printGenericOpForm();
/// Use local scope when printing the operation. This allows for using the
/// printer in a more localized and thread-safe setting, but may not
/// necessarily be identical to what the IR will look like when dumping
/// the full module.
OpPrintingFlags &useLocalScope();
/// Return if the given ElementsAttr should be elided.
bool shouldElideElementsAttr(ElementsAttr attr) const;
/// Return the size limit for printing large ElementsAttr.
Optional<int64_t> getLargeElementsAttrLimit() const;
/// Return if debug information should be printed.
bool shouldPrintDebugInfo() const;
/// Return if debug information should be printed in the pretty form.
bool shouldPrintDebugInfoPrettyForm() const;
/// Return if operations should be printed in the generic form.
bool shouldPrintGenericOpForm() const;
/// Return if the printer should use local scope when dumping the IR.
bool shouldUseLocalScope() const;
private:
/// Elide large elements attributes if the number of elements is larger than
/// the upper limit.
Optional<int64_t> elementsAttrElementLimit;
/// Print debug information.
bool printDebugInfoFlag : 1;
bool printDebugInfoPrettyFormFlag : 1;
/// Print operations in the generic form.
bool printGenericOpFormFlag : 1;
/// Print operations with numberings local to the current operation.
bool printLocalScope : 1;
};
//===----------------------------------------------------------------------===//
// Operation Value-Iterators
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// OperandRange
/// This class implements the operand iterators for the Operation class.
class OperandRange final : public llvm::detail::indexed_accessor_range_base<
OperandRange, OpOperand *, Value, Value, Value> {
public:
using RangeBaseT::RangeBaseT;
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<OperandRange>;
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }
/// Return the operand index of the first element of this range. The range
/// must not be empty.
unsigned getBeginOperandIndex() const;
/// Split this range into a set of contiguous subranges using the given
/// elements attribute, which contains the sizes of the sub ranges.
OperandRangeRange split(ElementsAttr segmentSizes) const;
private:
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
return object + index;
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
static Value dereference_iterator(OpOperand *object, ptrdiff_t index) {
return object[index].get();
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// OperandRangeRange
/// This class represents a contiguous range of operand ranges, e.g. from a
/// VariadicOfVariadic operand group.
class OperandRangeRange final
: public llvm::indexed_accessor_range<
OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange,
OperandRange, OperandRange> {
using OwnerT = std::pair<OpOperand *, Attribute>;
using RangeBaseT =
llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange,
OperandRange, OperandRange>;
public:
using RangeBaseT::RangeBaseT;
/// Returns the range of types of the values within this range.
TypeRangeRange getTypes() const { return TypeRangeRange(*this); }
auto getType() const { return getTypes(); }
/// Construct a range given a parent set of operands, and an I32 elements
/// attribute containing the sizes of the sub ranges.
OperandRangeRange(OperandRange operands, Attribute operandSegments);
/// Flatten all of the sub ranges into a single contiguous operand range.
OperandRange join() const;
private:
/// See `llvm::indexed_accessor_range` for details.
static OperandRange dereference(const OwnerT &object, ptrdiff_t index);
/// Allow access to `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// MutableOperandRange
/// This class provides a mutable adaptor for a range of operands. It allows for
/// setting, inserting, and erasing operands from the given range.
class MutableOperandRange {
public:
/// A pair of a named attribute corresponding to an operand segment attribute,
/// and the index within that attribute. The attribute should correspond to an
/// i32 DenseElementsAttr.
using OperandSegment = std::pair<unsigned, NamedAttribute>;
/// Construct a new mutable range from the given operand, operand start index,
/// and range length. `operandSegments` is an optional set of operand segments
/// to be updated when mutating the operand list.
MutableOperandRange(Operation *owner, unsigned start, unsigned length,
ArrayRef<OperandSegment> operandSegments = llvm::None);
MutableOperandRange(Operation *owner);
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment = llvm::None) const;
/// Append the given values to the range.
void append(ValueRange values);
/// Assign this range to the given values.
void assign(ValueRange values);
/// Assign the range to the given value.
void assign(Value value);
/// Erase the operands within the given sub-range.
void erase(unsigned subStart, unsigned subLen = 1);
/// Clear this range and erase all of the operands.
void clear();
/// Returns the current size of the range.
unsigned size() const { return length; }
/// Returns if the current range is empty.
bool empty() const { return size() == 0; }
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
/// Returns the owning operation.
Operation *getOwner() const { return owner; }
/// Split this range into a set of contiguous subranges using the given
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
/// The owning operation of this range.
Operation *owner;
/// The start index of the operand range within the owner operand list, and
/// the length starting from `start`.
unsigned start, length;
/// Optional set of operand segments that should be updated when mutating the
/// length of this range.
SmallVector<OperandSegment, 1> operandSegments;
};
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
/// This class represents a contiguous range of mutable operand ranges, e.g.
/// from a VariadicOfVariadic operand group.
class MutableOperandRangeRange final
: public llvm::indexed_accessor_range<
MutableOperandRangeRange,
std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange,
MutableOperandRange, MutableOperandRange> {
using OwnerT = std::pair<MutableOperandRange, NamedAttribute>;
using RangeBaseT =
llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT,
MutableOperandRange, MutableOperandRange,
MutableOperandRange>;
public:
using RangeBaseT::RangeBaseT;
/// Construct a range given a parent set of operands, and an I32 tensor
/// elements attribute containing the sizes of the sub ranges.
MutableOperandRangeRange(const MutableOperandRange &operands,
NamedAttribute operandSegmentAttr);
/// Flatten all of the sub ranges into a single contiguous mutable operand
/// range.
MutableOperandRange join() const;
/// Allow implicit conversion to an OperandRangeRange.
operator OperandRangeRange() const;
private:
/// See `llvm::indexed_accessor_range` for details.
static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index);
/// Allow access to `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// ResultRange
/// This class implements the result iterators for the Operation class.
class ResultRange final
: public llvm::detail::indexed_accessor_range_base<
ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> {
public:
using RangeBaseT::RangeBaseT;
ResultRange(OpResult result);
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<ResultRange>;
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }
//===--------------------------------------------------------------------===//
// Uses
//===--------------------------------------------------------------------===//
class UseIterator;
using use_iterator = UseIterator;
using use_range = iterator_range<use_iterator>;
/// Returns a range of all uses of results within this range, which is useful
/// for iterating over all uses.
use_range getUses() const;
use_iterator use_begin() const;
use_iterator use_end() const;
/// Returns true if no results in this range have uses.
bool use_empty() const {
return llvm::all_of(*this,
[](OpResult result) { return result.use_empty(); });
}
/// Replace all uses of results of this range with the provided 'values'. The
/// size of `values` must match the size of this range.
template <typename ValuesT>
std::enable_if_t<!std::is_convertible<ValuesT, Operation *>::value>
replaceAllUsesWith(ValuesT &&values) {
assert(static_cast<size_t>(std::distance(values.begin(), values.end())) ==
size() &&
"expected 'values' to correspond 1-1 with the number of results");
for (auto it : llvm::zip(*this, values))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
}
/// Replace all uses of results of this range with results of 'op'.
void replaceAllUsesWith(Operation *op);
//===--------------------------------------------------------------------===//
// Users
//===--------------------------------------------------------------------===//
using user_iterator = ValueUserIterator<use_iterator, OpOperand>;
using user_range = iterator_range<user_iterator>;
/// Returns a range of all users.
user_range getUsers();
user_iterator user_begin();
user_iterator user_end();
private:
/// See `llvm::detail::indexed_accessor_range_base` for details.
static detail::OpResultImpl *offset_base(detail::OpResultImpl *object,
ptrdiff_t index) {
return object->getNextResultAtOffset(index);
}
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OpResult dereference_iterator(detail::OpResultImpl *object,
ptrdiff_t index) {
return offset_base(object, index);
}
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
/// This class implements a use iterator for a range of operation results.
/// This iterates over all uses of all results within the given result range.
class ResultRange::UseIterator final
: public llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
OpOperand> {
public:
/// Initialize the UseIterator. Specify `end` to return iterator to last
/// use, otherwise this is an iterator to the first use.
explicit UseIterator(ResultRange results, bool end = false);
using llvm::iterator_facade_base<UseIterator, std::forward_iterator_tag,
OpOperand>::operator++;
UseIterator &operator++();
OpOperand *operator->() const { return use.getOperand(); }
OpOperand &operator*() const { return *use.getOperand(); }
bool operator==(const UseIterator &rhs) const { return use == rhs.use; }
bool operator!=(const UseIterator &rhs) const { return !(*this == rhs); }
private:
void skipOverResultsWithNoUsers();
/// The range of results being iterated over.
ResultRange::iterator it, endIt;
/// The use of the result.
Value::use_iterator use;
};
//===----------------------------------------------------------------------===//
// ValueRange
/// This class provides an abstraction over the different types of ranges over
/// Values. In many cases, this prevents the need to explicitly materialize a
/// SmallVector/std::vector. This class should be used in places that are not
/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange final
: public llvm::detail::indexed_accessor_range_base<
ValueRange,
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
Value, Value, Value> {
public:
/// The type representing the owner of a ValueRange. This is either a list of
/// values, operands, or results.
using OwnerT =
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
using RangeBaseT::RangeBaseT;
template <typename Arg,
typename = typename std::enable_if_t<
std::is_constructible<ArrayRef<Value>, Arg>::value &&
!std::is_convertible<Arg, Value>::value>>
ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
ValueRange(const std::initializer_list<Value> &values)
: ValueRange(ArrayRef<Value>(values)) {}
ValueRange(iterator_range<OperandRange::iterator> values)
: ValueRange(OperandRange(values)) {}
ValueRange(iterator_range<ResultRange::iterator> values)
: ValueRange(ResultRange(values)) {}
ValueRange(ArrayRef<BlockArgument> values)
: ValueRange(ArrayRef<Value>(values.data(), values.size())) {}
ValueRange(ArrayRef<Value> values = llvm::None);
ValueRange(OperandRange values);
ValueRange(ResultRange values);
/// Returns the types of the values within this range.
using type_iterator = ValueTypeIterator<iterator>;
using type_range = ValueTypeRange<ValueRange>;
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }
private:
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
/// See `llvm::detail::indexed_accessor_range_base` for details.
static Value dereference_iterator(const OwnerT &owner, ptrdiff_t index);
/// Allow access to `offset_base` and `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// Operation Equivalency
//===----------------------------------------------------------------------===//
/// This class provides utilities for computing if two operations are
/// equivalent.
struct OperationEquivalence {
enum Flags {
None = 0,
// When provided, the location attached to the operation are ignored.
IgnoreLocations = 1,
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations)
};
/// Compute a hash for the given operation.
/// The `hashOperands` and `hashResults` callbacks are expected to return a
/// unique hash_code for a given Value.
static llvm::hash_code computeHash(
Operation *op,
function_ref<llvm::hash_code(Value)> hashOperands =
[](Value v) { return hash_value(v); },
function_ref<llvm::hash_code(Value)> hashResults =
[](Value v) { return hash_value(v); },
Flags flags = Flags::None);
/// Helper that can be used with `computeHash` above to ignore operation
/// operands/result mapping.
static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; }
/// Helper that can be used with `computeHash` above to ignore operation
/// operands/result mapping.
static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
/// Compare two operations and return if they are equivalent.
/// `mapOperands` and `mapResults` are optional callbacks that allows the
/// caller to check the mapping of SSA value between the lhs and rhs
/// operations. It is expected to return success if the mapping is valid and
/// failure if it conflicts with a previous mapping.
static bool
isEquivalentTo(Operation *lhs, Operation *rhs,
function_ref<LogicalResult(Value, Value)> mapOperands,
function_ref<LogicalResult(Value, Value)> mapResults,
Flags flags = Flags::None);
/// Helper that can be used with `isEquivalentTo` above to ignore operation
/// operands/result mapping.
static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {
return success();
}
/// Helper that can be used with `isEquivalentTo` above to ignore operation
/// operands/result mapping.
static LogicalResult exactValueMatch(Value lhs, Value rhs) {
return success(lhs == rhs);
}
};
/// Enable Bitmask enums for OperationEquivalence::Flags.
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
} // end namespace mlir
namespace llvm {
template <>
struct DenseMapInfo<mlir::OperationName> {
static mlir::OperationName getEmptyKey() {
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::OperationName::getFromOpaquePointer(pointer);
}
static mlir::OperationName getTombstoneKey() {
void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::OperationName::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(mlir::OperationName val) {
return DenseMapInfo<void *>::getHashValue(val.getAsOpaquePointer());
}
static bool isEqual(mlir::OperationName lhs, mlir::OperationName rhs) {
return lhs == rhs;
}
};
template <>
struct DenseMapInfo<mlir::RegisteredOperationName>
: public DenseMapInfo<mlir::OperationName> {
static mlir::RegisteredOperationName getEmptyKey() {
void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::RegisteredOperationName::getFromOpaquePointer(pointer);
}
static mlir::RegisteredOperationName getTombstoneKey() {
void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::RegisteredOperationName::getFromOpaquePointer(pointer);
}
};
template <>
struct PointerLikeTypeTraits<mlir::OperationName> {
static inline void *getAsVoidPointer(mlir::OperationName I) {
return const_cast<void *>(I.getAsOpaquePointer());
}
static inline mlir::OperationName getFromVoidPointer(void *P) {
return mlir::OperationName::getFromOpaquePointer(P);
}
static constexpr int NumLowBitsAvailable =
PointerLikeTypeTraits<void *>::NumLowBitsAvailable;
};
template <>
struct PointerLikeTypeTraits<mlir::RegisteredOperationName>
: public PointerLikeTypeTraits<mlir::OperationName> {
static inline mlir::RegisteredOperationName getFromVoidPointer(void *P) {
return mlir::RegisteredOperationName::getFromOpaquePointer(P);
}
};
} // end namespace llvm
#endif