blob: a2d0c796c66b54b86bf4b7f75b5bef4dc376740c [file] [log] [blame]
//===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This is a test pass for verifying FuncOp's insertArgument method.
struct TestFuncInsertArg
: public PassWrapper<TestFuncInsertArg, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-insert-arg"; }
StringRef getDescription() const final { return "Test inserting func args."; }
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_args");
if (!inserts || inserts.empty())
continue;
SmallVector<unsigned, 4> indicesToInsert;
SmallVector<Type, 4> typesToInsert;
SmallVector<DictionaryAttr, 4> attrsToInsert;
SmallVector<Optional<Location>, 4> locsToInsert;
for (auto insert : inserts.getAsRange<ArrayAttr>()) {
indicesToInsert.push_back(
insert[0].cast<IntegerAttr>().getValue().getZExtValue());
typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
attrsToInsert.push_back(insert.size() > 2
? insert[2].cast<DictionaryAttr>()
: DictionaryAttr::get(&getContext()));
locsToInsert.push_back(
insert.size() > 3
? Optional<Location>(insert[3].cast<LocationAttr>())
: Optional<Location>{});
}
func->removeAttr("test.insert_args");
func.insertArguments(indicesToInsert, typesToInsert, attrsToInsert,
locsToInsert);
}
}
};
/// This is a test pass for verifying FuncOp's insertResult method.
struct TestFuncInsertResult
: public PassWrapper<TestFuncInsertResult, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-insert-result"; }
StringRef getDescription() const final {
return "Test inserting func results.";
}
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
auto inserts = func->getAttrOfType<ArrayAttr>("test.insert_results");
if (!inserts || inserts.empty())
continue;
SmallVector<unsigned, 4> indicesToInsert;
SmallVector<Type, 4> typesToInsert;
SmallVector<DictionaryAttr, 4> attrsToInsert;
for (auto insert : inserts.getAsRange<ArrayAttr>()) {
indicesToInsert.push_back(
insert[0].cast<IntegerAttr>().getValue().getZExtValue());
typesToInsert.push_back(insert[1].cast<TypeAttr>().getValue());
attrsToInsert.push_back(insert.size() > 2
? insert[2].cast<DictionaryAttr>()
: DictionaryAttr::get(&getContext()));
}
func->removeAttr("test.insert_results");
func.insertResults(indicesToInsert, typesToInsert, attrsToInsert);
}
}
};
/// This is a test pass for verifying FuncOp's eraseArgument method.
struct TestFuncEraseArg
: public PassWrapper<TestFuncEraseArg, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-erase-arg"; }
StringRef getDescription() const final { return "Test erasing func args."; }
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
// Push back twice to test that duplicate arg indices are handled
// correctly.
indicesToErase.push_back(argIndex);
indicesToErase.push_back(argIndex);
}
}
// Reverse the order to test that unsorted index lists are handled
// correctly.
std::reverse(indicesToErase.begin(), indicesToErase.end());
func.eraseArguments(indicesToErase);
}
}
};
/// This is a test pass for verifying FuncOp's eraseResult method.
struct TestFuncEraseResult
: public PassWrapper<TestFuncEraseResult, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-erase-result"; }
StringRef getDescription() const final {
return "Test erasing func results.";
}
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
// Push back twice to test that duplicate indices are handled
// correctly.
indicesToErase.push_back(resultIndex);
indicesToErase.push_back(resultIndex);
}
}
// Reverse the order to test that unsorted index lists are handled
// correctly.
std::reverse(indicesToErase.begin(), indicesToErase.end());
func.eraseResults(indicesToErase);
}
}
};
/// This is a test pass for verifying FuncOp's setType method.
struct TestFuncSetType
: public PassWrapper<TestFuncSetType, OperationPass<ModuleOp>> {
StringRef getArgument() const final { return "test-func-set-type"; }
StringRef getDescription() const final { return "Test FuncOp::setType."; }
void runOnOperation() override {
auto module = getOperation();
SymbolTable symbolTable(module);
for (FuncOp func : module.getOps<FuncOp>()) {
auto sym = func->getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
if (!sym)
continue;
func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
}
}
};
} // end anonymous namespace
namespace mlir {
void registerTestFunc() {
PassRegistration<TestFuncInsertArg>();
PassRegistration<TestFuncInsertResult>();
PassRegistration<TestFuncEraseArg>();
PassRegistration<TestFuncEraseResult>();
PassRegistration<TestFuncSetType>();
}
} // namespace mlir