| //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir-c/Rewrite.h" |
| |
| #include "mlir-c/Transforms.h" |
| #include "mlir/CAPI/IR.h" |
| #include "mlir/CAPI/Rewrite.h" |
| #include "mlir/CAPI/Support.h" |
| #include "mlir/CAPI/Wrap.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/PDLPatternMatch.h.inc" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "mlir/Transforms/WalkPatternRewriteDriver.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| /// RewriterBase API inherited from OpBuilder |
| //===----------------------------------------------------------------------===// |
| |
| MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { |
| return wrap(unwrap(rewriter)->getContext()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// Insertion points methods |
| //===----------------------------------------------------------------------===// |
| |
| void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { |
| unwrap(rewriter)->clearInsertionPoint(); |
| } |
| |
| void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| unwrap(rewriter)->setInsertionPoint(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, |
| MlirValue value) { |
| unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); |
| } |
| |
| void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, |
| MlirBlock block) { |
| unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); |
| } |
| |
| void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, |
| MlirBlock block) { |
| unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); |
| } |
| |
| MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { |
| return wrap(unwrap(rewriter)->getInsertionBlock()); |
| } |
| |
| MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { |
| return wrap(unwrap(rewriter)->getBlock()); |
| } |
| |
| MlirOperation |
| mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) { |
| mlir::RewriterBase *base = unwrap(rewriter); |
| mlir::Block *block = base->getInsertionBlock(); |
| mlir::Block::iterator it = base->getInsertionPoint(); |
| if (it == block->end()) |
| return {nullptr}; |
| |
| return wrap(std::addressof(*it)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// Block and operation creation/insertion/cloning |
| //===----------------------------------------------------------------------===// |
| |
| MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, |
| MlirBlock insertBefore, |
| intptr_t nArgTypes, |
| MlirType const *argTypes, |
| MlirLocation const *locations) { |
| SmallVector<Type, 4> args; |
| ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args); |
| SmallVector<Location, 4> locs; |
| ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs); |
| return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, |
| unwrappedLocs)); |
| } |
| |
| MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| return wrap(unwrap(rewriter)->insert(unwrap(op))); |
| } |
| |
| // Other methods of OpBuilder |
| |
| MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| return wrap(unwrap(rewriter)->clone(*unwrap(op))); |
| } |
| |
| MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); |
| } |
| |
| void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, |
| MlirRegion region, MlirBlock before) { |
| |
| unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// RewriterBase API |
| //===----------------------------------------------------------------------===// |
| |
| void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, |
| MlirRegion region, MlirBlock before) { |
| unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); |
| } |
| |
| void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, |
| MlirOperation op, intptr_t nValues, |
| MlirValue const *values) { |
| SmallVector<Value, 4> vals; |
| ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals); |
| unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); |
| } |
| |
| void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, |
| MlirOperation op, |
| MlirOperation newOp) { |
| unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); |
| } |
| |
| void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { |
| unwrap(rewriter)->eraseOp(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { |
| unwrap(rewriter)->eraseBlock(unwrap(block)); |
| } |
| |
| void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, |
| MlirBlock source, MlirOperation op, |
| intptr_t nArgValues, |
| MlirValue const *argValues) { |
| SmallVector<Value, 4> vals; |
| ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals); |
| |
| unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), |
| unwrappedVals); |
| } |
| |
| void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, |
| MlirBlock dest, intptr_t nArgValues, |
| MlirValue const *argValues) { |
| SmallVector<Value, 4> args; |
| ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args); |
| unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); |
| } |
| |
| void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, |
| MlirOperation existingOp) { |
| unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); |
| } |
| |
| void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, |
| MlirOperation existingOp) { |
| unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); |
| } |
| |
| void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, |
| MlirBlock existingBlock) { |
| unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); |
| } |
| |
| void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| unwrap(rewriter)->startOpModification(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| unwrap(rewriter)->finalizeOpModification(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, |
| MlirOperation op) { |
| unwrap(rewriter)->cancelOpModification(unwrap(op)); |
| } |
| |
| void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, |
| MlirValue from, MlirValue to) { |
| unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); |
| } |
| |
| void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, |
| intptr_t nValues, |
| MlirValue const *from, |
| MlirValue const *to) { |
| SmallVector<Value, 4> fromVals; |
| ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals); |
| SmallVector<Value, 4> toVals; |
| ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals); |
| unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); |
| } |
| |
| void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, |
| MlirOperation from, |
| intptr_t nTo, |
| MlirValue const *to) { |
| SmallVector<Value, 4> toVals; |
| ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals); |
| unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); |
| } |
| |
| void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, |
| MlirOperation from, |
| MlirOperation to) { |
| unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); |
| } |
| |
| void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, |
| MlirOperation op, |
| intptr_t nNewValues, |
| MlirValue const *newValues, |
| MlirBlock block) { |
| SmallVector<Value, 4> vals; |
| ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals); |
| unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, |
| unwrap(block)); |
| } |
| |
| void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, |
| MlirValue from, MlirValue to, |
| MlirOperation exceptedUser) { |
| unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), |
| unwrap(exceptedUser)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// IRRewriter API |
| //===----------------------------------------------------------------------===// |
| |
| MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { |
| return wrap(new IRRewriter(unwrap(context))); |
| } |
| |
| MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { |
| return wrap(new IRRewriter(unwrap(op))); |
| } |
| |
| void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { |
| delete static_cast<IRRewriter *>(unwrap(rewriter)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// RewritePatternSet and FrozenRewritePatternSet API |
| //===----------------------------------------------------------------------===// |
| |
| MlirFrozenRewritePatternSet |
| mlirFreezeRewritePattern(MlirRewritePatternSet set) { |
| auto *m = new mlir::FrozenRewritePatternSet(std::move(*unwrap(set))); |
| set.ptr = nullptr; |
| return wrap(m); |
| } |
| |
| void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet set) { |
| delete unwrap(set); |
| set.ptr = nullptr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// GreedyRewriteDriverConfig API |
| //===----------------------------------------------------------------------===// |
| |
| inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) { |
| assert(config.ptr && "unexpected null config"); |
| return static_cast<mlir::GreedyRewriteConfig *>(config.ptr); |
| } |
| |
| inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) { |
| return {config}; |
| } |
| |
| MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() { |
| return wrap(new mlir::GreedyRewriteConfig()); |
| } |
| |
| void mlirGreedyRewriteDriverConfigDestroy( |
| MlirGreedyRewriteDriverConfig config) { |
| delete unwrap(config); |
| } |
| |
| void mlirGreedyRewriteDriverConfigSetMaxIterations( |
| MlirGreedyRewriteDriverConfig config, int64_t maxIterations) { |
| unwrap(config)->setMaxIterations(maxIterations); |
| } |
| |
| void mlirGreedyRewriteDriverConfigSetMaxNumRewrites( |
| MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) { |
| unwrap(config)->setMaxNumRewrites(maxNumRewrites); |
| } |
| |
| void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal( |
| MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) { |
| unwrap(config)->setUseTopDownTraversal(useTopDownTraversal); |
| } |
| |
| void mlirGreedyRewriteDriverConfigEnableFolding( |
| MlirGreedyRewriteDriverConfig config, bool enable) { |
| unwrap(config)->enableFolding(enable); |
| } |
| |
| void mlirGreedyRewriteDriverConfigSetStrictness( |
| MlirGreedyRewriteDriverConfig config, |
| MlirGreedyRewriteStrictness strictness) { |
| mlir::GreedyRewriteStrictness cppStrictness; |
| switch (strictness) { |
| case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP: |
| cppStrictness = mlir::GreedyRewriteStrictness::AnyOp; |
| break; |
| case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS: |
| cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps; |
| break; |
| case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS: |
| cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps; |
| break; |
| } |
| unwrap(config)->setStrictness(cppStrictness); |
| } |
| |
| void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel( |
| MlirGreedyRewriteDriverConfig config, MlirGreedySimplifyRegionLevel level) { |
| mlir::GreedySimplifyRegionLevel cppLevel; |
| switch (level) { |
| case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED: |
| cppLevel = mlir::GreedySimplifyRegionLevel::Disabled; |
| break; |
| case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL: |
| cppLevel = mlir::GreedySimplifyRegionLevel::Normal; |
| break; |
| case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE: |
| cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive; |
| break; |
| } |
| unwrap(config)->setRegionSimplificationLevel(cppLevel); |
| } |
| |
| void mlirGreedyRewriteDriverConfigEnableConstantCSE( |
| MlirGreedyRewriteDriverConfig config, bool enable) { |
| unwrap(config)->enableConstantCSE(enable); |
| } |
| |
| int64_t mlirGreedyRewriteDriverConfigGetMaxIterations( |
| MlirGreedyRewriteDriverConfig config) { |
| return unwrap(config)->getMaxIterations(); |
| } |
| |
| int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites( |
| MlirGreedyRewriteDriverConfig config) { |
| return unwrap(config)->getMaxNumRewrites(); |
| } |
| |
| bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal( |
| MlirGreedyRewriteDriverConfig config) { |
| return unwrap(config)->getUseTopDownTraversal(); |
| } |
| |
| bool mlirGreedyRewriteDriverConfigIsFoldingEnabled( |
| MlirGreedyRewriteDriverConfig config) { |
| return unwrap(config)->isFoldingEnabled(); |
| } |
| |
| MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness( |
| MlirGreedyRewriteDriverConfig config) { |
| mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness(); |
| switch (cppStrictness) { |
| case mlir::GreedyRewriteStrictness::AnyOp: |
| return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP; |
| case mlir::GreedyRewriteStrictness::ExistingAndNewOps: |
| return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS; |
| case mlir::GreedyRewriteStrictness::ExistingOps: |
| return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS; |
| } |
| } |
| |
| MlirGreedySimplifyRegionLevel |
| mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel( |
| MlirGreedyRewriteDriverConfig config) { |
| mlir::GreedySimplifyRegionLevel cppLevel = |
| unwrap(config)->getRegionSimplificationLevel(); |
| switch (cppLevel) { |
| case mlir::GreedySimplifyRegionLevel::Disabled: |
| return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED; |
| case mlir::GreedySimplifyRegionLevel::Normal: |
| return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL; |
| case mlir::GreedySimplifyRegionLevel::Aggressive: |
| return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE; |
| } |
| } |
| |
| bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled( |
| MlirGreedyRewriteDriverConfig config) { |
| return unwrap(config)->isConstantCSEEnabled(); |
| } |
| |
| MlirLogicalResult |
| mlirApplyPatternsAndFoldGreedily(MlirModule op, |
| MlirFrozenRewritePatternSet patterns, |
| MlirGreedyRewriteDriverConfig config) { |
| return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns), |
| *unwrap(config))); |
| } |
| |
| MlirLogicalResult |
| mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, |
| MlirFrozenRewritePatternSet patterns, |
| MlirGreedyRewriteDriverConfig config) { |
| return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns), |
| *unwrap(config))); |
| } |
| |
| void mlirWalkAndApplyPatterns(MlirOperation op, |
| MlirFrozenRewritePatternSet patterns) { |
| mlir::walkAndApplyPatterns(unwrap(op), *unwrap(patterns)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// PatternRewriter API |
| //===----------------------------------------------------------------------===// |
| |
| MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { |
| return wrap(static_cast<mlir::RewriterBase *>(unwrap(rewriter))); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// RewritePattern API |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| |
| class ExternalRewritePattern : public mlir::RewritePattern { |
| public: |
| ExternalRewritePattern(MlirRewritePatternCallbacks callbacks, void *userData, |
| StringRef rootName, PatternBenefit benefit, |
| MLIRContext *context, |
| ArrayRef<StringRef> generatedNames) |
| : RewritePattern(rootName, benefit, context, generatedNames), |
| callbacks(callbacks), userData(userData) { |
| if (callbacks.construct) |
| callbacks.construct(userData); |
| } |
| |
| ~ExternalRewritePattern() { |
| if (callbacks.destruct) |
| callbacks.destruct(userData); |
| } |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| return unwrap(callbacks.matchAndRewrite( |
| wrap(static_cast<const mlir::RewritePattern *>(this)), wrap(op), |
| wrap(&rewriter), userData)); |
| } |
| |
| private: |
| MlirRewritePatternCallbacks callbacks; |
| void *userData; |
| }; |
| |
| } // namespace mlir |
| |
| MlirRewritePattern mlirOpRewritePatternCreate( |
| MlirStringRef rootName, unsigned benefit, MlirContext context, |
| MlirRewritePatternCallbacks callbacks, void *userData, |
| size_t nGeneratedNames, MlirStringRef *generatedNames) { |
| std::vector<mlir::StringRef> generatedNamesVec; |
| generatedNamesVec.reserve(nGeneratedNames); |
| for (size_t i = 0; i < nGeneratedNames; ++i) { |
| generatedNamesVec.push_back(unwrap(generatedNames[i])); |
| } |
| return wrap(new mlir::ExternalRewritePattern( |
| callbacks, userData, unwrap(rootName), PatternBenefit(benefit), |
| unwrap(context), generatedNamesVec)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// RewritePatternSet API |
| //===----------------------------------------------------------------------===// |
| |
| MlirRewritePatternSet mlirRewritePatternSetCreate(MlirContext context) { |
| return wrap(new mlir::RewritePatternSet(unwrap(context))); |
| } |
| |
| void mlirRewritePatternSetDestroy(MlirRewritePatternSet set) { |
| delete unwrap(set); |
| } |
| |
| void mlirRewritePatternSetAdd(MlirRewritePatternSet set, |
| MlirRewritePattern pattern) { |
| std::unique_ptr<mlir::RewritePattern> patternPtr( |
| const_cast<mlir::RewritePattern *>(unwrap(pattern))); |
| pattern.ptr = nullptr; |
| unwrap(set)->add(std::move(patternPtr)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// PDLPatternModule API |
| //===----------------------------------------------------------------------===// |
| |
| #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
| MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { |
| return wrap(new mlir::PDLPatternModule( |
| mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); |
| } |
| |
| void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { |
| delete unwrap(op); |
| op.ptr = nullptr; |
| } |
| |
| MlirRewritePatternSet |
| mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { |
| auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); |
| op.ptr = nullptr; |
| return wrap(m); |
| } |
| |
| MlirValue mlirPDLValueAsValue(MlirPDLValue value) { |
| return wrap(unwrap(value)->dyn_cast<mlir::Value>()); |
| } |
| |
| MlirType mlirPDLValueAsType(MlirPDLValue value) { |
| return wrap(unwrap(value)->dyn_cast<mlir::Type>()); |
| } |
| |
| MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { |
| return wrap(unwrap(value)->dyn_cast<mlir::Operation *>()); |
| } |
| |
| MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { |
| return wrap(unwrap(value)->dyn_cast<mlir::Attribute>()); |
| } |
| |
| void mlirPDLResultListPushBackValue(MlirPDLResultList results, |
| MlirValue value) { |
| unwrap(results)->push_back(unwrap(value)); |
| } |
| |
| void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { |
| unwrap(results)->push_back(unwrap(value)); |
| } |
| |
| void mlirPDLResultListPushBackOperation(MlirPDLResultList results, |
| MlirOperation value) { |
| unwrap(results)->push_back(unwrap(value)); |
| } |
| |
| void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, |
| MlirAttribute value) { |
| unwrap(results)->push_back(unwrap(value)); |
| } |
| |
| inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) { |
| std::vector<MlirPDLValue> mlirValues; |
| mlirValues.reserve(values.size()); |
| for (auto &value : values) { |
| mlirValues.push_back(wrap(&value)); |
| } |
| return mlirValues; |
| } |
| |
| void mlirPDLPatternModuleRegisterRewriteFunction( |
| MlirPDLPatternModule pdlModule, MlirStringRef name, |
| MlirPDLRewriteFunction rewriteFn, void *userData) { |
| unwrap(pdlModule)->registerRewriteFunction( |
| unwrap(name), |
| [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, |
| ArrayRef<PDLValue> values) -> LogicalResult { |
| std::vector<MlirPDLValue> mlirValues = wrap(values); |
| return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), |
| mlirValues.size(), mlirValues.data(), |
| userData)); |
| }); |
| } |
| |
| void mlirPDLPatternModuleRegisterConstraintFunction( |
| MlirPDLPatternModule pdlModule, MlirStringRef name, |
| MlirPDLConstraintFunction constraintFn, void *userData) { |
| unwrap(pdlModule)->registerConstraintFunction( |
| unwrap(name), |
| [userData, constraintFn](PatternRewriter &rewriter, |
| PDLResultList &results, |
| ArrayRef<PDLValue> values) -> LogicalResult { |
| std::vector<MlirPDLValue> mlirValues = wrap(values); |
| return unwrap(constraintFn(wrap(&rewriter), wrap(&results), |
| mlirValues.size(), mlirValues.data(), |
| userData)); |
| }); |
| } |
| #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |