blob: 19906f15ae85f6eacffa9f135117893badcc54ab [file] [log] [blame]
//===- InterpreterPass.cpp - Transform dialect interpreter pass -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Transforms/Passes.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
using namespace mlir;
namespace mlir {
namespace transform {
#define GEN_PASS_DEF_INTERPRETERPASS
#include "mlir/Dialect/Transform/Transforms/Passes.h.inc"
} // namespace transform
} // namespace mlir
/// Returns the payload operation to be used as payload root:
/// - the operation nested under `passRoot` that has the given tag attribute,
/// must be unique;
/// - the `passRoot` itself if the tag is empty.
static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
// Fast return.
if (tag.empty())
return passRoot;
// Walk to do a lookup.
Operation *target = nullptr;
auto tagAttrName = StringAttr::get(
passRoot->getContext(), transform::TransformDialect::kTargetTagAttrName);
WalkResult walkResult = passRoot->walk([&](Operation *op) {
auto attr = op->getAttrOfType<StringAttr>(tagAttrName);
if (!attr || attr.getValue() != tag)
return WalkResult::advance();
if (!target) {
target = op;
return WalkResult::advance();
}
InFlightDiagnostic diag = op->emitError()
<< "repeated operation with the target tag '"
<< tag << "'";
diag.attachNote(target->getLoc()) << "previously seen operation";
return WalkResult::interrupt();
});
if (!target) {
passRoot->emitError()
<< "could not find the operation with transform.target_tag=\"" << tag
<< "\" attribute";
return nullptr;
}
return walkResult.wasInterrupted() ? nullptr : target;
}
namespace {
class InterpreterPass
: public transform::impl::InterpreterPassBase<InterpreterPass> {
// Parses the pass arguments to bind trailing arguments of the entry point.
std::optional<RaggedArray<transform::MappedValue>>
parseArguments(Operation *payloadRoot) {
MLIRContext *context = payloadRoot->getContext();
SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
trailingBindings.resize(debugBindTrailingArgs.size());
// Construct lists of op names to match.
SmallVector<std::optional<OperationName>> debugBindNames;
debugBindNames.reserve(debugBindTrailingArgs.size());
for (auto &&[position, nameString] :
llvm::enumerate(debugBindTrailingArgs)) {
StringRef name = nameString;
// Parse the integer literals.
if (name.starts_with("#")) {
debugBindNames.push_back(std::nullopt);
StringRef lhs = "";
StringRef rhs = name.drop_front();
do {
std::tie(lhs, rhs) = rhs.split(';');
int64_t value;
if (lhs.getAsInteger(10, value)) {
emitError(UnknownLoc::get(context))
<< "couldn't parse integer pass argument " << name;
return std::nullopt;
}
trailingBindings[position].push_back(
Builder(context).getI64IntegerAttr(value));
} while (!rhs.empty());
} else if (name.starts_with("^")) {
debugBindNames.emplace_back(OperationName(name.drop_front(), context));
} else {
debugBindNames.emplace_back(OperationName(name, context));
}
}
// Collect operations or results for extra bindings.
payloadRoot->walk([&](Operation *payload) {
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
if (!name || payload->getName() != *name)
continue;
if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
.starts_with("^")) {
llvm::append_range(trailingBindings[position], payload->getResults());
} else {
trailingBindings[position].push_back(payload);
}
}
});
RaggedArray<transform::MappedValue> bindings;
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
bindings.push_back(std::move(trailing));
return bindings;
}
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp transformModule =
transform::detail::getPreloadedTransformModule(context);
Operation *payloadRoot =
findPayloadRoot(getOperation(), debugPayloadRootTag);
if (!payloadRoot)
return signalPassFailure();
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
getOperation(), transformModule, entryPoint);
if (!transformEntryPoint)
return signalPassFailure();
std::optional<RaggedArray<transform::MappedValue>> bindings =
parseArguments(payloadRoot);
if (!bindings)
return signalPassFailure();
if (failed(transform::applyTransformNamedSequence(
*bindings,
cast<transform::TransformOpInterface>(transformEntryPoint),
transformModule,
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
return signalPassFailure();
}
}
private:
/// Transform interpreter options.
transform::TransformOptions options;
};
} // namespace