blob: e7bb9937fbf707b44751013926e53cc08876dcac [file] [edit]
//===- ACCSpecializeForDevice.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass strips OpenACC constructs that are invalid or unnecessary inside
// device code (specialized acc routines or compute construct regions).
//
// Overview:
// ---------
// In a specialized acc routine or compute construct, many OpenACC operations
// do not make sense because they are host-side constructs. This pass removes
// or transforms these operations appropriately:
//
// - Data operations that manage device memory from host perspective
// - Compute constructs that launch kernels (we're already on device)
// - Runtime operations like init/shutdown/set/wait
//
// Transformations:
// ----------------
// The pass applies the following transformations:
//
// 1. Data Entry Ops (replaced with var operand):
// acc.attach, acc.copyin, acc.create, acc.declare_device_resident,
// acc.declare_link, acc.deviceptr, acc.get_deviceptr, acc.nocreate,
// acc.present, acc.update_device, acc.use_device
//
// 2. Data Exit Ops (erased):
// acc.copyout, acc.delete, acc.detach, acc.update_host
//
// 3. Structured Data/Compute Constructs (region inlined):
// acc.data, acc.host_data, acc.kernel_environment, acc.parallel,
// acc.serial, acc.kernels
//
// 4. Unstructured Data Ops (erased):
// acc.enter_data, acc.exit_data, acc.update, acc.declare_enter,
// acc.declare_exit
//
// 5. Runtime Ops (erased):
// acc.init, acc.shutdown, acc.set, acc.wait
//
// Scope of Application:
// ---------------------
// - For functions with `acc.specialized_routine` attribute: patterns are
// applied to the entire function body.
// - For non-specialized functions: patterns are applied only to ACC
// operations INSIDE compute constructs (parallel, serial, kernels),
// not to the compute constructs themselves or their data operands.
//
// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
// transformed by this pass as they are valid in device code.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace acc {
#define GEN_PASS_DEF_ACCSPECIALIZEFORDEVICE
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir
using namespace mlir;
using namespace mlir::acc;
namespace {
class ACCSpecializeForDevice
: public acc::impl::ACCSpecializeForDeviceBase<ACCSpecializeForDevice> {
public:
using ACCSpecializeForDeviceBase<
ACCSpecializeForDevice>::ACCSpecializeForDeviceBase;
void runOnOperation() override {
func::FuncOp func = getOperation();
RewritePatternSet patterns(&getContext());
acc::populateACCSpecializeForDevicePatterns(patterns);
GreedyRewriteConfig config;
config.setUseTopDownTraversal(true);
if (acc::isSpecializedAccRoutine(func)) {
// For specialized acc routines, apply patterns to the entire function
(void)applyPatternsGreedily(func, std::move(patterns), config);
} else {
// For non-specialized functions, apply patterns only to ACC operations
// inside compute constructs (not to the compute constructs themselves).
// Use ExistingOps strictness so the greedy driver does not expand the
// worklist to parent ops, which would accidentally unwrap the compute
// construct (e.g. after inlining acc routines with their own data
// regions).
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
SmallVector<Operation *> opsToTransform;
func.walk([&](Operation *op) {
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) {
// Walk inside the compute construct and collect ACC ops
op->walk([&](Operation *innerOp) {
// Skip the compute construct itself
if (innerOp == op)
return;
if (isa<acc::OpenACCDialect>(innerOp->getDialect()))
opsToTransform.push_back(innerOp);
});
}
});
if (!opsToTransform.empty())
(void)applyOpPatternsGreedily(opsToTransform, std::move(patterns),
config);
}
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population functions
//===----------------------------------------------------------------------===//
void mlir::acc::populateACCSpecializeForDevicePatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
// Declare patterns - erase declare_enter and its associated declare_exit
patterns.insert<ACCDeclareEnterOpConversion>(context);
// Data entry ops - replaced with their var operand
// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
// included here - they are valid in device code
patterns.insert<ACCOpReplaceWithVarConversion<acc::AttachOp>,
ACCOpReplaceWithVarConversion<acc::CopyinOp>,
ACCOpReplaceWithVarConversion<acc::CreateOp>,
ACCOpReplaceWithVarConversion<acc::DeclareDeviceResidentOp>,
ACCOpReplaceWithVarConversion<acc::DeclareLinkOp>,
ACCOpReplaceWithVarConversion<acc::DevicePtrOp>,
ACCOpReplaceWithVarConversion<acc::GetDevicePtrOp>,
ACCOpReplaceWithVarConversion<acc::NoCreateOp>,
ACCOpReplaceWithVarConversion<acc::PresentOp>,
ACCOpReplaceWithVarConversion<acc::UpdateDeviceOp>,
ACCOpReplaceWithVarConversion<acc::UseDeviceOp>>(context);
// Data exit ops - simply erased (no results)
patterns.insert<ACCOpEraseConversion<acc::CopyoutOp>,
ACCOpEraseConversion<acc::DeleteOp>,
ACCOpEraseConversion<acc::DetachOp>,
ACCOpEraseConversion<acc::UpdateHostOp>>(context);
// Structured data constructs - unwrap their regions
patterns.insert<ACCRegionUnwrapConversion<acc::DataOp>,
ACCRegionUnwrapConversion<acc::HostDataOp>,
ACCRegionUnwrapConversion<acc::KernelEnvironmentOp>>(context);
// Compute constructs - unwrap their regions
patterns.insert<ACCRegionUnwrapConversion<acc::ParallelOp>,
ACCRegionUnwrapConversion<acc::SerialOp>,
ACCRegionUnwrapConversion<acc::KernelsOp>>(context);
// Unstructured data operations - erase them
patterns.insert<ACCOpEraseConversion<acc::EnterDataOp>,
ACCOpEraseConversion<acc::ExitDataOp>,
ACCOpEraseConversion<acc::UpdateOp>>(context);
// Runtime operations - erase them
patterns.insert<
ACCOpEraseConversion<acc::InitOp>, ACCOpEraseConversion<acc::ShutdownOp>,
ACCOpEraseConversion<acc::SetOp>, ACCOpEraseConversion<acc::WaitOp>>(
context);
}