blob: 1ed14255bf5e0c71c0a4570fd139cac479afca3b [file] [log] [blame] [edit]
//===- Transform.cpp - C Interface for Transform dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Dialect/Transform.h"
#include "mlir/CAPI/Interfaces.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Rewrite.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
transform::TransformDialect)
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyOpType(MlirType type) {
return isa<transform::AnyOpType>(unwrap(type));
}
MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) {
return wrap(transform::AnyOpType::getTypeID());
}
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyOpTypeGetName(void) {
return wrap(transform::AnyOpType::name);
}
//===---------------------------------------------------------------------===//
// AnyParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyParamType(MlirType type) {
return isa<transform::AnyParamType>(unwrap(type));
}
MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) {
return wrap(transform::AnyParamType::getTypeID());
}
MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
return wrap(transform::AnyParamType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyParamTypeGetName(void) {
return wrap(transform::AnyParamType::name);
}
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyValueType(MlirType type) {
return isa<transform::AnyValueType>(unwrap(type));
}
MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) {
return wrap(transform::AnyValueType::getTypeID());
}
MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
return wrap(transform::AnyValueType::get(unwrap(ctx)));
}
MlirStringRef mlirTransformAnyValueTypeGetName(void) {
return wrap(transform::AnyValueType::name);
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformOperationType(MlirType type) {
return isa<transform::OperationType>(unwrap(type));
}
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
return wrap(transform::OperationType::getTypeID());
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}
MlirStringRef mlirTransformOperationTypeGetName(void) {
return wrap(transform::OperationType::name);
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
//===---------------------------------------------------------------------===//
// ParamType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformParamType(MlirType type) {
return isa<transform::ParamType>(unwrap(type));
}
MlirTypeID mlirTransformParamTypeGetTypeID(void) {
return wrap(transform::ParamType::getTypeID());
}
MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
}
MlirStringRef mlirTransformParamTypeGetName(void) {
return wrap(transform::ParamType::name);
}
MlirType mlirTransformParamTypeGetType(MlirType type) {
return wrap(cast<transform::ParamType>(unwrap(type)).getType());
}
//===---------------------------------------------------------------------===//
// TransformRewriter
//===---------------------------------------------------------------------===//
/// Casts a `MlirTransformRewriter` to a `MlirRewriterBase`.
MlirRewriterBase mlirTransformRewriterAsBase(MlirTransformRewriter rewriter) {
mlir::transform::TransformRewriter *t = unwrap(rewriter);
mlir::RewriterBase *base = static_cast<mlir::RewriterBase *>(t);
return wrap(base);
}
//===---------------------------------------------------------------------===//
// TransformResults
//===---------------------------------------------------------------------===//
void mlirTransformResultsSetOps(MlirTransformResults results, MlirValue result,
intptr_t numOps, MlirOperation *ops) {
SmallVector<Operation *> opsVec;
opsVec.reserve(numOps);
for (intptr_t i = 0; i < numOps; ++i)
opsVec.push_back(unwrap(ops[i]));
unwrap(results)->set(cast<OpResult>(unwrap(result)), opsVec);
}
void mlirTransformResultsSetValues(MlirTransformResults results,
MlirValue result, intptr_t numValues,
MlirValue *values) {
SmallVector<Value> valuesVec;
valuesVec.reserve(numValues);
for (intptr_t i = 0; i < numValues; ++i)
valuesVec.push_back(unwrap(values[i]));
unwrap(results)->setValues(cast<OpResult>(unwrap(result)), valuesVec);
}
void mlirTransformResultsSetParams(MlirTransformResults results,
MlirValue result, intptr_t numParams,
MlirAttribute *params) {
SmallVector<Attribute> paramsVec;
paramsVec.reserve(numParams);
for (intptr_t i = 0; i < numParams; ++i)
paramsVec.push_back(unwrap(params[i]));
unwrap(results)->setParams(cast<OpResult>(unwrap(result)), paramsVec);
}
//===---------------------------------------------------------------------===//
// TransformState
//===---------------------------------------------------------------------===//
void mlirTransformStateForEachPayloadOp(MlirTransformState state,
MlirValue value,
MlirOperationCallback callback,
void *userData) {
for (Operation *op : unwrap(state)->getPayloadOps(unwrap(value)))
callback(wrap(op), userData);
}
void mlirTransformStateForEachPayloadValue(MlirTransformState state,
MlirValue value,
MlirValueCallback callback,
void *userData) {
for (Value val : unwrap(state)->getPayloadValues(unwrap(value)))
callback(wrap(val), userData);
}
void mlirTransformStateForEachParam(MlirTransformState state, MlirValue value,
MlirAttributeCallback callback,
void *userData) {
for (Attribute attr : unwrap(state)->getParams(unwrap(value)))
callback(wrap(attr), userData);
}
//===---------------------------------------------------------------------===//
// TransformOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirTransformOpInterfaceTypeID(void) {
return wrap(transform::TransformOpInterface::getInterfaceID());
}
/// Fallback model for the TransformOpInterface that uses C API callbacks.
class TransformOpInterfaceFallbackModel
: public mlir::transform::TransformOpInterface::FallbackModel<
TransformOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirTransformOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~TransformOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return transform::TransformOpInterface::getInterfaceID();
}
static bool classof(const mlir::transform::detail::
TransformOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
::mlir::DiagnosedSilenceableFailure
apply(Operation *op, ::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state) const {
assert(callbacks.apply && "apply callback not set");
MlirDiagnosedSilenceableFailure status =
callbacks.apply(wrap(op), wrap(&rewriter), wrap(&transformResults),
wrap(&state), callbacks.userData);
switch (status) {
case MlirDiagnosedSilenceableFailureSuccess:
return DiagnosedSilenceableFailure::success();
case MlirDiagnosedSilenceableFailureSilenceableFailure:
// TODO: enable passing diagnostic info from C API to C++ API.
return DiagnosedSilenceableFailure::silenceableFailure(std::move(
*(op->emitError()
<< "TransformOpInterfaceFallbackModel: silenceable failure")
.getUnderlyingDiagnostic()));
case MlirDiagnosedSilenceableFailureDefiniteFailure:
return DiagnosedSilenceableFailure::definiteFailure();
}
llvm_unreachable("unknown transform status");
}
bool allowsRepeatedHandleOperands(Operation *op) const {
assert(callbacks.allowsRepeatedHandleOperands &&
"allowsRepeatedHandleOperands callback not set");
return callbacks.allowsRepeatedHandleOperands(wrap(op), callbacks.userData);
}
private:
MlirTransformOpInterfaceCallbacks callbacks;
};
/// Attach a TransformOpInterface FallbackModel to the given named operation.
/// The FallbackModel uses the provided callbacks to implement the interface.
void mlirTransformOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirTransformOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context.
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<TransformOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<TransformOpInterfaceFallbackModel>(
opInfo->getInterface<TransformOpInterfaceFallbackModel>());
assert(model && "Failed to get TransformOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// PatternDescriptorOpInterface
//===---------------------------------------------------------------------===//
MlirTypeID mlirPatternDescriptorOpInterfaceTypeID(void) {
return wrap(transform::PatternDescriptorOpInterface::getInterfaceID());
}
/// Fallback model for the PatternDescriptorOpInterface that uses C API
/// callbacks.
class PatternDescriptorOpInterfaceFallbackModel
: public mlir::transform::PatternDescriptorOpInterface::FallbackModel<
PatternDescriptorOpInterfaceFallbackModel> {
public:
/// Sets the callbacks that this FallbackModel will use.
/// NB: the callbacks can only be set through this method as the
/// RegisteredOperationName::attachInterface mechanism default-constructs
/// the FallbackModel without being able to provide arguments.
void setCallbacks(MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
this->callbacks = callbacks;
}
~PatternDescriptorOpInterfaceFallbackModel() {
if (callbacks.destruct)
callbacks.destruct(callbacks.userData);
}
static TypeID getInterfaceID() {
return transform::PatternDescriptorOpInterface::getInterfaceID();
}
static bool
classof(const mlir::transform::detail::
PatternDescriptorOpInterfaceInterfaceTraits::Concept *op) {
// Enable casting back to the FallbackModel from the Interface. This is
// necessary as attachInterface(...) default-constructs the FallbackModel
// without being able to pass in the callbacks and returns just the Concept.
return true;
}
void populatePatterns(Operation *op, RewritePatternSet &patterns) const {
assert(callbacks.populatePatterns && "populatePatterns callback not set");
callbacks.populatePatterns(wrap(op), wrap(&patterns), callbacks.userData);
}
void populatePatternsWithState(Operation *op, RewritePatternSet &patterns,
transform::TransformState &state) const {
if (callbacks.populatePatternsWithState) {
callbacks.populatePatternsWithState(wrap(op), wrap(&patterns),
wrap(&state), callbacks.userData);
} else {
// Default implementation: call populatePatterns without state.
populatePatterns(op, patterns);
}
}
private:
MlirPatternDescriptorOpInterfaceCallbacks callbacks;
};
/// Attach a PatternDescriptorOpInterface FallbackModel to the given named
/// operation. The FallbackModel uses the provided callbacks to implement the
/// interface.
void mlirPatternDescriptorOpInterfaceAttachFallbackModel(
MlirContext ctx, MlirStringRef opName,
MlirPatternDescriptorOpInterfaceCallbacks callbacks) {
// Look up the operation definition in the context.
std::optional<RegisteredOperationName> opInfo =
RegisteredOperationName::lookup(unwrap(opName), unwrap(ctx));
assert(opInfo.has_value() && "operation not found in context");
// NB: the following default-constructs the FallbackModel _without_ being able
// to provide arguments.
opInfo->attachInterface<PatternDescriptorOpInterfaceFallbackModel>();
// Cast to get the underlying FallbackModel and set the callbacks.
auto *model = cast<PatternDescriptorOpInterfaceFallbackModel>(
opInfo->getInterface<PatternDescriptorOpInterfaceFallbackModel>());
assert(model && "Failed to get PatternDescriptorOpInterfaceFallbackModel");
model->setCallbacks(callbacks);
}
//===---------------------------------------------------------------------===//
// MemoryEffectsOpInterface helpers
//===---------------------------------------------------------------------===//
/// Set the effect for the operands to only read the transform handles.
void mlirTransformOnlyReadsHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::onlyReadsHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the operands to consuming the transform handles.
void mlirTransformConsumesHandle(MlirOpOperand *operands, intptr_t numOperands,
MlirMemoryEffectInstancesList effects) {
MutableArrayRef<OpOperand> operandArray(unwrap(*operands), numOperands);
transform::consumesHandle(operandArray, *unwrap(effects));
}
/// Set the effect for the results to that they produce transform handles.
void mlirTransformProducesHandle(MlirValue *results, intptr_t numResults,
MlirMemoryEffectInstancesList effects) {
// NB: calling `producesHandle()` `numResults` as we cannot cast array of
// `OpResult`s to a single `ResultRange` (and neither is `ResultRange` exposed
// to Python). `producesHandle` iterates over the given `ResultRange` anyway.
SmallVectorImpl<MemoryEffects::EffectInstance> &effectList = *unwrap(effects);
for (intptr_t i = 0; i < numResults; ++i) {
auto opResult = cast<OpResult>(unwrap(results[i]));
transform::producesHandle(ResultRange(opResult), effectList);
}
}
/// Set the effect of potentially modifying payload IR.
void mlirTransformModifiesPayload(MlirMemoryEffectInstancesList effects) {
transform::modifiesPayload(*unwrap(effects));
}
/// Set the effect of potentially reading payload IR.
void mlirTransformOnlyReadsPayload(MlirMemoryEffectInstancesList effects) {
transform::onlyReadsPayload(*unwrap(effects));
}