blob: f781ed2d591b42080257f841af5dd154a8462ca7 [file] [log] [blame]
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Func dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
FailureOr<func::FuncOp>
func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
ArrayRef<unsigned> newArgsOrder,
ArrayRef<unsigned> newResultsOrder) {
// Generate an empty new function operation with the same name as the
// original.
assert(funcOp.getNumArguments() == newArgsOrder.size() &&
"newArgsOrder must match the number of arguments in the function");
assert(funcOp.getNumResults() == newResultsOrder.size() &&
"newResultsOrder must match the number of results in the function");
if (!funcOp.getBody().hasOneBlock())
return rewriter.notifyMatchFailure(
funcOp, "expected function to have exactly one block");
ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
SmallVector<Type> newInputTypes, newOutputTypes;
SmallVector<Location> locs;
for (unsigned int idx : newArgsOrder) {
newInputTypes.push_back(origInputTypes[idx]);
locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
}
for (unsigned int idx : newResultsOrder)
newOutputTypes.push_back(origOutputTypes[idx]);
rewriter.setInsertionPoint(funcOp);
auto newFuncOp = func::FuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(newInputTypes, newOutputTypes));
Region &newRegion = newFuncOp.getBody();
rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
newFuncOp.setVisibility(funcOp.getVisibility());
newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
// Map the arguments of the original function to the new function in
// the new order and adjust the attributes accordingly.
IRMapping operandMapper;
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
funcOp.getAllArgAttrs(argAttrs);
for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
newFuncOp.getArgument(i));
newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
}
funcOp.getAllResultAttrs(resultAttrs);
for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
// Clone the operations from the original function to the new function.
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
for (Operation &op : funcOp.getOps())
rewriter.clone(op, operandMapper);
// Handle the return operation.
auto returnOp = cast<func::ReturnOp>(
newFuncOp.getFunctionBody().begin()->getTerminator());
SmallVector<Value> newReturnValues;
for (unsigned int idx : newResultsOrder)
newReturnValues.push_back(returnOp.getOperand(idx));
rewriter.setInsertionPoint(returnOp);
auto newReturnOp =
func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
rewriter.eraseOp(returnOp);
rewriter.eraseOp(funcOp);
return newFuncOp;
}
func::CallOp
func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
ArrayRef<unsigned> newArgsOrder,
ArrayRef<unsigned> newResultsOrder) {
assert(
callOp.getNumOperands() == newArgsOrder.size() &&
"newArgsOrder must match the number of operands in the call operation");
assert(
callOp.getNumResults() == newResultsOrder.size() &&
"newResultsOrder must match the number of results in the call operation");
SmallVector<Value> newArgsOrderValues;
for (unsigned int argIdx : newArgsOrder)
newArgsOrderValues.push_back(callOp.getOperand(argIdx));
SmallVector<Type> newResultTypes;
for (unsigned int resIdx : newResultsOrder)
newResultTypes.push_back(callOp.getResult(resIdx).getType());
// Replace the kernel call operation with a new one that has the
// reordered arguments.
rewriter.setInsertionPoint(callOp);
auto newCallOp =
func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
newResultTypes, newArgsOrderValues);
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
newCallOp.getResult(newIndex));
rewriter.eraseOp(callOp);
return newCallOp;
}