blob: 51ab796021e4b472658ce71b592c4a9a83d4f98a [file] [log] [blame]
//===- ACCUseDeviceCanonicalizer.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass canonicalizes the use_device clause on a host_data construct such
// that use_device(x) can be lowered to a simple runtime call that takes the
// actual host pointer as argument.
//
// For a use_device operand that is a box type or a reference to a box, the
// pass:
// 1. Extracts the host base address for mapping to a device address using
// acc.use_device.
// 2. Creates a new boxed descriptor with the device address as the base
// address for use inside the host_data region.
//
// The pass also removes unused use_device clauses, reducing the number of
// runtime calls.
//
// Supported use_device operand types:
//
// Scalars:
// - !fir.ref<i32>, !fir.ref<f64>, etc.
//
// Arrays:
// - Explicit shape (no descriptor): !fir.ref<!fir.array<100xi32>>
// - Adjustable size: !fir.ref<!fir.array<?xi32>>
// - Assumed shape (handled by hoistBox): !fir.box<!fir.array<?xi32>>
// - Assumed size: !fir.ref<!fir.array<?xi32>>
// - Deferred shape (handled by hoistRefToBox):
// - Allocatable: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
// - Pointer: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
// - Subarray specification (handled by hoistBox):
// !fir.box<!fir.array<?xi32>>
//
// Not yet supported:
// - Assumed rank arrays
// - Composite variables: !fir.ref<!fir.type<...>>
// - Array elements (device pointer arithmetic in host_data region)
// - Composite variable members
// - Fortran common blocks: use_device(/cm_block/)
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/OpenACC/Passes.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include <cassert>
namespace fir::acc {
#define GEN_PASS_DEF_ACCUSEDEVICECANONICALIZER
#include "flang/Optimizer/OpenACC/Passes.h.inc"
} // namespace fir::acc
#define DEBUG_TYPE "acc-use-device-canonicalizer"
using namespace mlir;
namespace {
struct UseDeviceHostDataHoisting : public OpRewritePattern<acc::HostDataOp> {
using OpRewritePattern<acc::HostDataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(acc::HostDataOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> usedOperands;
SmallVector<Value> unusedUseDeviceOperands;
SmallVector<acc::UseDeviceOp> refToBoxUseDeviceOps;
SmallVector<acc::UseDeviceOp> boxUseDeviceOps;
for (Value operand : op.getDataClauseOperands()) {
if (acc::UseDeviceOp useDeviceOp =
operand.getDefiningOp<acc::UseDeviceOp>()) {
if (fir::isBoxAddress(useDeviceOp.getVar().getType())) {
if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
refToBoxUseDeviceOps.push_back(useDeviceOp);
} else if (isa<fir::BoxType>(useDeviceOp.getVar().getType())) {
if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
boxUseDeviceOps.push_back(useDeviceOp);
}
// host_data is the only user of this use_device operand - mark for
// removal
if (llvm::hasSingleElement(useDeviceOp->getUsers()))
unusedUseDeviceOperands.push_back(useDeviceOp.getResult());
else
usedOperands.push_back(useDeviceOp.getResult());
} else {
// Operand is not an `acc.use_device` result, keep it as is.
usedOperands.push_back(operand);
}
}
assert(!usedOperands.empty() && "Host_data operation has no used operands");
if (!unusedUseDeviceOperands.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "ACCUseDeviceCanonicalizer: Removing "
<< unusedUseDeviceOperands.size()
<< " unused use_device operands from host_data operation\n");
// Update the host_data operation to have only used operands
rewriter.modifyOpInPlace(op, [&]() {
op.getDataClauseOperandsMutable().assign(usedOperands);
});
// Remove unused use_device operations
for (Value operand : unusedUseDeviceOperands) {
acc::UseDeviceOp useDeviceOp =
operand.getDefiningOp<acc::UseDeviceOp>();
LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Erasing: "
<< *useDeviceOp << "\n");
rewriter.eraseOp(useDeviceOp);
}
return success();
}
// Handle references to box types
bool modified = false;
for (acc::UseDeviceOp useDeviceOp : refToBoxUseDeviceOps)
modified |=
hoistRefToBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
// Handle box types
for (acc::UseDeviceOp useDeviceOp : boxUseDeviceOps)
modified |= hoistBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
return modified ? success() : failure();
}
private:
/// Collect users of `acc.use_device` operation inside the `acc.host_data`
/// region that need to be updated with the final replacement value.
void collectUseDeviceUsersToUpdate(
acc::UseDeviceOp useDeviceOp, acc::HostDataOp hostDataOp,
SmallVectorImpl<Operation *> &usersToUpdate) const {
for (mlir::Operation *user : useDeviceOp->getUsers())
if (hostDataOp.getRegion().isAncestor(user->getParentRegion()))
usersToUpdate.push_back(user);
}
/// Create new `acc.use_device` operation with the given box address as
/// operand. Updates the `acc.host_data` operation to use the new
/// `acc.use_device` result.
acc::UseDeviceOp createNewUseDeviceOp(PatternRewriter &rewriter,
acc::UseDeviceOp useDeviceOp,
acc::HostDataOp hostDataOp,
fir::BoxAddrOp boxAddr) const {
// Create use_device on the raw pointer
acc::UseDeviceOp newUseDeviceOp = acc::UseDeviceOp::create(
rewriter, useDeviceOp.getLoc(), boxAddr.getType(), boxAddr.getResult(),
useDeviceOp.getVarTypeAttr(), useDeviceOp.getVarPtrPtr(),
useDeviceOp.getBounds(), useDeviceOp.getAsyncOperands(),
useDeviceOp.getAsyncOperandsDeviceTypeAttr(),
useDeviceOp.getAsyncOnlyAttr(), useDeviceOp.getDataClauseAttr(),
useDeviceOp.getStructuredAttr(), useDeviceOp.getImplicitAttr(),
useDeviceOp.getModifiersAttr(), useDeviceOp.getNameAttr(),
useDeviceOp.getRecipeAttr());
LLVM_DEBUG(llvm::dbgs() << "Created new hoisted pattern for box access:\n"
<< " box_addr: " << *boxAddr << "\n"
<< " new use_device: " << *newUseDeviceOp << "\n");
// Replace the old `acc.use_device` operand in the `acc.host_data` operation
// with the new one
rewriter.modifyOpInPlace(hostDataOp, [&]() {
hostDataOp->replaceUsesOfWith(useDeviceOp.getResult(),
newUseDeviceOp.getResult());
});
return newUseDeviceOp;
}
/// Canonicalize use_device operand that is a reference to a box.
/// Transforms:
/// %3 = fir.address_of(@_QFEtgt) : !fir.ref<i32>
/// %5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
/// fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
/// %9 = acc.use_device varPtr(%0 : !fir.ref<!fir.box<!fir.ptr<i32>>>)
/// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
/// acc.host_data dataOperands(%9 : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
/// %loaded = fir.load %9 : !fir.ref<!fir.box<!fir.ptr<i32>>>
/// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
/// !fir.ptr<i32> %conv = fir.convert %addr : (!fir.ptr<i32>) -> i64
/// fir.call @foo(%conv) : (i64) -> ()
/// acc.terminator
/// }
/// into:
/// %loaded = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
/// %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
/// !fir.ptr<i32>
/// %dev_ptr = acc.use_device varPtr(%addr : !fir.ptr<i32>) ->
/// !fir.ptr<i32>
/// -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
/// acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.ptr<i32>>>)
/// {
/// %embox = fir.embox %dev_ptr : (!fir.ptr<i32>) ->
/// !fir.box<!fir.ptr<i32>> %alloca = fir.alloca !fir.box<!fir.ptr<i32>>
/// fir.store %embox to %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
/// %loaded2 = fir.load %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
/// %addr2 = fir.box_addr %loaded2 : (!fir.box<!fir.ptr<i32>>) ->
/// !fir.ptr<i32> %conv = fir.convert %addr2 : (!fir.ptr<i32>) -> i64
/// fir.call @foo(%conv) : (i64) -> ()
/// acc.terminator
/// }
bool hoistRefToBox(PatternRewriter &rewriter, Value operand,
acc::UseDeviceOp useDeviceOp,
acc::HostDataOp hostDataOp) const {
// Safety check: if the use_device operation is already using a box_addr
// result, it means it has already been processed, so skip to avoid infinite
// loop
if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Skipping "
"already processed use_device operation\n");
return false;
}
// Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
// Collect users of the original `acc.use_device` operation that need to be
// updated
SmallVector<Operation *> usersToUpdate;
collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
rewriter.setInsertionPoint(useDeviceOp);
// Create a load operation to get the box from the variable
fir::LoadOp box = fir::LoadOp::create(rewriter, useDeviceOp.getLoc(),
useDeviceOp.getVar());
// Create a box_addr operation to get the address from the box
fir::BoxAddrOp boxAddr =
fir::BoxAddrOp::create(rewriter, useDeviceOp.getLoc(), box);
acc::UseDeviceOp newUseDeviceOp =
createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
LLVM_DEBUG(llvm::dbgs()
<< "Created new hoisted pattern for pointer access:\n"
<< " load box: " << *box << "\n"
<< " box_addr: " << *boxAddr << "\n"
<< " new use_device: " << *newUseDeviceOp << "\n");
// Set insertion point to the first op inside the host_data region
rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
// Create a FirOpBuilder from the PatternRewriter using the module we got
// earlier
fir::FirOpBuilder builder(rewriter, mod);
Value newBoxwithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
builder, useDeviceOp.getLoc(), box.getResult(),
newUseDeviceOp.getResult());
// Create new memory location and store the newBoxwithDevicePtr into new
// memory location
fir::AllocaOp newMemLoc = fir::AllocaOp::create(
rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr.getType());
[[maybe_unused]] fir::StoreOp newStoreOp = fir::StoreOp::create(
rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr, newMemLoc);
LLVM_DEBUG(llvm::dbgs()
<< "host_data region updated with new host descriptor "
"containing device pointer:\n"
<< " box with device pointer: "
<< *newBoxwithDevicePtr.getDefiningOp() << "\n"
<< " mem loc: " << *newMemLoc << "\n"
<< " store op: " << *newStoreOp << "\n");
// Replace all uses of the original `acc.use_device` operation inside the
// `acc.host_data` region with the new memory location containing the box
// with device pointer
for (mlir::Operation *user : usersToUpdate)
user->replaceUsesOfWith(useDeviceOp.getResult(), newMemLoc);
assert(useDeviceOp.getResult().use_empty() &&
"expected all uses of use_device to be replaced");
rewriter.eraseOp(useDeviceOp);
return true;
}
/// Canonicalize use_device operand that is a box type.
/// Transforms:
/// %box = ... : !fir.box<!fir.array<?xi32>>
/// %dev_box = acc.use_device varPtr(%box : !fir.box<!fir.array<?xi32>>)
/// -> !fir.box<!fir.array<?xi32>>
/// acc.host_data dataOperands(%dev_box : !fir.box<!fir.array<?xi32>>) {
/// %addr = fir.box_addr %dev_box : (!fir.box<!fir.array<?xi32>>) ->
/// !fir.heap<!fir.array<?xi32>>
/// // use %addr
/// }
/// into:
/// %box = ... : !fir.box<!fir.array<?xi32>>
/// %addr = fir.box_addr %box : (!fir.box<!fir.array<?xi32>>) ->
/// !fir.heap<!fir.array<?xi32>>
/// %dev_ptr = acc.use_device varPtr(%addr : !fir.heap<!fir.array<?xi32>>)
/// -> !fir.heap<!fir.array<?xi32>>
/// acc.host_data dataOperands(%dev_ptr : !fir.heap<!fir.array<?xi32>>) {
/// %new_box = fir.embox %dev_ptr ... : !fir.box<!fir.array<?xi32>>
/// %new_addr = fir.box_addr %new_box : (!fir.box<!fir.array<?xi32>>) ->
/// !fir.heap<!fir.array<?xi32>>
/// // use %new_addr instead of %addr
/// }
bool hoistBox(PatternRewriter &rewriter, Value operand,
acc::UseDeviceOp useDeviceOp,
acc::HostDataOp hostDataOp) const {
// Safety check: if the use_device operation is already using a box_addr
// result, it means it has already been processed, so skip to avoid infinite
// loop
if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
LLVM_DEBUG(llvm::dbgs()
<< "ACCUseDeviceCanonicalizer: Skipping "
"already processed box use_device operation\n");
return false;
}
// Collect users of the original `acc.use_device` operation that need to be
// updated
SmallVector<Operation *> usersToUpdate;
collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
// Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(useDeviceOp);
// Extract the raw pointer from the box descriptor
fir::BoxAddrOp boxAddr = fir::BoxAddrOp::create(
rewriter, useDeviceOp.getLoc(), useDeviceOp.getVar());
acc::UseDeviceOp newUseDeviceOp =
createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
// Set insertion point to the first op inside the host_data region
rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
// Create a FirOpBuilder from the PatternRewriter using the module we got
// earlier
fir::FirOpBuilder builder(rewriter, mod);
// Create a new host descriptor at the start of the host_data region
// with the device pointer as the base address
Value newBoxWithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
builder, useDeviceOp.getLoc(), useDeviceOp.getVar(),
newUseDeviceOp.getResult());
LLVM_DEBUG(llvm::dbgs()
<< "host_data region updated with new host descriptor "
"containing device pointer:\n"
<< " box with device pointer: "
<< *newBoxWithDevicePtr.getDefiningOp() << "\n");
// Replace all uses of the original `acc.use_device` operation inside the
// `acc.host_data` region with the new box containing device pointer
for (mlir::Operation *user : usersToUpdate)
user->replaceUsesOfWith(useDeviceOp.getResult(), newBoxWithDevicePtr);
assert(useDeviceOp.getResult().use_empty() &&
"expected all uses of use_device to be replaced");
rewriter.eraseOp(useDeviceOp);
return true;
}
};
class ACCUseDeviceCanonicalizer
: public fir::acc::impl::ACCUseDeviceCanonicalizerBase<
ACCUseDeviceCanonicalizer> {
public:
void runOnOperation() override {
MLIRContext *context = getOperation()->getContext();
RewritePatternSet patterns(context);
// Add the custom use_device canonicalization patterns
patterns.insert<UseDeviceHostDataHoisting>(context);
// Apply patterns greedily
GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks.
config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
config.setUseTopDownTraversal(true);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
} // namespace
std::unique_ptr<mlir::Pass> fir::acc::createACCUseDeviceCanonicalizerPass() {
return std::make_unique<ACCUseDeviceCanonicalizer>();
}