| //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements helper classes for implementing the "Op" types. This |
| // includes the Op type, which is the base class for Op class definitions, |
| // as well as number of traits in the OpTrait namespace that provide a |
| // declarative way to specify properties of Ops. |
| // |
| // The purpose of these types are to allow light-weight implementation of |
| // concrete ops (like DimOp) with very little boilerplate. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_IR_OPDEFINITION_H |
| #define MLIR_IR_OPDEFINITION_H |
| |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/Operation.h" |
| #include "llvm/Support/PointerLikeTypeTraits.h" |
| |
| #include <type_traits> |
| |
| namespace mlir { |
| class Builder; |
| class OpBuilder; |
| |
| /// This class represents success/failure for operation parsing. It is |
| /// essentially a simple wrapper class around LogicalResult that allows for |
| /// explicit conversion to bool. This allows for the parser to chain together |
| /// parse rules without the clutter of "failed/succeeded". |
| class ParseResult : public LogicalResult { |
| public: |
| ParseResult(LogicalResult result = success()) : LogicalResult(result) {} |
| |
| // Allow diagnostics emitted during parsing to be converted to failure. |
| ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {} |
| ParseResult(const Diagnostic &) : LogicalResult(failure()) {} |
| |
| /// Failure is true in a boolean context. |
| explicit operator bool() const { return failed(); } |
| }; |
| /// This class implements `Optional` functionality for ParseResult. We don't |
| /// directly use Optional here, because it provides an implicit conversion |
| /// to 'bool' which we want to avoid. This class is used to implement tri-state |
| /// 'parseOptional' functions that may have a failure mode when parsing that |
| /// shouldn't be attributed to "not present". |
| class OptionalParseResult { |
| public: |
| OptionalParseResult() = default; |
| OptionalParseResult(LogicalResult result) : impl(result) {} |
| OptionalParseResult(ParseResult result) : impl(result) {} |
| OptionalParseResult(const InFlightDiagnostic &) |
| : OptionalParseResult(failure()) {} |
| OptionalParseResult(llvm::NoneType) : impl(llvm::None) {} |
| |
| /// Returns true if we contain a valid ParseResult value. |
| bool hasValue() const { return impl.hasValue(); } |
| |
| /// Access the internal ParseResult value. |
| ParseResult getValue() const { return impl.getValue(); } |
| ParseResult operator*() const { return getValue(); } |
| |
| private: |
| Optional<ParseResult> impl; |
| }; |
| |
| // These functions are out-of-line utilities, which avoids them being template |
| // instantiated/duplicated. |
| namespace impl { |
| /// Insert an operation, generated by `buildTerminatorOp`, at the end of the |
| /// region's only block if it does not have a terminator already. If the region |
| /// is empty, insert a new block first. `buildTerminatorOp` should return the |
| /// terminator operation to insert. |
| void ensureRegionTerminator( |
| Region ®ion, OpBuilder &builder, Location loc, |
| function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
| void ensureRegionTerminator( |
| Region ®ion, Builder &builder, Location loc, |
| function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
| |
| } // namespace impl |
| |
| /// This is the concrete base class that holds the operation pointer and has |
| /// non-generic methods that only depend on State (to avoid having them |
| /// instantiated on template types that don't affect them. |
| /// |
| /// This also has the fallback implementations of customization hooks for when |
| /// they aren't customized. |
| class OpState { |
| public: |
| /// Ops are pointer-like, so we allow conversion to bool. |
| explicit operator bool() { return getOperation() != nullptr; } |
| |
| /// This implicitly converts to Operation*. |
| operator Operation *() const { return state; } |
| |
| /// Shortcut of `->` to access a member of Operation. |
| Operation *operator->() const { return state; } |
| |
| /// Return the operation that this refers to. |
| Operation *getOperation() { return state; } |
| |
| /// Return the context this operation belongs to. |
| MLIRContext *getContext() { return getOperation()->getContext(); } |
| |
| /// Print the operation to the given stream. |
| void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { |
| state->print(os, flags); |
| } |
| void print(raw_ostream &os, AsmState &asmState, |
| OpPrintingFlags flags = llvm::None) { |
| state->print(os, asmState, flags); |
| } |
| |
| /// Dump this operation. |
| void dump() { state->dump(); } |
| |
| /// The source location the operation was defined or derived from. |
| Location getLoc() { return state->getLoc(); } |
| |
| /// Return true if there are no users of any results of this operation. |
| bool use_empty() { return state->use_empty(); } |
| |
| /// Remove this operation from its parent block and delete it. |
| void erase() { state->erase(); } |
| |
| /// Emit an error with the op name prefixed, like "'dim' op " which is |
| /// convenient for verifiers. |
| InFlightDiagnostic emitOpError(const Twine &message = {}); |
| |
| /// Emit an error about fatal conditions with this operation, reporting up to |
| /// any diagnostic handlers that may be listening. |
| InFlightDiagnostic emitError(const Twine &message = {}); |
| |
| /// Emit a warning about this operation, reporting up to any diagnostic |
| /// handlers that may be listening. |
| InFlightDiagnostic emitWarning(const Twine &message = {}); |
| |
| /// Emit a remark about this operation, reporting up to any diagnostic |
| /// handlers that may be listening. |
| InFlightDiagnostic emitRemark(const Twine &message = {}); |
| |
| /// Walk the operation by calling the callback for each nested operation |
| /// (including this one), block or region, depending on the callback provided. |
| /// Regions, blocks and operations at the same nesting level are visited in |
| /// lexicographical order. The walk order for enclosing regions, blocks and |
| /// operations with respect to their nested ones is specified by 'Order' |
| /// (post-order by default). A callback on a block or operation is allowed to |
| /// erase that block or operation if either: |
| /// * the walk is in post-order, or |
| /// * the walk is in pre-order and the walk is skipped after the erasure. |
| /// See Operation::walk for more details. |
| template <WalkOrder Order = WalkOrder::PostOrder, typename FnT, |
| typename RetT = detail::walkResultType<FnT>> |
| RetT walk(FnT &&callback) { |
| return state->walk<Order>(std::forward<FnT>(callback)); |
| } |
| |
| // These are default implementations of customization hooks. |
| public: |
| /// This hook returns any canonicalization pattern rewrites that the operation |
| /// supports, for use by the canonicalization pass. |
| static void getCanonicalizationPatterns(RewritePatternSet &results, |
| MLIRContext *context) {} |
| |
| protected: |
| /// If the concrete type didn't implement a custom verifier hook, just fall |
| /// back to this one which accepts everything. |
| LogicalResult verify() { return success(); } |
| |
| /// Unless overridden, the custom assembly form of an op is always rejected. |
| /// Op implementations should implement this to return failure. |
| /// On success, they should fill in result with the fields to use. |
| static ParseResult parse(OpAsmParser &parser, OperationState &result); |
| |
| // The fallback for the printer is to print it the generic assembly form. |
| static void print(Operation *op, OpAsmPrinter &p); |
| static void printOpName(Operation *op, OpAsmPrinter &p, |
| StringRef defaultDialect); |
| |
| /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, |
| /// so we can cast it away here. |
| explicit OpState(Operation *state) : state(state) {} |
| |
| private: |
| Operation *state; |
| |
| /// Allow access to internal hook implementation methods. |
| friend RegisteredOperationName; |
| }; |
| |
| // Allow comparing operators. |
| inline bool operator==(OpState lhs, OpState rhs) { |
| return lhs.getOperation() == rhs.getOperation(); |
| } |
| inline bool operator!=(OpState lhs, OpState rhs) { |
| return lhs.getOperation() != rhs.getOperation(); |
| } |
| |
| raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr); |
| |
| /// This class represents a single result from folding an operation. |
| class OpFoldResult : public PointerUnion<Attribute, Value> { |
| using PointerUnion<Attribute, Value>::PointerUnion; |
| |
| public: |
| void dump() { llvm::errs() << *this << "\n"; } |
| }; |
| |
| /// Allow printing to a stream. |
| inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { |
| if (Value value = ofr.dyn_cast<Value>()) |
| value.print(os); |
| else |
| ofr.dyn_cast<Attribute>().print(os); |
| return os; |
| } |
| |
| /// Allow printing to a stream. |
| inline raw_ostream &operator<<(raw_ostream &os, OpState op) { |
| op.print(os, OpPrintingFlags().useLocalScope()); |
| return os; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Operation Trait Types |
| //===----------------------------------------------------------------------===// |
| |
| namespace OpTrait { |
| |
| // These functions are out-of-line implementations of the methods in the |
| // corresponding trait classes. This avoids them being template |
| // instantiated/duplicated. |
| namespace impl { |
| OpFoldResult foldIdempotent(Operation *op); |
| OpFoldResult foldInvolution(Operation *op); |
| LogicalResult verifyZeroOperands(Operation *op); |
| LogicalResult verifyOneOperand(Operation *op); |
| LogicalResult verifyNOperands(Operation *op, unsigned numOperands); |
| LogicalResult verifyIsIdempotent(Operation *op); |
| LogicalResult verifyIsInvolution(Operation *op); |
| LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); |
| LogicalResult verifyOperandsAreFloatLike(Operation *op); |
| LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); |
| LogicalResult verifySameTypeOperands(Operation *op); |
| LogicalResult verifyZeroRegion(Operation *op); |
| LogicalResult verifyOneRegion(Operation *op); |
| LogicalResult verifyNRegions(Operation *op, unsigned numRegions); |
| LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); |
| LogicalResult verifyZeroResult(Operation *op); |
| LogicalResult verifyOneResult(Operation *op); |
| LogicalResult verifyNResults(Operation *op, unsigned numOperands); |
| LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); |
| LogicalResult verifySameOperandsShape(Operation *op); |
| LogicalResult verifySameOperandsAndResultShape(Operation *op); |
| LogicalResult verifySameOperandsElementType(Operation *op); |
| LogicalResult verifySameOperandsAndResultElementType(Operation *op); |
| LogicalResult verifySameOperandsAndResultType(Operation *op); |
| LogicalResult verifyResultsAreBoolLike(Operation *op); |
| LogicalResult verifyResultsAreFloatLike(Operation *op); |
| LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); |
| LogicalResult verifyIsTerminator(Operation *op); |
| LogicalResult verifyZeroSuccessor(Operation *op); |
| LogicalResult verifyOneSuccessor(Operation *op); |
| LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); |
| LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); |
| LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, |
| StringRef valueGroupName, |
| size_t expectedCount); |
| LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); |
| LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); |
| LogicalResult verifyNoRegionArguments(Operation *op); |
| LogicalResult verifyElementwise(Operation *op); |
| LogicalResult verifyIsIsolatedFromAbove(Operation *op); |
| } // namespace impl |
| |
| /// Helper class for implementing traits. Clients are not expected to interact |
| /// with this directly, so its members are all protected. |
| template <typename ConcreteType, template <typename> class TraitType> |
| class TraitBase { |
| protected: |
| /// Return the ultimate Operation being worked on. |
| Operation *getOperation() { |
| // We have to cast up to the trait type, then to the concrete type, then to |
| // the BaseState class in explicit hops because the concrete type will |
| // multiply derive from the (content free) TraitBase class, and we need to |
| // be able to disambiguate the path for the C++ compiler. |
| auto *trait = static_cast<TraitType<ConcreteType> *>(this); |
| auto *concrete = static_cast<ConcreteType *>(trait); |
| auto *base = static_cast<OpState *>(concrete); |
| return base->getOperation(); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Operand Traits |
| |
| namespace detail { |
| /// Utility trait base that provides accessors for derived traits that have |
| /// multiple operands. |
| template <typename ConcreteType, template <typename> class TraitType> |
| struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> { |
| using operand_iterator = Operation::operand_iterator; |
| using operand_range = Operation::operand_range; |
| using operand_type_iterator = Operation::operand_type_iterator; |
| using operand_type_range = Operation::operand_type_range; |
| |
| /// Return the number of operands. |
| unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } |
| |
| /// Return the operand at index 'i'. |
| Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } |
| |
| /// Set the operand at index 'i' to 'value'. |
| void setOperand(unsigned i, Value value) { |
| this->getOperation()->setOperand(i, value); |
| } |
| |
| /// Operand iterator access. |
| operand_iterator operand_begin() { |
| return this->getOperation()->operand_begin(); |
| } |
| operand_iterator operand_end() { return this->getOperation()->operand_end(); } |
| operand_range getOperands() { return this->getOperation()->getOperands(); } |
| |
| /// Operand type access. |
| operand_type_iterator operand_type_begin() { |
| return this->getOperation()->operand_type_begin(); |
| } |
| operand_type_iterator operand_type_end() { |
| return this->getOperation()->operand_type_end(); |
| } |
| operand_type_range getOperandTypes() { |
| return this->getOperation()->getOperandTypes(); |
| } |
| }; |
| } // end namespace detail |
| |
| /// This class provides the API for ops that are known to have no |
| /// SSA operand. |
| template <typename ConcreteType> |
| class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyZeroOperands(op); |
| } |
| |
| private: |
| // Disable these. |
| void getOperand() {} |
| void setOperand() {} |
| }; |
| |
| /// This class provides the API for ops that are known to have exactly one |
| /// SSA operand. |
| template <typename ConcreteType> |
| class OneOperand : public TraitBase<ConcreteType, OneOperand> { |
| public: |
| Value getOperand() { return this->getOperation()->getOperand(0); } |
| |
| void setOperand(Value value) { this->getOperation()->setOperand(0, value); } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOneOperand(op); |
| } |
| }; |
| |
| /// This class provides the API for ops that are known to have a specified |
| /// number of operands. This is used as a trait like this: |
| /// |
| /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { |
| /// |
| template <unsigned N> |
| class NOperands { |
| public: |
| static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); |
| |
| template <typename ConcreteType> |
| class Impl |
| : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyNOperands(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops that are known to have a at least a |
| /// specified number of operands. This is used as a trait like this: |
| /// |
| /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { |
| /// |
| template <unsigned N> |
| class AtLeastNOperands { |
| public: |
| template <typename ConcreteType> |
| class Impl : public detail::MultiOperandTraitBase<ConcreteType, |
| AtLeastNOperands<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyAtLeastNOperands(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops which have an unknown number of |
| /// SSA operands. |
| template <typename ConcreteType> |
| class VariadicOperands |
| : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {}; |
| |
| //===----------------------------------------------------------------------===// |
| // Region Traits |
| |
| /// This class provides verification for ops that are known to have zero |
| /// regions. |
| template <typename ConcreteType> |
| class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyZeroRegion(op); |
| } |
| }; |
| |
| namespace detail { |
| /// Utility trait base that provides accessors for derived traits that have |
| /// multiple regions. |
| template <typename ConcreteType, template <typename> class TraitType> |
| struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> { |
| using region_iterator = MutableArrayRef<Region>; |
| using region_range = RegionRange; |
| |
| /// Return the number of regions. |
| unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } |
| |
| /// Return the region at `index`. |
| Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } |
| |
| /// Region iterator access. |
| region_iterator region_begin() { |
| return this->getOperation()->region_begin(); |
| } |
| region_iterator region_end() { return this->getOperation()->region_end(); } |
| region_range getRegions() { return this->getOperation()->getRegions(); } |
| }; |
| } // end namespace detail |
| |
| /// This class provides APIs for ops that are known to have a single region. |
| template <typename ConcreteType> |
| class OneRegion : public TraitBase<ConcreteType, OneRegion> { |
| public: |
| Region &getRegion() { return this->getOperation()->getRegion(0); } |
| |
| /// Returns a range of operations within the region of this operation. |
| auto getOps() { return getRegion().getOps(); } |
| template <typename OpT> |
| auto getOps() { |
| return getRegion().template getOps<OpT>(); |
| } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOneRegion(op); |
| } |
| }; |
| |
| /// This class provides the API for ops that are known to have a specified |
| /// number of regions. |
| template <unsigned N> |
| class NRegions { |
| public: |
| static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); |
| |
| template <typename ConcreteType> |
| class Impl |
| : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyNRegions(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides APIs for ops that are known to have at least a specified |
| /// number of regions. |
| template <unsigned N> |
| class AtLeastNRegions { |
| public: |
| template <typename ConcreteType> |
| class Impl : public detail::MultiRegionTraitBase<ConcreteType, |
| AtLeastNRegions<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyAtLeastNRegions(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops which have an unknown number of |
| /// regions. |
| template <typename ConcreteType> |
| class VariadicRegions |
| : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {}; |
| |
| //===----------------------------------------------------------------------===// |
| // Result Traits |
| |
| /// This class provides return value APIs for ops that are known to have |
| /// zero results. |
| template <typename ConcreteType> |
| class ZeroResult : public TraitBase<ConcreteType, ZeroResult> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyZeroResult(op); |
| } |
| }; |
| |
| namespace detail { |
| /// Utility trait base that provides accessors for derived traits that have |
| /// multiple results. |
| template <typename ConcreteType, template <typename> class TraitType> |
| struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> { |
| using result_iterator = Operation::result_iterator; |
| using result_range = Operation::result_range; |
| using result_type_iterator = Operation::result_type_iterator; |
| using result_type_range = Operation::result_type_range; |
| |
| /// Return the number of results. |
| unsigned getNumResults() { return this->getOperation()->getNumResults(); } |
| |
| /// Return the result at index 'i'. |
| Value getResult(unsigned i) { return this->getOperation()->getResult(i); } |
| |
| /// Replace all uses of results of this operation with the provided 'values'. |
| /// 'values' may correspond to an existing operation, or a range of 'Value'. |
| template <typename ValuesT> |
| void replaceAllUsesWith(ValuesT &&values) { |
| this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); |
| } |
| |
| /// Return the type of the `i`-th result. |
| Type getType(unsigned i) { return getResult(i).getType(); } |
| |
| /// Result iterator access. |
| result_iterator result_begin() { |
| return this->getOperation()->result_begin(); |
| } |
| result_iterator result_end() { return this->getOperation()->result_end(); } |
| result_range getResults() { return this->getOperation()->getResults(); } |
| |
| /// Result type access. |
| result_type_iterator result_type_begin() { |
| return this->getOperation()->result_type_begin(); |
| } |
| result_type_iterator result_type_end() { |
| return this->getOperation()->result_type_end(); |
| } |
| result_type_range getResultTypes() { |
| return this->getOperation()->getResultTypes(); |
| } |
| }; |
| } // end namespace detail |
| |
| /// This class provides return value APIs for ops that are known to have a |
| /// single result. ResultType is the concrete type returned by getType(). |
| template <typename ConcreteType> |
| class OneResult : public TraitBase<ConcreteType, OneResult> { |
| public: |
| Value getResult() { return this->getOperation()->getResult(0); } |
| |
| /// If the operation returns a single value, then the Op can be implicitly |
| /// converted to an Value. This yields the value of the only result. |
| operator Value() { return getResult(); } |
| |
| /// Replace all uses of 'this' value with the new value, updating anything |
| /// in the IR that uses 'this' to use the other value instead. When this |
| /// returns there are zero uses of 'this'. |
| void replaceAllUsesWith(Value newValue) { |
| getResult().replaceAllUsesWith(newValue); |
| } |
| |
| /// Replace all uses of 'this' value with the result of 'op'. |
| void replaceAllUsesWith(Operation *op) { |
| this->getOperation()->replaceAllUsesWith(op); |
| } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOneResult(op); |
| } |
| }; |
| |
| /// This trait is used for return value APIs for ops that are known to have a |
| /// specific type other than `Type`. This allows the "getType()" member to be |
| /// more specific for an op. This should be used in conjunction with OneResult, |
| /// and occur in the trait list before OneResult. |
| template <typename ResultType> |
| class OneTypedResult { |
| public: |
| /// This class provides return value APIs for ops that are known to have a |
| /// single result. ResultType is the concrete type returned by getType(). |
| template <typename ConcreteType> |
| class Impl |
| : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { |
| public: |
| ResultType getType() { |
| auto resultTy = this->getOperation()->getResult(0).getType(); |
| return resultTy.template cast<ResultType>(); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops that are known to have a specified |
| /// number of results. This is used as a trait like this: |
| /// |
| /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { |
| /// |
| template <unsigned N> |
| class NResults { |
| public: |
| static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); |
| |
| template <typename ConcreteType> |
| class Impl |
| : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyNResults(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops that are known to have at least a |
| /// specified number of results. This is used as a trait like this: |
| /// |
| /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { |
| /// |
| template <unsigned N> |
| class AtLeastNResults { |
| public: |
| template <typename ConcreteType> |
| class Impl : public detail::MultiResultTraitBase<ConcreteType, |
| AtLeastNResults<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyAtLeastNResults(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops which have an unknown number of |
| /// results. |
| template <typename ConcreteType> |
| class VariadicResults |
| : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {}; |
| |
| //===----------------------------------------------------------------------===// |
| // Terminator Traits |
| |
| /// This class indicates that the regions associated with this op don't have |
| /// terminators. |
| template <typename ConcreteType> |
| class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {}; |
| |
| /// This class provides the API for ops that are known to be terminators. |
| template <typename ConcreteType> |
| class IsTerminator : public TraitBase<ConcreteType, IsTerminator> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyIsTerminator(op); |
| } |
| }; |
| |
| /// This class provides verification for ops that are known to have zero |
| /// successors. |
| template <typename ConcreteType> |
| class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyZeroSuccessor(op); |
| } |
| }; |
| |
| namespace detail { |
| /// Utility trait base that provides accessors for derived traits that have |
| /// multiple successors. |
| template <typename ConcreteType, template <typename> class TraitType> |
| struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> { |
| using succ_iterator = Operation::succ_iterator; |
| using succ_range = SuccessorRange; |
| |
| /// Return the number of successors. |
| unsigned getNumSuccessors() { |
| return this->getOperation()->getNumSuccessors(); |
| } |
| |
| /// Return the successor at `index`. |
| Block *getSuccessor(unsigned i) { |
| return this->getOperation()->getSuccessor(i); |
| } |
| |
| /// Set the successor at `index`. |
| void setSuccessor(Block *block, unsigned i) { |
| return this->getOperation()->setSuccessor(block, i); |
| } |
| |
| /// Successor iterator access. |
| succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } |
| succ_iterator succ_end() { return this->getOperation()->succ_end(); } |
| succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } |
| }; |
| } // end namespace detail |
| |
| /// This class provides APIs for ops that are known to have a single successor. |
| template <typename ConcreteType> |
| class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> { |
| public: |
| Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } |
| void setSuccessor(Block *succ) { |
| this->getOperation()->setSuccessor(succ, 0); |
| } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOneSuccessor(op); |
| } |
| }; |
| |
| /// This class provides the API for ops that are known to have a specified |
| /// number of successors. |
| template <unsigned N> |
| class NSuccessors { |
| public: |
| static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2"); |
| |
| template <typename ConcreteType> |
| class Impl : public detail::MultiSuccessorTraitBase<ConcreteType, |
| NSuccessors<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyNSuccessors(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides APIs for ops that are known to have at least a specified |
| /// number of successors. |
| template <unsigned N> |
| class AtLeastNSuccessors { |
| public: |
| template <typename ConcreteType> |
| class Impl |
| : public detail::MultiSuccessorTraitBase<ConcreteType, |
| AtLeastNSuccessors<N>::Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyAtLeastNSuccessors(op, N); |
| } |
| }; |
| }; |
| |
| /// This class provides the API for ops which have an unknown number of |
| /// successors. |
| template <typename ConcreteType> |
| class VariadicSuccessors |
| : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> { |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SingleBlock |
| |
| /// This class provides APIs and verifiers for ops with regions having a single |
| /// block. |
| template <typename ConcreteType> |
| struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
| Region ®ion = op->getRegion(i); |
| |
| // Empty regions are fine. |
| if (region.empty()) |
| continue; |
| |
| // Non-empty regions must contain a single basic block. |
| if (!llvm::hasSingleElement(region)) |
| return op->emitOpError("expects region #") |
| << i << " to have 0 or 1 blocks"; |
| |
| if (!ConcreteType::template hasTrait<NoTerminator>()) { |
| Block &block = region.front(); |
| if (block.empty()) |
| return op->emitOpError() << "expects a non-empty block"; |
| } |
| } |
| return success(); |
| } |
| |
| Block *getBody(unsigned idx = 0) { |
| Region ®ion = this->getOperation()->getRegion(idx); |
| assert(!region.empty() && "unexpected empty region"); |
| return ®ion.front(); |
| } |
| Region &getBodyRegion(unsigned idx = 0) { |
| return this->getOperation()->getRegion(idx); |
| } |
| |
| //===------------------------------------------------------------------===// |
| // Single Region Utilities |
| //===------------------------------------------------------------------===// |
| |
| /// The following are a set of methods only enabled when the parent |
| /// operation has a single region. Each of these methods take an additional |
| /// template parameter that represents the concrete operation so that we |
| /// can use SFINAE to disable the methods for non-single region operations. |
| template <typename OpT, typename T = void> |
| using enable_if_single_region = |
| typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; |
| |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT, Block::iterator> begin() { |
| return getBody()->begin(); |
| } |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT, Block::iterator> end() { |
| return getBody()->end(); |
| } |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT, Operation &> front() { |
| return *begin(); |
| } |
| |
| /// Insert the operation into the back of the body. |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> push_back(Operation *op) { |
| insert(Block::iterator(getBody()->end()), op); |
| } |
| |
| /// Insert the operation at the given insertion point. |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { |
| insert(Block::iterator(insertPt), op); |
| } |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) { |
| getBody()->getOperations().insert(insertPt, op); |
| } |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // SingleBlockImplicitTerminator |
| |
| /// This class provides APIs and verifiers for ops with regions having a single |
| /// block that must terminate with `TerminatorOpType`. |
| template <typename TerminatorOpType> |
| struct SingleBlockImplicitTerminator { |
| template <typename ConcreteType> |
| class Impl : public SingleBlock<ConcreteType> { |
| private: |
| using Base = SingleBlock<ConcreteType>; |
| /// Builds a terminator operation without relying on OpBuilder APIs to avoid |
| /// cyclic header inclusion. |
| static Operation *buildTerminator(OpBuilder &builder, Location loc) { |
| OperationState state(loc, TerminatorOpType::getOperationName()); |
| TerminatorOpType::build(builder, state); |
| return Operation::create(state); |
| } |
| |
| public: |
| /// The type of the operation used as the implicit terminator type. |
| using ImplicitTerminatorOpT = TerminatorOpType; |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| if (failed(Base::verifyTrait(op))) |
| return failure(); |
| for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
| Region ®ion = op->getRegion(i); |
| // Empty regions are fine. |
| if (region.empty()) |
| continue; |
| Operation &terminator = region.front().back(); |
| if (isa<TerminatorOpType>(terminator)) |
| continue; |
| |
| return op->emitOpError("expects regions to end with '" + |
| TerminatorOpType::getOperationName() + |
| "', found '" + |
| terminator.getName().getStringRef() + "'") |
| .attachNote() |
| << "in custom textual format, the absence of terminator implies " |
| "'" |
| << TerminatorOpType::getOperationName() << '\''; |
| } |
| |
| return success(); |
| } |
| |
| /// Ensure that the given region has the terminator required by this trait. |
| /// If OpBuilder is provided, use it to build the terminator and notify the |
| /// OpBuilder listeners accordingly. If only a Builder is provided, locally |
| /// construct an OpBuilder with no listeners; this should only be used if no |
| /// OpBuilder is available at the call site, e.g., in the parser. |
| static void ensureTerminator(Region ®ion, Builder &builder, |
| Location loc) { |
| ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
| buildTerminator); |
| } |
| static void ensureTerminator(Region ®ion, OpBuilder &builder, |
| Location loc) { |
| ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
| buildTerminator); |
| } |
| |
| //===------------------------------------------------------------------===// |
| // Single Region Utilities |
| //===------------------------------------------------------------------===// |
| using Base::getBody; |
| |
| template <typename OpT, typename T = void> |
| using enable_if_single_region = |
| typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; |
| |
| /// Insert the operation into the back of the body, before the terminator. |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> push_back(Operation *op) { |
| insert(Block::iterator(getBody()->getTerminator()), op); |
| } |
| |
| /// Insert the operation at the given insertion point. Note: The operation |
| /// is never inserted after the terminator, even if the insertion point is |
| /// end(). |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { |
| insert(Block::iterator(insertPt), op); |
| } |
| template <typename OpT = ConcreteType> |
| enable_if_single_region<OpT> insert(Block::iterator insertPt, |
| Operation *op) { |
| auto *body = getBody(); |
| if (insertPt == body->end()) |
| insertPt = Block::iterator(body->getTerminator()); |
| body->getOperations().insert(insertPt, op); |
| } |
| }; |
| }; |
| |
| /// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended |
| /// to be used with `llvm::is_detected`. |
| template <class T> |
| using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT; |
| |
| /// Support to check if an operation has the SingleBlockImplicitTerminator |
| /// trait. We can't just use `hasTrait` because this class is templated on a |
| /// specific terminator op. |
| template <class Op, bool hasTerminator = |
| llvm::is_detected<has_implicit_terminator_t, Op>::value> |
| struct hasSingleBlockImplicitTerminator { |
| static constexpr bool value = std::is_base_of< |
| typename OpTrait::SingleBlockImplicitTerminator< |
| typename Op::ImplicitTerminatorOpT>::template Impl<Op>, |
| Op>::value; |
| }; |
| template <class Op> |
| struct hasSingleBlockImplicitTerminator<Op, false> { |
| static constexpr bool value = false; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Misc Traits |
| |
| /// This class provides verification for ops that are known to have the same |
| /// operand shape: all operands are scalars, vectors/tensors of the same |
| /// shape. |
| template <typename ConcreteType> |
| class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameOperandsShape(op); |
| } |
| }; |
| |
| /// This class provides verification for ops that are known to have the same |
| /// operand and result shape: both are scalars, vectors/tensors of the same |
| /// shape. |
| template <typename ConcreteType> |
| class SameOperandsAndResultShape |
| : public TraitBase<ConcreteType, SameOperandsAndResultShape> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameOperandsAndResultShape(op); |
| } |
| }; |
| |
| /// This class provides verification for ops that are known to have the same |
| /// operand element type (or the type itself if it is scalar). |
| /// |
| template <typename ConcreteType> |
| class SameOperandsElementType |
| : public TraitBase<ConcreteType, SameOperandsElementType> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameOperandsElementType(op); |
| } |
| }; |
| |
| /// This class provides verification for ops that are known to have the same |
| /// operand and result element type (or the type itself if it is scalar). |
| /// |
| template <typename ConcreteType> |
| class SameOperandsAndResultElementType |
| : public TraitBase<ConcreteType, SameOperandsAndResultElementType> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameOperandsAndResultElementType(op); |
| } |
| }; |
| |
| /// This class provides verification for ops that are known to have the same |
| /// operand and result type. |
| /// |
| /// Note: this trait subsumes the SameOperandsAndResultShape and |
| /// SameOperandsAndResultElementType traits. |
| template <typename ConcreteType> |
| class SameOperandsAndResultType |
| : public TraitBase<ConcreteType, SameOperandsAndResultType> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameOperandsAndResultType(op); |
| } |
| }; |
| |
| /// This class verifies that any results of the specified op have a boolean |
| /// type, a vector thereof, or a tensor thereof. |
| template <typename ConcreteType> |
| class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyResultsAreBoolLike(op); |
| } |
| }; |
| |
| /// This class verifies that any results of the specified op have a floating |
| /// point type, a vector thereof, or a tensor thereof. |
| template <typename ConcreteType> |
| class ResultsAreFloatLike |
| : public TraitBase<ConcreteType, ResultsAreFloatLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyResultsAreFloatLike(op); |
| } |
| }; |
| |
| /// This class verifies that any results of the specified op have a signless |
| /// integer or index type, a vector thereof, or a tensor thereof. |
| template <typename ConcreteType> |
| class ResultsAreSignlessIntegerLike |
| : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyResultsAreSignlessIntegerLike(op); |
| } |
| }; |
| |
| /// This class adds property that the operation is commutative. |
| template <typename ConcreteType> |
| class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {}; |
| |
| /// This class adds property that the operation is an involution. |
| /// This means a unary to unary operation "f" that satisfies f(f(x)) = x |
| template <typename ConcreteType> |
| class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert(ConcreteType::template hasTrait<OneResult>(), |
| "expected operation to produce one result"); |
| static_assert(ConcreteType::template hasTrait<OneOperand>(), |
| "expected operation to take one operand"); |
| static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
| "expected operation to preserve type"); |
| // Involution requires the operation to be side effect free as well |
| // but currently this check is under a FIXME and is not actually done. |
| return impl::verifyIsInvolution(op); |
| } |
| |
| static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
| return impl::foldInvolution(op); |
| } |
| }; |
| |
| /// This class adds property that the operation is idempotent. |
| /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) |
| template <typename ConcreteType> |
| class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert(ConcreteType::template hasTrait<OneResult>(), |
| "expected operation to produce one result"); |
| static_assert(ConcreteType::template hasTrait<OneOperand>(), |
| "expected operation to take one operand"); |
| static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
| "expected operation to preserve type"); |
| // Idempotent requires the operation to be side effect free as well |
| // but currently this check is under a FIXME and is not actually done. |
| return impl::verifyIsIdempotent(op); |
| } |
| |
| static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
| return impl::foldIdempotent(op); |
| } |
| }; |
| |
| /// This class verifies that all operands of the specified op have a float type, |
| /// a vector thereof, or a tensor thereof. |
| template <typename ConcreteType> |
| class OperandsAreFloatLike |
| : public TraitBase<ConcreteType, OperandsAreFloatLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOperandsAreFloatLike(op); |
| } |
| }; |
| |
| /// This class verifies that all operands of the specified op have a signless |
| /// integer or index type, a vector thereof, or a tensor thereof. |
| template <typename ConcreteType> |
| class OperandsAreSignlessIntegerLike |
| : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyOperandsAreSignlessIntegerLike(op); |
| } |
| }; |
| |
| /// This class verifies that all operands of the specified op have the same |
| /// type. |
| template <typename ConcreteType> |
| class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifySameTypeOperands(op); |
| } |
| }; |
| |
| /// This class provides the API for a sub-set of ops that are known to be |
| /// constant-like. These are non-side effecting operations with one result and |
| /// zero operands that can always be folded to a specific attribute value. |
| template <typename ConcreteType> |
| class ConstantLike : public TraitBase<ConcreteType, ConstantLike> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert(ConcreteType::template hasTrait<OneResult>(), |
| "expected operation to produce one result"); |
| static_assert(ConcreteType::template hasTrait<ZeroOperands>(), |
| "expected operation to take zero operands"); |
| // TODO: We should verify that the operation can always be folded, but this |
| // requires that the attributes of the op already be verified. We should add |
| // support for verifying traits "after" the operation to enable this use |
| // case. |
| return success(); |
| } |
| }; |
| |
| /// This class provides the API for ops that are known to be isolated from |
| /// above. |
| template <typename ConcreteType> |
| class IsIsolatedFromAbove |
| : public TraitBase<ConcreteType, IsIsolatedFromAbove> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| return impl::verifyIsIsolatedFromAbove(op); |
| } |
| }; |
| |
| /// A trait of region holding operations that defines a new scope for polyhedral |
| /// optimization purposes. Any SSA values of 'index' type that either dominate |
| /// such an operation or are used at the top-level of such an operation |
| /// automatically become valid symbols for the polyhedral scope defined by that |
| /// operation. For more details, see `Traits.md#AffineScope`. |
| template <typename ConcreteType> |
| class AffineScope : public TraitBase<ConcreteType, AffineScope> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), |
| "expected operation to have one or more regions"); |
| return success(); |
| } |
| }; |
| |
| /// A trait of region holding operations that define a new scope for automatic |
| /// allocations, i.e., allocations that are freed when control is transferred |
| /// back from the operation's region. Any operations performing such allocations |
| /// (for eg. memref.alloca) will have their allocations automatically freed at |
| /// their closest enclosing operation with this trait. |
| template <typename ConcreteType> |
| class AutomaticAllocationScope |
| : public TraitBase<ConcreteType, AutomaticAllocationScope> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), |
| "expected operation to have one or more regions"); |
| return success(); |
| } |
| }; |
| |
| /// This class provides a verifier for ops that are expecting their parent |
| /// to be one of the given parent ops |
| template <typename... ParentOpTypes> |
| struct HasParent { |
| template <typename ConcreteType> |
| class Impl : public TraitBase<ConcreteType, Impl> { |
| public: |
| static LogicalResult verifyTrait(Operation *op) { |
| if (llvm::isa<ParentOpTypes...>(op->getParentOp())) |
| return success(); |
| |
| return op->emitOpError() |
| << "expects parent op " |
| << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") |
| << llvm::makeArrayRef({ParentOpTypes::getOperationName()...}) |
| << "'"; |
| } |
| }; |
| }; |
| |
| /// A trait for operations that have an attribute specifying operand segments. |
| /// |
| /// Certain operations can have multiple variadic operands and their size |
| /// relationship is not always known statically. For such cases, we need |
| /// a per-op-instance specification to divide the operands into logical groups |
| /// or segments. This can be modeled by attributes. The attribute will be named |
| /// as `operand_segment_sizes`. |
| /// |
| /// This trait verifies the attribute for specifying operand segments has |
| /// the correct type (1D vector) and values (non-negative), etc. |
| template <typename ConcreteType> |
| class AttrSizedOperandSegments |
| : public TraitBase<ConcreteType, AttrSizedOperandSegments> { |
| public: |
| static StringRef getOperandSegmentSizeAttr() { |
| return "operand_segment_sizes"; |
| } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return ::mlir::OpTrait::impl::verifyOperandSizeAttr( |
| op, getOperandSegmentSizeAttr()); |
| } |
| }; |
| |
| /// Similar to AttrSizedOperandSegments but used for results. |
| template <typename ConcreteType> |
| class AttrSizedResultSegments |
| : public TraitBase<ConcreteType, AttrSizedResultSegments> { |
| public: |
| static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } |
| |
| static LogicalResult verifyTrait(Operation *op) { |
| return ::mlir::OpTrait::impl::verifyResultSizeAttr( |
| op, getResultSegmentSizeAttr()); |
| } |
| }; |
| |
| /// This trait provides a verifier for ops that are expecting their regions to |
| /// not have any arguments |
| template <typename ConcrentType> |
| struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> { |
| static LogicalResult verifyTrait(Operation *op) { |
| return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); |
| } |
| }; |
| |
| // This trait is used to flag operations that consume or produce |
| // values of `MemRef` type where those references can be 'normalized'. |
| // TODO: Right now, the operands of an operation are either all normalizable, |
| // or not. In the future, we may want to allow some of the operands to be |
| // normalizable. |
| template <typename ConcrentType> |
| struct MemRefsNormalizable |
| : public TraitBase<ConcrentType, MemRefsNormalizable> {}; |
| |
| /// This trait tags element-wise ops on vectors or tensors. |
| /// |
| /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this |
| /// trait. In particular, broadcasting behavior is not allowed. |
| /// |
| /// An `Elementwise` op must satisfy the following properties: |
| /// |
| /// 1. If any result is a vector/tensor then at least one operand must also be a |
| /// vector/tensor. |
| /// 2. If any operand is a vector/tensor then there must be at least one result |
| /// and all results must be vectors/tensors. |
| /// 3. All operand and result vector/tensor types must be of the same shape. The |
| /// shape may be dynamic in which case the op's behaviour is undefined for |
| /// non-matching shapes. |
| /// 4. The operation must be elementwise on its vector/tensor operands and |
| /// results. When applied to single-element vectors/tensors, the result must |
| /// be the same per elememnt. |
| /// |
| /// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new |
| /// interface `ElementwiseTypeInterface` that describes the container types for |
| /// which the operation is elementwise. |
| /// |
| /// Rationale: |
| /// - 1. and 2. guarantee a well-defined iteration space and exclude the cases |
| /// of 0 non-scalar operands or 0 non-scalar results, which complicate a |
| /// generic definition of the iteration space. |
| /// - 3. guarantees that folding can be done across scalars/vectors/tensors with |
| /// the same pattern, as otherwise lots of special handling for type |
| /// mismatches would be needed. |
| /// - 4. guarantees that no error handling is needed. Higher-level dialects |
| /// should reify any needed guards or error handling code before lowering to |
| /// an `Elementwise` op. |
| template <typename ConcreteType> |
| struct Elementwise : public TraitBase<ConcreteType, Elementwise> { |
| static LogicalResult verifyTrait(Operation *op) { |
| return ::mlir::OpTrait::impl::verifyElementwise(op); |
| } |
| }; |
| |
| /// This trait tags `Elementwise` operatons that can be systematically |
| /// scalarized. All vector/tensor operands and results are then replaced by |
| /// scalars of the respective element type. Semantically, this is the operation |
| /// on a single element of the vector/tensor. |
| /// |
| /// Rationale: |
| /// Allow to define the vector/tensor semantics of elementwise operations based |
| /// on the same op's behavior on scalars. This provides a constructive procedure |
| /// for IR transformations to, e.g., create scalar loop bodies from tensor ops. |
| /// |
| /// Example: |
| /// ``` |
| /// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val) |
| /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
| /// -> tensor<?xf32> |
| /// ``` |
| /// can be scalarized to |
| /// |
| /// ``` |
| /// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar) |
| /// : (i1, f32, f32) -> f32 |
| /// ``` |
| template <typename ConcreteType> |
| struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> { |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert( |
| ConcreteType::template hasTrait<Elementwise>(), |
| "`Scalarizable` trait is only applicable to `Elementwise` ops."); |
| return success(); |
| } |
| }; |
| |
| /// This trait tags `Elementwise` operatons that can be systematically |
| /// vectorized. All scalar operands and results are then replaced by vectors |
| /// with the respective element type. Semantically, this is the operation on |
| /// multiple elements simultaneously. See also `Tensorizable`. |
| /// |
| /// Rationale: |
| /// Provide the reverse to `Scalarizable` which, when chained together, allows |
| /// reasoning about the relationship between the tensor and vector case. |
| /// Additionally, it permits reasoning about promoting scalars to vectors via |
| /// broadcasting in cases like `%select_scalar_pred` below. |
| template <typename ConcreteType> |
| struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> { |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert( |
| ConcreteType::template hasTrait<Elementwise>(), |
| "`Vectorizable` trait is only applicable to `Elementwise` ops."); |
| return success(); |
| } |
| }; |
| |
| /// This trait tags `Elementwise` operatons that can be systematically |
| /// tensorized. All scalar operands and results are then replaced by tensors |
| /// with the respective element type. Semantically, this is the operation on |
| /// multiple elements simultaneously. See also `Vectorizable`. |
| /// |
| /// Rationale: |
| /// Provide the reverse to `Scalarizable` which, when chained together, allows |
| /// reasoning about the relationship between the tensor and vector case. |
| /// Additionally, it permits reasoning about promoting scalars to tensors via |
| /// broadcasting in cases like `%select_scalar_pred` below. |
| /// |
| /// Examples: |
| /// ``` |
| /// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 |
| /// ``` |
| /// can be tensorized to |
| /// ``` |
| /// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>) |
| /// -> tensor<?xf32> |
| /// ``` |
| /// |
| /// ``` |
| /// %scalar_pred = "std.select"(%pred, %true_val, %false_val) |
| /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| /// ``` |
| /// can be tensorized to |
| /// ``` |
| /// %tensor_pred = "std.select"(%pred, %true_val, %false_val) |
| /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
| /// -> tensor<?xf32> |
| /// ``` |
| template <typename ConcreteType> |
| struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> { |
| static LogicalResult verifyTrait(Operation *op) { |
| static_assert( |
| ConcreteType::template hasTrait<Elementwise>(), |
| "`Tensorizable` trait is only applicable to `Elementwise` ops."); |
| return success(); |
| } |
| }; |
| |
| /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` |
| /// provide an easy way for scalar operations to conveniently generalize their |
| /// behavior to vectors/tensors, and systematize conversion between these forms. |
| bool hasElementwiseMappableTraits(Operation *op); |
| |
| } // end namespace OpTrait |
| |
| //===----------------------------------------------------------------------===// |
| // Internal Trait Utilities |
| //===----------------------------------------------------------------------===// |
| |
| namespace op_definition_impl { |
| //===----------------------------------------------------------------------===// |
| // Trait Existence |
| |
| /// Returns true if this given Trait ID matches the IDs of any of the provided |
| /// trait types `Traits`. |
| template <template <typename T> class... Traits> |
| static bool hasTrait(TypeID traitID) { |
| TypeID traitIDs[] = {TypeID::get<Traits>()...}; |
| for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) |
| if (traitIDs[i] == traitID) |
| return true; |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Trait Folding |
| |
| /// Trait to check if T provides a 'foldTrait' method for single result |
| /// operations. |
| template <typename T, typename... Args> |
| using has_single_result_fold_trait = decltype(T::foldTrait( |
| std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); |
| template <typename T> |
| using detect_has_single_result_fold_trait = |
| llvm::is_detected<has_single_result_fold_trait, T>; |
| /// Trait to check if T provides a general 'foldTrait' method. |
| template <typename T, typename... Args> |
| using has_fold_trait = |
| decltype(T::foldTrait(std::declval<Operation *>(), |
| std::declval<ArrayRef<Attribute>>(), |
| std::declval<SmallVectorImpl<OpFoldResult> &>())); |
| template <typename T> |
| using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; |
| /// Trait to check if T provides any `foldTrait` method. |
| /// NOTE: This should use std::disjunction when C++17 is available. |
| template <typename T> |
| using detect_has_any_fold_trait = |
| std::conditional_t<bool(detect_has_fold_trait<T>::value), |
| detect_has_fold_trait<T>, |
| detect_has_single_result_fold_trait<T>>; |
| |
| /// Returns the result of folding a trait that implements a `foldTrait` function |
| /// that is specialized for operations that have a single result. |
| template <typename Trait> |
| static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, |
| LogicalResult> |
| foldTrait(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| assert(op->hasTrait<OpTrait::OneResult>() && |
| "expected trait on non single-result operation to implement the " |
| "general `foldTrait` method"); |
| // If a previous trait has already been folded and replaced this operation, we |
| // fail to fold this trait. |
| if (!results.empty()) |
| return failure(); |
| |
| if (OpFoldResult result = Trait::foldTrait(op, operands)) { |
| if (result.template dyn_cast<Value>() != op->getResult(0)) |
| results.push_back(result); |
| return success(); |
| } |
| return failure(); |
| } |
| /// Returns the result of folding a trait that implements a generalized |
| /// `foldTrait` function that is supports any operation type. |
| template <typename Trait> |
| static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> |
| foldTrait(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| // If a previous trait has already been folded and replaced this operation, we |
| // fail to fold this trait. |
| return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); |
| } |
| |
| /// The internal implementation of `foldTraits` below that returns the result of |
| /// folding a set of trait types `Ts` that implement a `foldTrait` method. |
| template <typename... Ts> |
| static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results, |
| std::tuple<Ts...> *) { |
| bool anyFolded = false; |
| (void)std::initializer_list<int>{ |
| (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...}; |
| return success(anyFolded); |
| } |
| |
| /// Given a tuple type containing a set of traits that contain a `foldTrait` |
| /// method, return the result of folding the given operation. |
| template <typename TraitTupleT> |
| static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult> |
| foldTraits(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr); |
| } |
| /// A variant of the method above that is specialized when there are no traits |
| /// that contain a `foldTrait` method. |
| template <typename TraitTupleT> |
| static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult> |
| foldTraits(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| return failure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Trait Verification |
| |
| /// Trait to check if T provides a `verifyTrait` method. |
| template <typename T, typename... Args> |
| using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); |
| template <typename T> |
| using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; |
| |
| /// The internal implementation of `verifyTraits` below that returns the result |
| /// of verifying the current operation with all of the provided trait types |
| /// `Ts`. |
| template <typename... Ts> |
| static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) { |
| LogicalResult result = success(); |
| (void)std::initializer_list<int>{ |
| (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...}; |
| return result; |
| } |
| |
| /// Given a tuple type containing a set of traits that contain a |
| /// `verifyTrait` method, return the result of verifying the given operation. |
| template <typename TraitTupleT> |
| static LogicalResult verifyTraits(Operation *op) { |
| return verifyTraitsImpl(op, (TraitTupleT *)nullptr); |
| } |
| } // namespace op_definition_impl |
| |
| //===----------------------------------------------------------------------===// |
| // Operation Definition classes |
| //===----------------------------------------------------------------------===// |
| |
| /// This provides public APIs that all operations should have. The template |
| /// argument 'ConcreteType' should be the concrete type by CRTP and the others |
| /// are base classes by the policy pattern. |
| template <typename ConcreteType, template <typename T> class... Traits> |
| class Op : public OpState, public Traits<ConcreteType>... { |
| public: |
| /// Inherit getOperation from `OpState`. |
| using OpState::getOperation; |
| |
| /// Return if this operation contains the provided trait. |
| template <template <typename T> class Trait> |
| static constexpr bool hasTrait() { |
| return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; |
| } |
| |
| /// Create a deep copy of this operation. |
| ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } |
| |
| /// Create a partial copy of this operation without traversing into attached |
| /// regions. The new operation will have the same number of regions as the |
| /// original one, but they will be left empty. |
| ConcreteType cloneWithoutRegions() { |
| return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); |
| } |
| |
| /// Return true if this "op class" can match against the specified operation. |
| static bool classof(Operation *op) { |
| if (auto info = op->getRegisteredInfo()) |
| return TypeID::get<ConcreteType>() == info->getTypeID(); |
| #ifndef NDEBUG |
| if (op->getName().getStringRef() == ConcreteType::getOperationName()) |
| llvm::report_fatal_error( |
| "classof on '" + ConcreteType::getOperationName() + |
| "' failed due to the operation not being registered"); |
| #endif |
| return false; |
| } |
| /// Provide `classof` support for other OpBase derived classes, such as |
| /// Interfaces. |
| template <typename T> |
| static std::enable_if_t<std::is_base_of<OpState, T>::value, bool> |
| classof(const T *op) { |
| return classof(const_cast<T *>(op)->getOperation()); |
| } |
| |
| /// Expose the type we are instantiated on to template machinery that may want |
| /// to introspect traits on this operation. |
| using ConcreteOpType = ConcreteType; |
| |
| /// This is a public constructor. Any op can be initialized to null. |
| explicit Op() : OpState(nullptr) {} |
| Op(std::nullptr_t) : OpState(nullptr) {} |
| |
| /// This is a public constructor to enable access via the llvm::cast family of |
| /// methods. This should not be used directly. |
| explicit Op(Operation *state) : OpState(state) {} |
| |
| /// Methods for supporting PointerLikeTypeTraits. |
| const void *getAsOpaquePointer() const { |
| return static_cast<const void *>((Operation *)*this); |
| } |
| static ConcreteOpType getFromOpaquePointer(const void *pointer) { |
| return ConcreteOpType( |
| reinterpret_cast<Operation *>(const_cast<void *>(pointer))); |
| } |
| |
| /// Attach the given models as implementations of the corresponding interfaces |
| /// for the concrete operation. |
| template <typename... Models> |
| static void attachInterface(MLIRContext &context) { |
| Optional<RegisteredOperationName> info = RegisteredOperationName::lookup( |
| ConcreteType::getOperationName(), &context); |
| if (!info) |
| llvm::report_fatal_error( |
| "Attempting to attach an interface to an unregistered operation " + |
| ConcreteType::getOperationName() + "."); |
| info->attachInterface<Models...>(); |
| } |
| |
| private: |
| /// Trait to check if T provides a 'fold' method for a single result op. |
| template <typename T, typename... Args> |
| using has_single_result_fold = |
| decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); |
| template <typename T> |
| using detect_has_single_result_fold = |
| llvm::is_detected<has_single_result_fold, T>; |
| /// Trait to check if T provides a general 'fold' method. |
| template <typename T, typename... Args> |
| using has_fold = decltype(std::declval<T>().fold( |
| std::declval<ArrayRef<Attribute>>(), |
| std::declval<SmallVectorImpl<OpFoldResult> &>())); |
| template <typename T> |
| using detect_has_fold = llvm::is_detected<has_fold, T>; |
| /// Trait to check if T provides a 'print' method. |
| template <typename T, typename... Args> |
| using has_print = |
| decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); |
| template <typename T> |
| using detect_has_print = llvm::is_detected<has_print, T>; |
| /// A tuple type containing the traits that have a `foldTrait` function. |
| using FoldableTraitsTupleT = typename detail::FilterTypes< |
| op_definition_impl::detect_has_any_fold_trait, |
| Traits<ConcreteType>...>::type; |
| /// A tuple type containing the traits that have a verify function. |
| using VerifiableTraitsTupleT = |
| typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait, |
| Traits<ConcreteType>...>::type; |
| |
| /// Returns an interface map containing the interfaces registered to this |
| /// operation. |
| static detail::InterfaceMap getInterfaceMap() { |
| return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); |
| } |
| |
| /// Return the internal implementations of each of the OperationName |
| /// hooks. |
| /// Implementation of `FoldHookFn` OperationName hook. |
| static OperationName::FoldHookFn getFoldHookFn() { |
| return getFoldHookFnImpl<ConcreteType>(); |
| } |
| /// The internal implementation of `getFoldHookFn` above that is invoked if |
| /// the operation is single result and defines a `fold` method. |
| template <typename ConcreteOpT> |
| static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, |
| Traits<ConcreteOpT>...>::value && |
| detect_has_single_result_fold<ConcreteOpT>::value, |
| OperationName::FoldHookFn> |
| getFoldHookFnImpl() { |
| return [](Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| return foldSingleResultHook<ConcreteOpT>(op, operands, results); |
| }; |
| } |
| /// The internal implementation of `getFoldHookFn` above that is invoked if |
| /// the operation is not single result and defines a `fold` method. |
| template <typename ConcreteOpT> |
| static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, |
| Traits<ConcreteOpT>...>::value && |
| detect_has_fold<ConcreteOpT>::value, |
| OperationName::FoldHookFn> |
| getFoldHookFnImpl() { |
| return [](Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| return foldHook<ConcreteOpT>(op, operands, results); |
| }; |
| } |
| /// The internal implementation of `getFoldHookFn` above that is invoked if |
| /// the operation does not define a `fold` method. |
| template <typename ConcreteOpT> |
| static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value && |
| !detect_has_fold<ConcreteOpT>::value, |
| OperationName::FoldHookFn> |
| getFoldHookFnImpl() { |
| return [](Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| // In this case, we only need to fold the traits of the operation. |
| return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands, |
| results); |
| }; |
| } |
| /// Return the result of folding a single result operation that defines a |
| /// `fold` method. |
| template <typename ConcreteOpT> |
| static LogicalResult |
| foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| OpFoldResult result = cast<ConcreteOpT>(op).fold(operands); |
| |
| // If the fold failed or was in-place, try to fold the traits of the |
| // operation. |
| if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { |
| if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( |
| op, operands, results))) |
| return success(); |
| return success(static_cast<bool>(result)); |
| } |
| results.push_back(result); |
| return success(); |
| } |
| /// Return the result of folding an operation that defines a `fold` method. |
| template <typename ConcreteOpT> |
| static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, |
| SmallVectorImpl<OpFoldResult> &results) { |
| LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results); |
| |
| // If the fold failed or was in-place, try to fold the traits of the |
| // operation. |
| if (failed(result) || results.empty()) { |
| if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( |
| op, operands, results))) |
| return success(); |
| } |
| return result; |
| } |
| |
| /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook. |
| static OperationName::GetCanonicalizationPatternsFn |
| getGetCanonicalizationPatternsFn() { |
| return &ConcreteType::getCanonicalizationPatterns; |
| } |
| /// Implementation of `GetHasTraitFn` |
| static OperationName::HasTraitFn getHasTraitFn() { |
| return |
| [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); }; |
| } |
| /// Implementation of `ParseAssemblyFn` OperationName hook. |
| static OperationName::ParseAssemblyFn getParseAssemblyFn() { |
| return &ConcreteType::parse; |
| } |
| /// Implementation of `PrintAssemblyFn` OperationName hook. |
| static OperationName::PrintAssemblyFn getPrintAssemblyFn() { |
| return getPrintAssemblyFnImpl<ConcreteType>(); |
| } |
| /// The internal implementation of `getPrintAssemblyFn` that is invoked when |
| /// the concrete operation does not define a `print` method. |
| template <typename ConcreteOpT> |
| static std::enable_if_t<!detect_has_print<ConcreteOpT>::value, |
| OperationName::PrintAssemblyFn> |
| getPrintAssemblyFnImpl() { |
| return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { |
| return OpState::print(op, printer); |
| }; |
| } |
| /// The internal implementation of `getPrintAssemblyFn` that is invoked when |
| /// the concrete operation defines a `print` method. |
| template <typename ConcreteOpT> |
| static std::enable_if_t<detect_has_print<ConcreteOpT>::value, |
| OperationName::PrintAssemblyFn> |
| getPrintAssemblyFnImpl() { |
| return &printAssembly; |
| } |
| static void printAssembly(Operation *op, OpAsmPrinter &p, |
| StringRef defaultDialect) { |
| OpState::printOpName(op, p, defaultDialect); |
| return cast<ConcreteType>(op).print(p); |
| } |
| /// Implementation of `VerifyInvariantsFn` OperationName hook. |
| static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { |
| return &verifyInvariants; |
| } |
| |
| static constexpr bool hasNoDataMembers() { |
| // Checking that the derived class does not define any member by comparing |
| // its size to an ad-hoc EmptyOp. |
| class EmptyOp : public Op<EmptyOp, Traits...> {}; |
| return sizeof(ConcreteType) == sizeof(EmptyOp); |
| } |
| |
| static LogicalResult verifyInvariants(Operation *op) { |
| static_assert(hasNoDataMembers(), |
| "Op class shouldn't define new data members"); |
| return failure( |
| failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) || |
| failed(cast<ConcreteType>(op).verify())); |
| } |
| |
| /// Allow access to internal implementation methods. |
| friend RegisteredOperationName; |
| }; |
| |
| /// This class represents the base of an operation interface. See the definition |
| /// of `detail::Interface` for requirements on the `Traits` type. |
| template <typename ConcreteType, typename Traits> |
| class OpInterface |
| : public detail::Interface<ConcreteType, Operation *, Traits, |
| Op<ConcreteType>, OpTrait::TraitBase> { |
| public: |
| using Base = OpInterface<ConcreteType, Traits>; |
| using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits, |
| Op<ConcreteType>, OpTrait::TraitBase>; |
| |
| /// Inherit the base class constructor. |
| using InterfaceBase::InterfaceBase; |
| |
| protected: |
| /// Returns the impl interface instance for the given operation. |
| static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { |
| OperationName name = op->getName(); |
| |
| // Access the raw interface from the operation info. |
| if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) { |
| if (auto *opIface = rInfo->getInterface<ConcreteType>()) |
| return opIface; |
| // Fallback to the dialect to provide it with a chance to implement this |
| // interface for this operation. |
| return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>( |
| op->getName()); |
| } |
| // Fallback to the dialect to provide it with a chance to implement this |
| // interface for this operation. |
| if (Dialect *dialect = name.getDialect()) |
| return dialect->getRegisteredInterfaceForOp<ConcreteType>(name); |
| return nullptr; |
| } |
| |
| /// Allow access to `getInterfaceFor`. |
| friend InterfaceBase; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // Common Operation Folders/Parsers/Printers |
| //===----------------------------------------------------------------------===// |
| |
| // These functions are out-of-line implementations of the methods in UnaryOp and |
| // BinaryOp, which avoids them being template instantiated/duplicated. |
| namespace impl { |
| ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser, |
| OperationState &result); |
| |
| void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs, |
| Value rhs); |
| ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, |
| OperationState &result); |
| |
| // Prints the given binary `op` in custom assembly form if both the two operands |
| // and the result have the same time. Otherwise, prints the generic assembly |
| // form. |
| void printOneResultOp(Operation *op, OpAsmPrinter &p); |
| } // namespace impl |
| |
| // These functions are out-of-line implementations of the methods in |
| // CastOpInterface, which avoids them being template instantiated/duplicated. |
| namespace impl { |
| /// Attempt to fold the given cast operation. |
| LogicalResult foldCastInterfaceOp(Operation *op, |
| ArrayRef<Attribute> attrOperands, |
| SmallVectorImpl<OpFoldResult> &foldResults); |
| /// Attempt to verify the given cast operation. |
| LogicalResult verifyCastInterfaceOp( |
| Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible); |
| |
| // TODO: Remove the parse/print/build here (new ODS functionality obsoletes the |
| // need for them, but some older ODS code in `std` still depends on them). |
| void buildCastOp(OpBuilder &builder, OperationState &result, Value source, |
| Type destType); |
| ParseResult parseCastOp(OpAsmParser &parser, OperationState &result); |
| void printCastOp(Operation *op, OpAsmPrinter &p); |
| // TODO: These methods are deprecated in favor of CastOpInterface. Remove them |
| // when all uses have been updated. Also, consider adding functionality to |
| // CastOpInterface to be able to perform the ChainedTensorCast canonicalization |
| // generically. |
| Value foldCastOp(Operation *op); |
| LogicalResult verifyCastOp(Operation *op, |
| function_ref<bool(Type, Type)> areCastCompatible); |
| } // namespace impl |
| } // end namespace mlir |
| |
| namespace llvm { |
| |
| template <typename T> |
| struct DenseMapInfo< |
| T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> { |
| static inline T getEmptyKey() { |
| auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
| return T::getFromOpaquePointer(pointer); |
| } |
| static inline T getTombstoneKey() { |
| auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
| return T::getFromOpaquePointer(pointer); |
| } |
| static unsigned getHashValue(T val) { |
| return hash_value(val.getAsOpaquePointer()); |
| } |
| static bool isEqual(T lhs, T rhs) { return lhs == rhs; } |
| }; |
| |
| } // end namespace llvm |
| |
| #endif |