blob: 1b6eba94c5eac9da8575d4e685d0f4a4156232d4 [file] [log] [blame]
//====- LowerCIRLoopToSCF.cpp - Lowering from CIR Loop to SCF -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of CIR loop operations to SCF.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/LowerToMLIR.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace cir;
using namespace llvm;
namespace cir {
class SCFLoop {
public:
SCFLoop(mlir::cir::ForOp op, mlir::ConversionPatternRewriter *rewriter)
: forOp(op), rewriter(rewriter) {}
int64_t getStep() { return step; }
mlir::Value getLowerBound() { return lowerBound; }
mlir::Value getUpperBound() { return upperBound; }
int64_t findStepAndIV(mlir::Value &addr);
mlir::cir::CmpOp findCmpOp();
mlir::Value findIVInitValue();
void analysis();
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
void transferToSCFForOp();
private:
mlir::cir::ForOp forOp;
mlir::cir::CmpOp cmpOp;
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
mlir::ConversionPatternRewriter *rewriter;
int64_t step = 0;
};
class SCFWhileLoop {
public:
SCFWhileLoop(mlir::cir::WhileOp op, mlir::cir::WhileOp::Adaptor adaptor,
mlir::ConversionPatternRewriter *rewriter)
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
void transferToSCFWhileOp();
private:
mlir::cir::WhileOp whileOp;
mlir::cir::WhileOp::Adaptor adaptor;
mlir::ConversionPatternRewriter *rewriter;
};
class SCFDoLoop {
public:
SCFDoLoop(mlir::cir::DoWhileOp op, mlir::cir::DoWhileOp::Adaptor adaptor,
mlir::ConversionPatternRewriter *rewriter)
: DoOp(op), adaptor(adaptor), rewriter(rewriter) {}
void transferToSCFWhileOp();
private:
mlir::cir::DoWhileOp DoOp;
mlir::cir::DoWhileOp::Adaptor adaptor;
mlir::ConversionPatternRewriter *rewriter;
};
static int64_t getConstant(mlir::cir::ConstantOp op) {
auto attr = op->getAttrs().front().getValue();
const auto IntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(attr);
return IntAttr.getValue().getSExtValue();
}
int64_t SCFLoop::findStepAndIV(mlir::Value &addr) {
auto *stepBlock =
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
assert(stepBlock && "Can not find step block");
int64_t step = 0;
mlir::Value IV = nullptr;
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
for (mlir::Operation &op : *stepBlock)
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(op)) {
addr = loadOp.getAddr();
IV = loadOp.getResult();
} else if (auto cop = dyn_cast<mlir::cir::ConstantOp>(op)) {
if (step)
llvm_unreachable(
"Not support multiple constant in step calculation yet");
step = getConstant(cop);
} else if (auto bop = dyn_cast<mlir::cir::BinOp>(op)) {
if (bop.getLhs() != IV)
llvm_unreachable("Find BinOp not operate on IV");
if (bop.getKind() != mlir::cir::BinOpKind::Add)
llvm_unreachable(
"Not support BinOp other than Add in step calculation yet");
} else if (auto uop = dyn_cast<mlir::cir::UnaryOp>(op)) {
if (uop.getInput() != IV)
llvm_unreachable("Find UnaryOp not operate on IV");
if (uop.getKind() == mlir::cir::UnaryOpKind::Inc)
step = 1;
else if (uop.getKind() == mlir::cir::UnaryOpKind::Dec)
llvm_unreachable("Not support decrement step yet");
} else if (auto storeOp = dyn_cast<mlir::cir::StoreOp>(op)) {
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
}
assert(step && "Can't find step when lowering ForOp");
return step;
}
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
if (!op)
return false;
if (isa<mlir::cir::LoadOp>(op)) {
if (!op->getOperand(0))
return false;
if (op->getOperand(0) == IVAddr)
return true;
}
return false;
}
mlir::cir::CmpOp SCFLoop::findCmpOp() {
cmpOp = nullptr;
for (auto *user : IVAddr.getUsers()) {
if (user->getParentRegion() != &forOp.getCond())
continue;
if (auto loadOp = dyn_cast<mlir::cir::LoadOp>(*user)) {
if (!loadOp->hasOneUse())
continue;
if (auto op = dyn_cast<mlir::cir::CmpOp>(*loadOp->user_begin())) {
cmpOp = op;
break;
}
}
}
if (!cmpOp)
llvm_unreachable("Can't find loop CmpOp");
auto type = cmpOp.getLhs().getType();
if (!mlir::isa<mlir::cir::IntType>(type))
llvm_unreachable("Non-integer type IV is not supported");
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
if (!lhsDefOp)
llvm_unreachable("Can't find IV load");
if (!isIVLoad(lhsDefOp, IVAddr))
llvm_unreachable("cmpOp LHS is not IV");
if (cmpOp.getKind() != mlir::cir::CmpOpKind::le &&
cmpOp.getKind() != mlir::cir::CmpOpKind::lt)
llvm_unreachable("Not support lowering other than le or lt comparison");
return cmpOp;
}
mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
int addend) {
auto type = V.getType();
auto c1 = rewriter->create<mlir::arith::ConstantOp>(
loc, type, mlir::IntegerAttr::get(type, addend));
return rewriter->create<mlir::arith::AddIOp>(loc, V, c1);
}
// Return IV initial value by searching the store before the loop.
// The operations before the loop have been transferred to MLIR.
// So we need to go through getRemappedValue to find the value.
mlir::Value SCFLoop::findIVInitValue() {
auto remapAddr = rewriter->getRemappedValue(IVAddr);
if (!remapAddr)
return nullptr;
if (!remapAddr.hasOneUse())
return nullptr;
auto memrefStore = dyn_cast<mlir::memref::StoreOp>(*remapAddr.user_begin());
if (!memrefStore)
return nullptr;
return memrefStore->getOperand(0);
}
void SCFLoop::analysis() {
step = findStepAndIV(IVAddr);
cmpOp = findCmpOp();
auto IVInit = findIVInitValue();
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
// So we could get the value by getRemappedValue.
auto IVEndBound = rewriter->getRemappedValue(cmpOp.getRhs());
// If the loop end bound is not loop invariant and can't be hoisted.
// The following assertion will be triggerred.
assert(IVEndBound && "can't find IV end boundary");
if (step > 0) {
lowerBound = IVInit;
if (cmpOp.getKind() == mlir::cir::CmpOpKind::lt)
upperBound = IVEndBound;
else if (cmpOp.getKind() == mlir::cir::CmpOpKind::le)
upperBound = plusConstant(IVEndBound, cmpOp.getLoc(), 1);
}
assert(lowerBound && "can't find loop lower bound");
assert(upperBound && "can't find loop upper bound");
}
void SCFLoop::transferToSCFForOp() {
auto ub = getUpperBound();
auto lb = getLowerBound();
auto loc = forOp.getLoc();
auto type = lb.getType();
auto step = rewriter->create<mlir::arith::ConstantOp>(
loc, type, mlir::IntegerAttr::get(type, getStep()));
auto scfForOp = rewriter->create<mlir::scf::ForOp>(loc, lb, ub, step);
SmallVector<mlir::Value> bbArg;
rewriter->eraseOp(&scfForOp.getBody()->back());
rewriter->inlineBlockBefore(&forOp.getBody().front(), scfForOp.getBody(),
scfForOp.getBody()->end(), bbArg);
scfForOp->walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {
if (isa<mlir::cir::BreakOp>(op) || isa<mlir::cir::ContinueOp>(op) ||
isa<mlir::cir::IfOp>(op))
llvm_unreachable(
"Not support lowering loop with break, continue or if yet");
// Replace the IV usage to scf loop induction variable.
if (isIVLoad(op, IVAddr)) {
// Replace CIR IV load with arith.addi scf.IV, 0.
// The replacement makes the SCF IV can be automatically propogated
// by OpAdaptor for individual IV user lowering.
// The redundant arith.addi can be removed by later MLIR passes.
rewriter->setInsertionPoint(op);
auto newIV = plusConstant(scfForOp.getInductionVar(), loc, 0);
rewriter->replaceOp(op, newIV.getDefiningOp());
}
return mlir::WalkResult::advance();
});
}
void SCFWhileLoop::transferToSCFWhileOp() {
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
rewriter->createBlock(&scfWhileOp.getBefore());
rewriter->createBlock(&scfWhileOp.getAfter());
rewriter->inlineBlockBefore(&whileOp.getCond().front(),
scfWhileOp.getBeforeBody(),
scfWhileOp.getBeforeBody()->end());
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
scfWhileOp.getAfterBody(),
scfWhileOp.getAfterBody()->end());
}
void SCFDoLoop::transferToSCFWhileOp() {
auto beforeBuilder = [&](mlir::OpBuilder &builder, mlir::Location loc,
mlir::ValueRange args) {
auto *newBlock = builder.getBlock();
rewriter->mergeBlocks(&DoOp.getBody().front(), newBlock);
auto *yieldOp = newBlock->getTerminator();
rewriter->mergeBlocks(&DoOp.getCond().front(), newBlock,
yieldOp->getResults());
rewriter->eraseOp(yieldOp);
};
auto afterBuilder = [&](mlir::OpBuilder &builder, mlir::Location loc,
mlir::ValueRange args) {
rewriter->create<mlir::scf::YieldOp>(loc, args);
};
rewriter->create<mlir::scf::WhileOp>(DoOp.getLoc(), DoOp->getResultTypes(),
adaptor.getOperands(), beforeBuilder,
afterBuilder);
}
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
public:
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::ForOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFLoop loop(op, &rewriter);
loop.analysis();
loop.transferToSCFForOp();
rewriter.eraseOp(op);
return mlir::success();
}
};
class CIRWhileOpLowering
: public mlir::OpConversionPattern<mlir::cir::WhileOp> {
public:
using OpConversionPattern<mlir::cir::WhileOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::WhileOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFWhileLoop loop(op, adaptor, &rewriter);
loop.transferToSCFWhileOp();
rewriter.eraseOp(op);
return mlir::success();
}
};
class CIRDoOpLowering : public mlir::OpConversionPattern<mlir::cir::DoWhileOp> {
public:
using OpConversionPattern<mlir::cir::DoWhileOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::DoWhileOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
SCFDoLoop loop(op, adaptor, &rewriter);
loop.transferToSCFWhileOp();
rewriter.eraseOp(op);
return mlir::success();
}
};
class CIRConditionOpLowering
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
public:
using OpConversionPattern<mlir::cir::ConditionOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(mlir::cir::ConditionOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::WhileOp>([&](auto) {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
op, i1Condition, parentOp->getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
}
};
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering,
CIRDoOpLowering>(converter, patterns.getContext());
}
} // namespace cir