blob: b4cb0932ef631085c0decbef5894909f250e8f33 [file] [edit]
//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Func dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "func-utils"
using namespace mlir;
/// This method creates an inverse mapping of the provided map `oldToNew`.
/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
/// to new index `j`,
/// This method returns a vector where `result[j]` contains all old indices
/// that map to new index `j`.
///
/// Example:
/// ```
/// oldIdxToNewIdx = [0, 1, 2, 2, 3]
/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
/// ```
///
static llvm::SmallVector<llvm::SmallVector<int>>
getInverseMapping(ArrayRef<int> oldIdxToNewIdx) {
int numOfNewIdxs = 0;
if (!oldIdxToNewIdx.empty())
numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
newToOldIdxs[newIdx].push_back(oldIdx);
return newToOldIdxs;
}
/// This method returns a new vector of elements that are mapped from the
/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
/// is at least one old index that maps to it. Also, It assumes that mapping to
/// the same old index has the same element in the `origElements` vector.
template <typename Element>
static SmallVector<Element> getMappedElements(
ArrayRef<Element> origElements,
const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
SmallVector<Element> newElements;
for (const auto &oldIdxs : newIdxToOldIdxs) {
assert(llvm::all_of(oldIdxs,
[&origElements](int idx) -> bool {
return idx >= 0 &&
static_cast<size_t>(idx) < origElements.size();
}) &&
"idx must be less than the number of elements in the original "
"elements");
assert(!oldIdxs.empty() && "oldIdx must not be empty");
Element origTypeToCheck = origElements[oldIdxs.front()];
assert(llvm::all_of(oldIdxs,
[&](int idx) -> bool {
return origElements[idx] == origTypeToCheck;
}) &&
"all oldIdxs must be equal");
newElements.push_back(origTypeToCheck);
}
return newElements;
}
FailureOr<func::FuncOp>
func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
ArrayRef<int> oldArgIdxToNewArgIdx,
ArrayRef<int> oldResIdxToNewResIdx) {
// Generate an empty new function operation with the same name as the
// original.
assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
"oldArgIdxToNewArgIdx must match the number of arguments in the "
"function");
assert(
funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
"oldResIdxToNewResIdx must match the number of results in the function");
if (!funcOp.getBody().hasOneBlock())
return rewriter.notifyMatchFailure(
funcOp, "expected function to have exactly one block");
// We may have some duplicate arguments in the old function, i.e.
// in the mapping `newArgIdxToOldArgIdxs` for some new argument index
// there may be multiple old argument indices.
llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
getInverseMapping(oldArgIdxToNewArgIdx);
SmallVector<Type> newInputTypes = getMappedElements(
funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
SmallVector<Location> locs;
for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
getInverseMapping(oldResIdxToNewResIdx);
SmallVector<Type> newOutputTypes = getMappedElements(
funcOp.getFunctionType().getResults(), newResToOldResIdxs);
rewriter.setInsertionPoint(funcOp);
auto newFuncOp = func::FuncOp::create(
rewriter, funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(newInputTypes, newOutputTypes));
Region &newRegion = newFuncOp.getBody();
rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
newFuncOp.setVisibility(funcOp.getVisibility());
// Map the arguments of the original function to the new function in
// the new order and adjust the attributes accordingly.
IRMapping operandMapper;
SmallVector<DictionaryAttr> argAttrs, resultAttrs;
funcOp.getAllArgAttrs(argAttrs);
for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
operandMapper.map(funcOp.getArgument(oldArgIdx),
newFuncOp.getArgument(newArgIdx));
for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
funcOp.getAllResultAttrs(resultAttrs);
for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
// Clone the operations from the original function to the new function.
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
for (Operation &op : funcOp.getOps())
rewriter.clone(op, operandMapper);
// Handle the return operation.
auto returnOp = cast<func::ReturnOp>(
newFuncOp.getFunctionBody().begin()->getTerminator());
SmallVector<Value> newReturnValues;
for (const auto &oldResIdxs : newResToOldResIdxs)
newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
rewriter.setInsertionPoint(returnOp);
func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
rewriter.eraseOp(returnOp);
rewriter.eraseOp(funcOp);
return newFuncOp;
}
func::CallOp
func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
ArrayRef<int> oldArgIdxToNewArgIdx,
ArrayRef<int> oldResIdxToNewResIdx) {
assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
"oldArgIdxToNewArgIdx must match the number of operands in the call "
"operation");
assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
"oldResIdxToNewResIdx must match the number of results in the call "
"operation");
SmallVector<Value> origOperands = callOp.getOperands();
SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
getInverseMapping(oldArgIdxToNewArgIdx);
SmallVector<Value> newOperandsValues =
getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
getInverseMapping(oldResIdxToNewResIdx);
SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
SmallVector<Type> newResultTypes =
getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
// Replace the kernel call operation with a new one that has the
// mapped arguments.
rewriter.setInsertionPoint(callOp);
auto newCallOp =
func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
newResultTypes, newOperandsValues);
newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
newCallOp.getResult(newResIdx));
rewriter.eraseOp(callOp);
return newCallOp;
}
FailureOr<std::pair<func::FuncOp, func::CallOp>>
func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
ModuleOp moduleOp) {
SmallVector<func::CallOp> callOps;
auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
if (callOp.getCallee() == funcOp.getSymName()) {
if (!callOps.empty())
// Only support one callOp for now
return WalkResult::interrupt();
callOps.push_back(callOp);
}
return WalkResult::advance();
});
if (traversalResult.wasInterrupted()) {
LDBG() << "function " << funcOp.getName() << " has more than one callOp";
return failure();
}
if (callOps.empty()) {
LDBG() << "function " << funcOp.getName() << " does not have any callOp";
return failure();
}
func::CallOp callOp = callOps.front();
// Create mapping for arguments (deduplicate operands)
SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
llvm::DenseMap<Value, int> valueToNewArgIdx;
for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
auto [iterator, inserted] = valueToNewArgIdx.insert(
{operand, static_cast<int>(valueToNewArgIdx.size())});
// Reduce the duplicate operands and maintain the original order.
oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
}
bool hasDuplicateOperands =
valueToNewArgIdx.size() != callOp.getNumOperands();
if (!hasDuplicateOperands) {
LDBG() << "function " << funcOp.getName()
<< " does not have duplicate operands";
return failure();
}
// Create identity mapping for results (no deduplication needed)
SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
for (int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
oldResIdxToNewResIdx[resultIdx] = resultIdx;
// Apply the transformation to create new function and call operations
FailureOr<func::FuncOp> newFuncOpOrFailure = replaceFuncWithNewMapping(
rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
if (failed(newFuncOpOrFailure)) {
LDBG() << "failed to replace function signature with name "
<< funcOp.getName() << " with new order";
return failure();
}
func::CallOp newCallOp = replaceCallOpWithNewMapping(
rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
return std::make_pair(*newFuncOpOrFailure, newCallOp);
}