blob: 2fc3cfbf08092facc47dc3103b39bc2e142f6058 [file] [log] [blame]
//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helper classes for implementing the "Op" types. This
// includes the Op type, which is the base class for Op class definitions,
// as well as number of traits in the OpTrait namespace that provide a
// declarative way to specify properties of Ops.
//
// The purpose of these types are to allow light-weight implementation of
// concrete ops (like DimOp) with very little boilerplate.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPDEFINITION_H
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include <type_traits>
namespace mlir {
class Builder;
class OpBuilder;
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
/// parse rules without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
// Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(); }
};
/// This class implements `Optional` functionality for ParseResult. We don't
/// directly use Optional here, because it provides an implicit conversion
/// to 'bool' which we want to avoid. This class is used to implement tri-state
/// 'parseOptional' functions that may have a failure mode when parsing that
/// shouldn't be attributed to "not present".
class OptionalParseResult {
public:
OptionalParseResult() = default;
OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)
: OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
/// Returns true if we contain a valid ParseResult value.
bool hasValue() const { return impl.hasValue(); }
/// Access the internal ParseResult value.
ParseResult getValue() const { return impl.getValue(); }
ParseResult operator*() const { return getValue(); }
private:
Optional<ParseResult> impl;
};
// These functions are out-of-line utilities, which avoids them being template
// instantiated/duplicated.
namespace impl {
/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
/// region's only block if it does not have a terminator already. If the region
/// is empty, insert a new block first. `buildTerminatorOp` should return the
/// terminator operation to insert.
void ensureRegionTerminator(
Region &region, OpBuilder &builder, Location loc,
function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
void ensureRegionTerminator(
Region &region, Builder &builder, Location loc,
function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
} // namespace impl
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
///
/// This also has the fallback implementations of customization hooks for when
/// they aren't customized.
class OpState {
public:
/// Ops are pointer-like, so we allow conversion to bool.
explicit operator bool() { return getOperation() != nullptr; }
/// This implicitly converts to Operation*.
operator Operation *() const { return state; }
/// Shortcut of `->` to access a member of Operation.
Operation *operator->() const { return state; }
/// Return the operation that this refers to.
Operation *getOperation() { return state; }
/// Return the context this operation belongs to.
MLIRContext *getContext() { return getOperation()->getContext(); }
/// Print the operation to the given stream.
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
void print(raw_ostream &os, AsmState &asmState,
OpPrintingFlags flags = llvm::None) {
state->print(os, asmState, flags);
}
/// Dump this operation.
void dump() { state->dump(); }
/// The source location the operation was defined or derived from.
Location getLoc() { return state->getLoc(); }
/// Return true if there are no users of any results of this operation.
bool use_empty() { return state->use_empty(); }
/// Remove this operation from its parent block and delete it.
void erase() { state->erase(); }
/// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers.
InFlightDiagnostic emitOpError(const Twine &message = {});
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening.
InFlightDiagnostic emitError(const Twine &message = {});
/// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitWarning(const Twine &message = {});
/// Emit a remark about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
/// Walk the operation by calling the callback for each nested operation
/// (including this one), block or region, depending on the callback provided.
/// Regions, blocks and operations at the same nesting level are visited in
/// lexicographical order. The walk order for enclosing regions, blocks and
/// operations with respect to their nested ones is specified by 'Order'
/// (post-order by default). A callback on a block or operation is allowed to
/// erase that block or operation if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
return state->walk<Order>(std::forward<FnT>(callback));
}
// These are default implementations of customization hooks.
public:
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
static void getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
LogicalResult verify() { return success(); }
/// Unless overridden, the custom assembly form of an op is always rejected.
/// Op implementations should implement this to return failure.
/// On success, they should fill in result with the fields to use.
static ParseResult parse(OpAsmParser &parser, OperationState &result);
// The fallback for the printer is to print it the generic assembly form.
static void print(Operation *op, OpAsmPrinter &p);
static void printOpName(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect);
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
explicit OpState(Operation *state) : state(state) {}
private:
Operation *state;
/// Allow access to internal hook implementation methods.
friend RegisteredOperationName;
};
// Allow comparing operators.
inline bool operator==(OpState lhs, OpState rhs) {
return lhs.getOperation() == rhs.getOperation();
}
inline bool operator!=(OpState lhs, OpState rhs) {
return lhs.getOperation() != rhs.getOperation();
}
raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
/// This class represents a single result from folding an operation.
class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
public:
void dump() { llvm::errs() << *this << "\n"; }
};
/// Allow printing to a stream.
inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
if (Value value = ofr.dyn_cast<Value>())
value.print(os);
else
ofr.dyn_cast<Attribute>().print(os);
return os;
}
/// Allow printing to a stream.
inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
op.print(os, OpPrintingFlags().useLocalScope());
return os;
}
//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
OpFoldResult foldIdempotent(Operation *op);
OpFoldResult foldInvolution(Operation *op);
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyIsIdempotent(Operation *op);
LogicalResult verifyIsInvolution(Operation *op);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
LogicalResult verifyZeroRegion(Operation *op);
LogicalResult verifyOneRegion(Operation *op);
LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyZeroResult(Operation *op);
LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
LogicalResult verifySameOperandsShape(Operation *op);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
LogicalResult verifyIsTerminator(Operation *op);
LogicalResult verifyZeroSuccessor(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
StringRef valueGroupName,
size_t expectedCount);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);
LogicalResult verifyElementwise(Operation *op);
LogicalResult verifyIsIsolatedFromAbove(Operation *op);
} // namespace impl
/// Helper class for implementing traits. Clients are not expected to interact
/// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class TraitBase {
protected:
/// Return the ultimate Operation being worked on.
Operation *getOperation() {
// We have to cast up to the trait type, then to the concrete type, then to
// the BaseState class in explicit hops because the concrete type will
// multiply derive from the (content free) TraitBase class, and we need to
// be able to disambiguate the path for the C++ compiler.
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
auto *concrete = static_cast<ConcreteType *>(trait);
auto *base = static_cast<OpState *>(concrete);
return base->getOperation();
}
};
//===----------------------------------------------------------------------===//
// Operand Traits
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple operands.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
using operand_iterator = Operation::operand_iterator;
using operand_range = Operation::operand_range;
using operand_type_iterator = Operation::operand_type_iterator;
using operand_type_range = Operation::operand_type_range;
/// Return the number of operands.
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
/// Return the operand at index 'i'.
Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
/// Set the operand at index 'i' to 'value'.
void setOperand(unsigned i, Value value) {
this->getOperation()->setOperand(i, value);
}
/// Operand iterator access.
operand_iterator operand_begin() {
return this->getOperation()->operand_begin();
}
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
operand_range getOperands() { return this->getOperation()->getOperands(); }
/// Operand type access.
operand_type_iterator operand_type_begin() {
return this->getOperation()->operand_type_begin();
}
operand_type_iterator operand_type_end() {
return this->getOperation()->operand_type_end();
}
operand_type_range getOperandTypes() {
return this->getOperation()->getOperandTypes();
}
};
} // end namespace detail
/// This class provides the API for ops that are known to have no
/// SSA operand.
template <typename ConcreteType>
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroOperands(op);
}
private:
// Disable these.
void getOperand() {}
void setOperand() {}
};
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
Value getOperand() { return this->getOperation()->getOperand(0); }
void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneOperand(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of operands. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N>
class NOperands {
public:
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNOperands(op, N);
}
};
};
/// This class provides the API for ops that are known to have a at least a
/// specified number of operands. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
///
template <unsigned N>
class AtLeastNOperands {
public:
template <typename ConcreteType>
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
AtLeastNOperands<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNOperands(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// SSA operands.
template <typename ConcreteType>
class VariadicOperands
: public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
//===----------------------------------------------------------------------===//
// Region Traits
/// This class provides verification for ops that are known to have zero
/// regions.
template <typename ConcreteType>
class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroRegion(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple regions.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
using region_iterator = MutableArrayRef<Region>;
using region_range = RegionRange;
/// Return the number of regions.
unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
/// Return the region at `index`.
Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
/// Region iterator access.
region_iterator region_begin() {
return this->getOperation()->region_begin();
}
region_iterator region_end() { return this->getOperation()->region_end(); }
region_range getRegions() { return this->getOperation()->getRegions(); }
};
} // end namespace detail
/// This class provides APIs for ops that are known to have a single region.
template <typename ConcreteType>
class OneRegion : public TraitBase<ConcreteType, OneRegion> {
public:
Region &getRegion() { return this->getOperation()->getRegion(0); }
/// Returns a range of operations within the region of this operation.
auto getOps() { return getRegion().getOps(); }
template <typename OpT>
auto getOps() {
return getRegion().template getOps<OpT>();
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneRegion(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of regions.
template <unsigned N>
class NRegions {
public:
static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNRegions(op, N);
}
};
};
/// This class provides APIs for ops that are known to have at least a specified
/// number of regions.
template <unsigned N>
class AtLeastNRegions {
public:
template <typename ConcreteType>
class Impl : public detail::MultiRegionTraitBase<ConcreteType,
AtLeastNRegions<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNRegions(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// regions.
template <typename ConcreteType>
class VariadicRegions
: public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
//===----------------------------------------------------------------------===//
// Result Traits
/// This class provides return value APIs for ops that are known to have
/// zero results.
template <typename ConcreteType>
class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroResult(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple results.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
using result_iterator = Operation::result_iterator;
using result_range = Operation::result_range;
using result_type_iterator = Operation::result_type_iterator;
using result_type_range = Operation::result_type_range;
/// Return the number of results.
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
/// Return the result at index 'i'.
Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
template <typename ValuesT>
void replaceAllUsesWith(ValuesT &&values) {
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
}
/// Return the type of the `i`-th result.
Type getType(unsigned i) { return getResult(i).getType(); }
/// Result iterator access.
result_iterator result_begin() {
return this->getOperation()->result_begin();
}
result_iterator result_end() { return this->getOperation()->result_end(); }
result_range getResults() { return this->getOperation()->getResults(); }
/// Result type access.
result_type_iterator result_type_begin() {
return this->getOperation()->result_type_begin();
}
result_type_iterator result_type_end() {
return this->getOperation()->result_type_end();
}
result_type_range getResultTypes() {
return this->getOperation()->getResultTypes();
}
};
} // end namespace detail
/// This class provides return value APIs for ops that are known to have a
/// single result. ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value getResult() { return this->getOperation()->getResult(0); }
/// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value. This yields the value of the only result.
operator Value() { return getResult(); }
/// Replace all uses of 'this' value with the new value, updating anything
/// in the IR that uses 'this' to use the other value instead. When this
/// returns there are zero uses of 'this'.
void replaceAllUsesWith(Value newValue) {
getResult().replaceAllUsesWith(newValue);
}
/// Replace all uses of 'this' value with the result of 'op'.
void replaceAllUsesWith(Operation *op) {
this->getOperation()->replaceAllUsesWith(op);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneResult(op);
}
};
/// This trait is used for return value APIs for ops that are known to have a
/// specific type other than `Type`. This allows the "getType()" member to be
/// more specific for an op. This should be used in conjunction with OneResult,
/// and occur in the trait list before OneResult.
template <typename ResultType>
class OneTypedResult {
public:
/// This class provides return value APIs for ops that are known to have a
/// single result. ResultType is the concrete type returned by getType().
template <typename ConcreteType>
class Impl
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
ResultType getType() {
auto resultTy = this->getOperation()->getResult(0).getType();
return resultTy.template cast<ResultType>();
}
};
};
/// This class provides the API for ops that are known to have a specified
/// number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N>
class NResults {
public:
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNResults(op, N);
}
};
};
/// This class provides the API for ops that are known to have at least a
/// specified number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
///
template <unsigned N>
class AtLeastNResults {
public:
template <typename ConcreteType>
class Impl : public detail::MultiResultTraitBase<ConcreteType,
AtLeastNResults<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNResults(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// results.
template <typename ConcreteType>
class VariadicResults
: public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
//===----------------------------------------------------------------------===//
// Terminator Traits
/// This class indicates that the regions associated with this op don't have
/// terminators.
template <typename ConcreteType>
class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
/// This class provides the API for ops that are known to be terminators.
template <typename ConcreteType>
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyIsTerminator(op);
}
};
/// This class provides verification for ops that are known to have zero
/// successors.
template <typename ConcreteType>
class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroSuccessor(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple successors.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
using succ_iterator = Operation::succ_iterator;
using succ_range = SuccessorRange;
/// Return the number of successors.
unsigned getNumSuccessors() {
return this->getOperation()->getNumSuccessors();
}
/// Return the successor at `index`.
Block *getSuccessor(unsigned i) {
return this->getOperation()->getSuccessor(i);
}
/// Set the successor at `index`.
void setSuccessor(Block *block, unsigned i) {
return this->getOperation()->setSuccessor(block, i);
}
/// Successor iterator access.
succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_iterator succ_end() { return this->getOperation()->succ_end(); }
succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
};
} // end namespace detail
/// This class provides APIs for ops that are known to have a single successor.
template <typename ConcreteType>
class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
public:
Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
void setSuccessor(Block *succ) {
this->getOperation()->setSuccessor(succ, 0);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneSuccessor(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of successors.
template <unsigned N>
class NSuccessors {
public:
static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
template <typename ConcreteType>
class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
NSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNSuccessors(op, N);
}
};
};
/// This class provides APIs for ops that are known to have at least a specified
/// number of successors.
template <unsigned N>
class AtLeastNSuccessors {
public:
template <typename ConcreteType>
class Impl
: public detail::MultiSuccessorTraitBase<ConcreteType,
AtLeastNSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNSuccessors(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// successors.
template <typename ConcreteType>
class VariadicSuccessors
: public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
};
//===----------------------------------------------------------------------===//
// SingleBlock
/// This class provides APIs and verifiers for ops with regions having a single
/// block.
template <typename ConcreteType>
struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
public:
static LogicalResult verifyTrait(Operation *op) {
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
Region &region = op->getRegion(i);
// Empty regions are fine.
if (region.empty())
continue;
// Non-empty regions must contain a single basic block.
if (!llvm::hasSingleElement(region))
return op->emitOpError("expects region #")
<< i << " to have 0 or 1 blocks";
if (!ConcreteType::template hasTrait<NoTerminator>()) {
Block &block = region.front();
if (block.empty())
return op->emitOpError() << "expects a non-empty block";
}
}
return success();
}
Block *getBody(unsigned idx = 0) {
Region &region = this->getOperation()->getRegion(idx);
assert(!region.empty() && "unexpected empty region");
return &region.front();
}
Region &getBodyRegion(unsigned idx = 0) {
return this->getOperation()->getRegion(idx);
}
//===------------------------------------------------------------------===//
// Single Region Utilities
//===------------------------------------------------------------------===//
/// The following are a set of methods only enabled when the parent
/// operation has a single region. Each of these methods take an additional
/// template parameter that represents the concrete operation so that we
/// can use SFINAE to disable the methods for non-single region operations.
template <typename OpT, typename T = void>
using enable_if_single_region =
typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
template <typename OpT = ConcreteType>
enable_if_single_region<OpT, Block::iterator> begin() {
return getBody()->begin();
}
template <typename OpT = ConcreteType>
enable_if_single_region<OpT, Block::iterator> end() {
return getBody()->end();
}
template <typename OpT = ConcreteType>
enable_if_single_region<OpT, Operation &> front() {
return *begin();
}
/// Insert the operation into the back of the body.
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> push_back(Operation *op) {
insert(Block::iterator(getBody()->end()), op);
}
/// Insert the operation at the given insertion point.
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
insert(Block::iterator(insertPt), op);
}
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
getBody()->getOperations().insert(insertPt, op);
}
};
//===----------------------------------------------------------------------===//
// SingleBlockImplicitTerminator
/// This class provides APIs and verifiers for ops with regions having a single
/// block that must terminate with `TerminatorOpType`.
template <typename TerminatorOpType>
struct SingleBlockImplicitTerminator {
template <typename ConcreteType>
class Impl : public SingleBlock<ConcreteType> {
private:
using Base = SingleBlock<ConcreteType>;
/// Builds a terminator operation without relying on OpBuilder APIs to avoid
/// cyclic header inclusion.
static Operation *buildTerminator(OpBuilder &builder, Location loc) {
OperationState state(loc, TerminatorOpType::getOperationName());
TerminatorOpType::build(builder, state);
return Operation::create(state);
}
public:
/// The type of the operation used as the implicit terminator type.
using ImplicitTerminatorOpT = TerminatorOpType;
static LogicalResult verifyTrait(Operation *op) {
if (failed(Base::verifyTrait(op)))
return failure();
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
Region &region = op->getRegion(i);
// Empty regions are fine.
if (region.empty())
continue;
Operation &terminator = region.front().back();
if (isa<TerminatorOpType>(terminator))
continue;
return op->emitOpError("expects regions to end with '" +
TerminatorOpType::getOperationName() +
"', found '" +
terminator.getName().getStringRef() + "'")
.attachNote()
<< "in custom textual format, the absence of terminator implies "
"'"
<< TerminatorOpType::getOperationName() << '\'';
}
return success();
}
/// Ensure that the given region has the terminator required by this trait.
/// If OpBuilder is provided, use it to build the terminator and notify the
/// OpBuilder listeners accordingly. If only a Builder is provided, locally
/// construct an OpBuilder with no listeners; this should only be used if no
/// OpBuilder is available at the call site, e.g., in the parser.
static void ensureTerminator(Region &region, Builder &builder,
Location loc) {
::mlir::impl::ensureRegionTerminator(region, builder, loc,
buildTerminator);
}
static void ensureTerminator(Region &region, OpBuilder &builder,
Location loc) {
::mlir::impl::ensureRegionTerminator(region, builder, loc,
buildTerminator);
}
//===------------------------------------------------------------------===//
// Single Region Utilities
//===------------------------------------------------------------------===//
using Base::getBody;
template <typename OpT, typename T = void>
using enable_if_single_region =
typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
/// Insert the operation into the back of the body, before the terminator.
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> push_back(Operation *op) {
insert(Block::iterator(getBody()->getTerminator()), op);
}
/// Insert the operation at the given insertion point. Note: The operation
/// is never inserted after the terminator, even if the insertion point is
/// end().
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
insert(Block::iterator(insertPt), op);
}
template <typename OpT = ConcreteType>
enable_if_single_region<OpT> insert(Block::iterator insertPt,
Operation *op) {
auto *body = getBody();
if (insertPt == body->end())
insertPt = Block::iterator(body->getTerminator());
body->getOperations().insert(insertPt, op);
}
};
};
/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
/// to be used with `llvm::is_detected`.
template <class T>
using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
/// Support to check if an operation has the SingleBlockImplicitTerminator
/// trait. We can't just use `hasTrait` because this class is templated on a
/// specific terminator op.
template <class Op, bool hasTerminator =
llvm::is_detected<has_implicit_terminator_t, Op>::value>
struct hasSingleBlockImplicitTerminator {
static constexpr bool value = std::is_base_of<
typename OpTrait::SingleBlockImplicitTerminator<
typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
Op>::value;
};
template <class Op>
struct hasSingleBlockImplicitTerminator<Op, false> {
static constexpr bool value = false;
};
//===----------------------------------------------------------------------===//
// Misc Traits
/// This class provides verification for ops that are known to have the same
/// operand shape: all operands are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsShape(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result shape: both are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsAndResultShape
: public TraitBase<ConcreteType, SameOperandsAndResultShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultShape(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsElementType
: public TraitBase<ConcreteType, SameOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsElementType(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsAndResultElementType
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultElementType(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
/// Note: this trait subsumes the SameOperandsAndResultShape and
/// SameOperandsAndResultElementType traits.
template <typename ConcreteType>
class SameOperandsAndResultType
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultType(op);
}
};
/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreBoolLike(op);
}
};
/// This class verifies that any results of the specified op have a floating
/// point type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreFloatLike
: public TraitBase<ConcreteType, ResultsAreFloatLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreFloatLike(op);
}
};
/// This class verifies that any results of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreSignlessIntegerLike
: public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreSignlessIntegerLike(op);
}
};
/// This class adds property that the operation is commutative.
template <typename ConcreteType>
class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
/// This class adds property that the operation is an involution.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
template <typename ConcreteType>
class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
static_assert(ConcreteType::template hasTrait<OneOperand>(),
"expected operation to take one operand");
static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
"expected operation to preserve type");
// Involution requires the operation to be side effect free as well
// but currently this check is under a FIXME and is not actually done.
return impl::verifyIsInvolution(op);
}
static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
return impl::foldInvolution(op);
}
};
/// This class adds property that the operation is idempotent.
/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
/// or a binary operation "g" that satisfies g(x, x) = x.
template <typename ConcreteType>
class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
static_assert(ConcreteType::template hasTrait<OneOperand>() ||
ConcreteType::template hasTrait<NOperands<2>::Impl>(),
"expected operation to take one or two operands");
static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
"expected operation to preserve type");
// Idempotent requires the operation to be side effect free as well
// but currently this check is under a FIXME and is not actually done.
return impl::verifyIsIdempotent(op);
}
static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
return impl::foldIdempotent(op);
}
};
/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreFloatLike
: public TraitBase<ConcreteType, OperandsAreFloatLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandsAreFloatLike(op);
}
};
/// This class verifies that all operands of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreSignlessIntegerLike
: public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandsAreSignlessIntegerLike(op);
}
};
/// This class verifies that all operands of the specified op have the same
/// type.
template <typename ConcreteType>
class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameTypeOperands(op);
}
};
/// This class provides the API for a sub-set of ops that are known to be
/// constant-like. These are non-side effecting operations with one result and
/// zero operands that can always be folded to a specific attribute value.
template <typename ConcreteType>
class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(ConcreteType::template hasTrait<OneResult>(),
"expected operation to produce one result");
static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
"expected operation to take zero operands");
// TODO: We should verify that the operation can always be folded, but this
// requires that the attributes of the op already be verified. We should add
// support for verifying traits "after" the operation to enable this use
// case.
return success();
}
};
/// This class provides the API for ops that are known to be isolated from
/// above.
template <typename ConcreteType>
class IsIsolatedFromAbove
: public TraitBase<ConcreteType, IsIsolatedFromAbove> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyIsIsolatedFromAbove(op);
}
};
/// A trait of region holding operations that defines a new scope for polyhedral
/// optimization purposes. Any SSA values of 'index' type that either dominate
/// such an operation or are used at the top-level of such an operation
/// automatically become valid symbols for the polyhedral scope defined by that
/// operation. For more details, see `Traits.md#AffineScope`.
template <typename ConcreteType>
class AffineScope : public TraitBase<ConcreteType, AffineScope> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
"expected operation to have one or more regions");
return success();
}
};
/// A trait of region holding operations that define a new scope for automatic
/// allocations, i.e., allocations that are freed when control is transferred
/// back from the operation's region. Any operations performing such allocations
/// (for eg. memref.alloca) will have their allocations automatically freed at
/// their closest enclosing operation with this trait.
template <typename ConcreteType>
class AutomaticAllocationScope
: public TraitBase<ConcreteType, AutomaticAllocationScope> {
public:
static LogicalResult verifyTrait(Operation *op) {
static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
"expected operation to have one or more regions");
return success();
}
};
/// This class provides a verifier for ops that are expecting their parent
/// to be one of the given parent ops
template <typename... ParentOpTypes>
struct HasParent {
template <typename ConcreteType>
class Impl : public TraitBase<ConcreteType, Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
return success();
return op->emitOpError()
<< "expects parent op "
<< (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
<< llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
<< "'";
}
};
};
/// A trait for operations that have an attribute specifying operand segments.
///
/// Certain operations can have multiple variadic operands and their size
/// relationship is not always known statically. For such cases, we need
/// a per-op-instance specification to divide the operands into logical groups
/// or segments. This can be modeled by attributes. The attribute will be named
/// as `operand_segment_sizes`.
///
/// This trait verifies the attribute for specifying operand segments has
/// the correct type (1D vector) and values (non-negative), etc.
template <typename ConcreteType>
class AttrSizedOperandSegments
: public TraitBase<ConcreteType, AttrSizedOperandSegments> {
public:
static StringRef getOperandSegmentSizeAttr() {
return "operand_segment_sizes";
}
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
op, getOperandSegmentSizeAttr());
}
};
/// Similar to AttrSizedOperandSegments but used for results.
template <typename ConcreteType>
class AttrSizedResultSegments
: public TraitBase<ConcreteType, AttrSizedResultSegments> {
public:
static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyResultSizeAttr(
op, getResultSegmentSizeAttr());
}
};
/// This trait provides a verifier for ops that are expecting their regions to
/// not have any arguments
template <typename ConcrentType>
struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
}
};
// This trait is used to flag operations that consume or produce
// values of `MemRef` type where those references can be 'normalized'.
// TODO: Right now, the operands of an operation are either all normalizable,
// or not. In the future, we may want to allow some of the operands to be
// normalizable.
template <typename ConcrentType>
struct MemRefsNormalizable
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
/// This trait tags element-wise ops on vectors or tensors.
///
/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
/// trait. In particular, broadcasting behavior is not allowed.
///
/// An `Elementwise` op must satisfy the following properties:
///
/// 1. If any result is a vector/tensor then at least one operand must also be a
/// vector/tensor.
/// 2. If any operand is a vector/tensor then there must be at least one result
/// and all results must be vectors/tensors.
/// 3. All operand and result vector/tensor types must be of the same shape. The
/// shape may be dynamic in which case the op's behaviour is undefined for
/// non-matching shapes.
/// 4. The operation must be elementwise on its vector/tensor operands and
/// results. When applied to single-element vectors/tensors, the result must
/// be the same per elememnt.
///
/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
/// interface `ElementwiseTypeInterface` that describes the container types for
/// which the operation is elementwise.
///
/// Rationale:
/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
/// generic definition of the iteration space.
/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
/// the same pattern, as otherwise lots of special handling for type
/// mismatches would be needed.
/// - 4. guarantees that no error handling is needed. Higher-level dialects
/// should reify any needed guards or error handling code before lowering to
/// an `Elementwise` op.
template <typename ConcreteType>
struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
static LogicalResult verifyTrait(Operation *op) {
return ::mlir::OpTrait::impl::verifyElementwise(op);
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// scalarized. All vector/tensor operands and results are then replaced by
/// scalars of the respective element type. Semantically, this is the operation
/// on a single element of the vector/tensor.
///
/// Rationale:
/// Allow to define the vector/tensor semantics of elementwise operations based
/// on the same op's behavior on scalars. This provides a constructive procedure
/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
///
/// Example:
/// ```
/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// ```
/// can be scalarized to
///
/// ```
/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
/// : (i1, f32, f32) -> f32
/// ```
template <typename ConcreteType>
struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
static LogicalResult verifyTrait(Operation *op) {
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Scalarizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// vectorized. All scalar operands and results are then replaced by vectors
/// with the respective element type. Semantically, this is the operation on
/// multiple elements simultaneously. See also `Tensorizable`.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to vectors via
/// broadcasting in cases like `%select_scalar_pred` below.
template <typename ConcreteType>
struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
static LogicalResult verifyTrait(Operation *op) {
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Vectorizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// This trait tags `Elementwise` operatons that can be systematically
/// tensorized. All scalar operands and results are then replaced by tensors
/// with the respective element type. Semantically, this is the operation on
/// multiple elements simultaneously. See also `Vectorizable`.
///
/// Rationale:
/// Provide the reverse to `Scalarizable` which, when chained together, allows
/// reasoning about the relationship between the tensor and vector case.
/// Additionally, it permits reasoning about promoting scalars to tensors via
/// broadcasting in cases like `%select_scalar_pred` below.
///
/// Examples:
/// ```
/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
/// ```
/// can be tensorized to
/// ```
/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// ```
///
/// ```
/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
/// ```
/// can be tensorized to
/// ```
/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
/// -> tensor<?xf32>
/// ```
template <typename ConcreteType>
struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
static LogicalResult verifyTrait(Operation *op) {
static_assert(
ConcreteType::template hasTrait<Elementwise>(),
"`Tensorizable` trait is only applicable to `Elementwise` ops.");
return success();
}
};
/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
/// provide an easy way for scalar operations to conveniently generalize their
/// behavior to vectors/tensors, and systematize conversion between these forms.
bool hasElementwiseMappableTraits(Operation *op);
} // end namespace OpTrait
//===----------------------------------------------------------------------===//
// Internal Trait Utilities
//===----------------------------------------------------------------------===//
namespace op_definition_impl {
//===----------------------------------------------------------------------===//
// Trait Existence
/// Returns true if this given Trait ID matches the IDs of any of the provided
/// trait types `Traits`.
template <template <typename T> class... Traits>
static bool hasTrait(TypeID traitID) {
TypeID traitIDs[] = {TypeID::get<Traits>()...};
for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
if (traitIDs[i] == traitID)
return true;
return false;
}
//===----------------------------------------------------------------------===//
// Trait Folding
/// Trait to check if T provides a 'foldTrait' method for single result
/// operations.
template <typename T, typename... Args>
using has_single_result_fold_trait = decltype(T::foldTrait(
std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
template <typename T>
using detect_has_single_result_fold_trait =
llvm::is_detected<has_single_result_fold_trait, T>;
/// Trait to check if T provides a general 'foldTrait' method.
template <typename T, typename... Args>
using has_fold_trait =
decltype(T::foldTrait(std::declval<Operation *>(),
std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
/// Trait to check if T provides any `foldTrait` method.
/// NOTE: This should use std::disjunction when C++17 is available.
template <typename T>
using detect_has_any_fold_trait =
std::conditional_t<bool(detect_has_fold_trait<T>::value),
detect_has_fold_trait<T>,
detect_has_single_result_fold_trait<T>>;
/// Returns the result of folding a trait that implements a `foldTrait` function
/// that is specialized for operations that have a single result.
template <typename Trait>
static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
LogicalResult>
foldTrait(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
assert(op->hasTrait<OpTrait::OneResult>() &&
"expected trait on non single-result operation to implement the "
"general `foldTrait` method");
// If a previous trait has already been folded and replaced this operation, we
// fail to fold this trait.
if (!results.empty())
return failure();
if (OpFoldResult result = Trait::foldTrait(op, operands)) {
if (result.template dyn_cast<Value>() != op->getResult(0))
results.push_back(result);
return success();
}
return failure();
}
/// Returns the result of folding a trait that implements a generalized
/// `foldTrait` function that is supports any operation type.
template <typename Trait>
static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
foldTrait(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// If a previous trait has already been folded and replaced this operation, we
// fail to fold this trait.
return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
}
/// The internal implementation of `foldTraits` below that returns the result of
/// folding a set of trait types `Ts` that implement a `foldTrait` method.
template <typename... Ts>
static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results,
std::tuple<Ts...> *) {
bool anyFolded = false;
(void)std::initializer_list<int>{
(anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
return success(anyFolded);
}
/// Given a tuple type containing a set of traits that contain a `foldTrait`
/// method, return the result of folding the given operation.
template <typename TraitTupleT>
static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
foldTraits(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
}
/// A variant of the method above that is specialized when there are no traits
/// that contain a `foldTrait` method.
template <typename TraitTupleT>
static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
foldTraits(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return failure();
}
//===----------------------------------------------------------------------===//
// Trait Verification
/// Trait to check if T provides a `verifyTrait` method.
template <typename T, typename... Args>
using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
/// The internal implementation of `verifyTraits` below that returns the result
/// of verifying the current operation with all of the provided trait types
/// `Ts`.
template <typename... Ts>
static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
LogicalResult result = success();
(void)std::initializer_list<int>{
(result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
return result;
}
/// Given a tuple type containing a set of traits that contain a
/// `verifyTrait` method, return the result of verifying the given operation.
template <typename TraitTupleT>
static LogicalResult verifyTraits(Operation *op) {
return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
}
} // namespace op_definition_impl
//===----------------------------------------------------------------------===//
// Operation Definition classes
//===----------------------------------------------------------------------===//
/// This provides public APIs that all operations should have. The template
/// argument 'ConcreteType' should be the concrete type by CRTP and the others
/// are base classes by the policy pattern.
template <typename ConcreteType, template <typename T> class... Traits>
class Op : public OpState, public Traits<ConcreteType>... {
public:
/// Inherit getOperation from `OpState`.
using OpState::getOperation;
/// Return if this operation contains the provided trait.
template <template <typename T> class Trait>
static constexpr bool hasTrait() {
return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
}
/// Create a deep copy of this operation.
ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
/// Create a partial copy of this operation without traversing into attached
/// regions. The new operation will have the same number of regions as the
/// original one, but they will be left empty.
ConcreteType cloneWithoutRegions() {
return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
}
/// Return true if this "op class" can match against the specified operation.
static bool classof(Operation *op) {
if (auto info = op->getRegisteredInfo())
return TypeID::get<ConcreteType>() == info->getTypeID();
#ifndef NDEBUG
if (op->getName().getStringRef() == ConcreteType::getOperationName())
llvm::report_fatal_error(
"classof on '" + ConcreteType::getOperationName() +
"' failed due to the operation not being registered");
#endif
return false;
}
/// Provide `classof` support for other OpBase derived classes, such as
/// Interfaces.
template <typename T>
static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
classof(const T *op) {
return classof(const_cast<T *>(op)->getOperation());
}
/// Expose the type we are instantiated on to template machinery that may want
/// to introspect traits on this operation.
using ConcreteOpType = ConcreteType;
/// This is a public constructor. Any op can be initialized to null.
explicit Op() : OpState(nullptr) {}
Op(std::nullptr_t) : OpState(nullptr) {}
/// This is a public constructor to enable access via the llvm::cast family of
/// methods. This should not be used directly.
explicit Op(Operation *state) : OpState(state) {}
/// Methods for supporting PointerLikeTypeTraits.
const void *getAsOpaquePointer() const {
return static_cast<const void *>((Operation *)*this);
}
static ConcreteOpType getFromOpaquePointer(const void *pointer) {
return ConcreteOpType(
reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
}
/// Attach the given models as implementations of the corresponding interfaces
/// for the concrete operation.
template <typename... Models>
static void attachInterface(MLIRContext &context) {
Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
ConcreteType::getOperationName(), &context);
if (!info)
llvm::report_fatal_error(
"Attempting to attach an interface to an unregistered operation " +
ConcreteType::getOperationName() + ".");
info->attachInterface<Models...>();
}
private:
/// Trait to check if T provides a 'fold' method for a single result op.
template <typename T, typename... Args>
using has_single_result_fold =
decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
template <typename T>
using detect_has_single_result_fold =
llvm::is_detected<has_single_result_fold, T>;
/// Trait to check if T provides a general 'fold' method.
template <typename T, typename... Args>
using has_fold = decltype(std::declval<T>().fold(
std::declval<ArrayRef<Attribute>>(),
std::declval<SmallVectorImpl<OpFoldResult> &>()));
template <typename T>
using detect_has_fold = llvm::is_detected<has_fold, T>;
/// Trait to check if T provides a 'print' method.
template <typename T, typename... Args>
using has_print =
decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
template <typename T>
using detect_has_print = llvm::is_detected<has_print, T>;
/// A tuple type containing the traits that have a `foldTrait` function.
using FoldableTraitsTupleT = typename detail::FilterTypes<
op_definition_impl::detect_has_any_fold_trait,
Traits<ConcreteType>...>::type;
/// A tuple type containing the traits that have a verify function.
using VerifiableTraitsTupleT =
typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
Traits<ConcreteType>...>::type;
/// Returns an interface map containing the interfaces registered to this
/// operation.
static detail::InterfaceMap getInterfaceMap() {
return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
}
/// Return the internal implementations of each of the OperationName
/// hooks.
/// Implementation of `FoldHookFn` OperationName hook.
static OperationName::FoldHookFn getFoldHookFn() {
return getFoldHookFnImpl<ConcreteType>();
}
/// The internal implementation of `getFoldHookFn` above that is invoked if
/// the operation is single result and defines a `fold` method.
template <typename ConcreteOpT>
static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
Traits<ConcreteOpT>...>::value &&
detect_has_single_result_fold<ConcreteOpT>::value,
OperationName::FoldHookFn>
getFoldHookFnImpl() {
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldSingleResultHook<ConcreteOpT>(op, operands, results);
};
}
/// The internal implementation of `getFoldHookFn` above that is invoked if
/// the operation is not single result and defines a `fold` method.
template <typename ConcreteOpT>
static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
Traits<ConcreteOpT>...>::value &&
detect_has_fold<ConcreteOpT>::value,
OperationName::FoldHookFn>
getFoldHookFnImpl() {
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return foldHook<ConcreteOpT>(op, operands, results);
};
}
/// The internal implementation of `getFoldHookFn` above that is invoked if
/// the operation does not define a `fold` method.
template <typename ConcreteOpT>
static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
!detect_has_fold<ConcreteOpT>::value,
OperationName::FoldHookFn>
getFoldHookFnImpl() {
return [](Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
// In this case, we only need to fold the traits of the operation.
return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
results);
};
}
/// Return the result of folding a single result operation that defines a
/// `fold` method.
template <typename ConcreteOpT>
static LogicalResult
foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
op, operands, results)))
return success();
return success(static_cast<bool>(result));
}
results.push_back(result);
return success();
}
/// Return the result of folding an operation that defines a `fold` method.
template <typename ConcreteOpT>
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
// If the fold failed or was in-place, try to fold the traits of the
// operation.
if (failed(result) || results.empty()) {
if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
op, operands, results)))
return success();
}
return result;
}
/// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
static OperationName::GetCanonicalizationPatternsFn
getGetCanonicalizationPatternsFn() {
return &ConcreteType::getCanonicalizationPatterns;
}
/// Implementation of `GetHasTraitFn`
static OperationName::HasTraitFn getHasTraitFn() {
return
[](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
}
/// Implementation of `ParseAssemblyFn` OperationName hook.
static OperationName::ParseAssemblyFn getParseAssemblyFn() {
return &ConcreteType::parse;
}
/// Implementation of `PrintAssemblyFn` OperationName hook.
static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
return getPrintAssemblyFnImpl<ConcreteType>();
}
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
/// the concrete operation does not define a `print` method.
template <typename ConcreteOpT>
static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
OperationName::PrintAssemblyFn>
getPrintAssemblyFnImpl() {
return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
return OpState::print(op, printer);
};
}
/// The internal implementation of `getPrintAssemblyFn` that is invoked when
/// the concrete operation defines a `print` method.
template <typename ConcreteOpT>
static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
OperationName::PrintAssemblyFn>
getPrintAssemblyFnImpl() {
return &printAssembly;
}
static void printAssembly(Operation *op, OpAsmPrinter &p,
StringRef defaultDialect) {
OpState::printOpName(op, p, defaultDialect);
return cast<ConcreteType>(op).print(p);
}
/// Implementation of `VerifyInvariantsFn` OperationName hook.
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
return &verifyInvariants;
}
static constexpr bool hasNoDataMembers() {
// Checking that the derived class does not define any member by comparing
// its size to an ad-hoc EmptyOp.
class EmptyOp : public Op<EmptyOp, Traits...> {};
return sizeof(ConcreteType) == sizeof(EmptyOp);
}
static LogicalResult verifyInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verify()));
}
/// Allow access to internal implementation methods.
friend RegisteredOperationName;
};
/// This class represents the base of an operation interface. See the definition
/// of `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
class OpInterface
: public detail::Interface<ConcreteType, Operation *, Traits,
Op<ConcreteType>, OpTrait::TraitBase> {
public:
using Base = OpInterface<ConcreteType, Traits>;
using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
Op<ConcreteType>, OpTrait::TraitBase>;
/// Inherit the base class constructor.
using InterfaceBase::InterfaceBase;
protected:
/// Returns the impl interface instance for the given operation.
static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
OperationName name = op->getName();
// Access the raw interface from the operation info.
if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
if (auto *opIface = rInfo->getInterface<ConcreteType>())
return opIface;
// Fallback to the dialect to provide it with a chance to implement this
// interface for this operation.
return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
op->getName());
}
// Fallback to the dialect to provide it with a chance to implement this
// interface for this operation.
if (Dialect *dialect = name.getDialect())
return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
return nullptr;
}
/// Allow access to `getInterfaceFor`.
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// Common Operation Folders/Parsers/Printers
//===----------------------------------------------------------------------===//
// These functions are out-of-line implementations of the methods in UnaryOp and
// BinaryOp, which avoids them being template instantiated/duplicated.
namespace impl {
ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
OperationState &result);
void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
Value rhs);
ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result);
// Prints the given binary `op` in custom assembly form if both the two operands
// and the result have the same time. Otherwise, prints the generic assembly
// form.
void printOneResultOp(Operation *op, OpAsmPrinter &p);
} // namespace impl
// These functions are out-of-line implementations of the methods in
// CastOpInterface, which avoids them being template instantiated/duplicated.
namespace impl {
/// Attempt to fold the given cast operation.
LogicalResult foldCastInterfaceOp(Operation *op,
ArrayRef<Attribute> attrOperands,
SmallVectorImpl<OpFoldResult> &foldResults);
/// Attempt to verify the given cast operation.
LogicalResult verifyCastInterfaceOp(
Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
// need for them, but some older ODS code in `std` still depends on them).
void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
Type destType);
ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
void printCastOp(Operation *op, OpAsmPrinter &p);
// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
// when all uses have been updated. Also, consider adding functionality to
// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
// generically.
Value foldCastOp(Operation *op);
LogicalResult verifyCastOp(Operation *op,
function_ref<bool(Type, Type)> areCastCompatible);
} // namespace impl
} // end namespace mlir
namespace llvm {
template <typename T>
struct DenseMapInfo<
T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
static inline T getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return T::getFromOpaquePointer(pointer);
}
static inline T getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return T::getFromOpaquePointer(pointer);
}
static unsigned getHashValue(T val) {
return hash_value(val.getAsOpaquePointer());
}
static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
};
} // end namespace llvm
#endif