blob: 005e84cb8e9f9a2a307094f4b54f58e62e04b422 [file] [log] [blame]
//===- OMPFunctionFiltering.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to filter out functions intended for the host
// when compiling for the device and vice versa.
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/SmallVector.h"
namespace fir {
#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
using namespace mlir;
namespace {
class OMPFunctionFilteringPass
: public fir::impl::OMPFunctionFilteringBase<OMPFunctionFilteringPass> {
public:
OMPFunctionFilteringPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
OpBuilder opBuilder(context);
auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
if (!op || !op.getIsTargetDevice())
return;
op->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
// Do not filter functions with target regions inside, because they have
// to be available for both host and device so that regular and reverse
// offloading can be supported.
bool hasTargetRegion =
funcOp
->walk<WalkOrder::PreOrder>(
[&](omp::TargetOp) { return WalkResult::interrupt(); })
.wasInterrupted();
omp::DeclareTargetDeviceType declareType =
omp::DeclareTargetDeviceType::host;
auto declareTargetOp =
dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
if (declareTargetOp && declareTargetOp.isDeclareTarget())
declareType = declareTargetOp.getDeclareTargetDeviceType();
// Filtering a function here means deleting it if it doesn't contain a
// target region. Else we explicitly set the omp.declare_target
// attribute. The second stage of function filtering at the MLIR to LLVM
// IR translation level will remove functions that contain the target
// region from the generated llvm IR.
if (declareType == omp::DeclareTargetDeviceType::host) {
SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
for (SymbolTable::SymbolUse use : funcUses) {
Operation *callOp = use.getUser();
if (auto internalFunc = mlir::dyn_cast<func::FuncOp>(callOp)) {
// Do not delete internal procedures holding the symbol of their
// Fortran host procedure as attribute.
internalFunc->removeAttr(fir::getHostSymbolAttrName());
// Set public visibility so that the function is not deleted by MLIR
// because unused. Changing it is OK here because the function will
// be deleted anyway in the second filtering phase.
internalFunc.setVisibility(mlir::SymbolTable::Visibility::Public);
continue;
}
// If the callOp has users then replace them with Undef values.
if (!callOp->use_empty()) {
SmallVector<Value> undefResults;
for (Value res : callOp->getResults()) {
opBuilder.setInsertionPoint(callOp);
undefResults.emplace_back(
opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
}
callOp->replaceAllUsesWith(undefResults);
}
// Remove the callOp
callOp->erase();
}
if (!hasTargetRegion) {
funcOp.erase();
return WalkResult::skip();
}
if (declareTargetOp)
declareTargetOp.setDeclareTarget(declareType,
omp::DeclareTargetCaptureClause::to);
}
return WalkResult::advance();
});
}
};
} // namespace
std::unique_ptr<Pass> fir::createOMPFunctionFilteringPass() {
return std::make_unique<OMPFunctionFilteringPass>();
}