blob: c864556783a2a91de88f0b8eff2ccdebb17c0143 [file] [log] [blame] [edit]
//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
// the pass was taken from SPIRV-LLVM translator.
//
//===----------------------------------------------------------------------===//
#include "SPIRV.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include <list>
#define DEBUG_TYPE "spirv-regularizer"
using namespace llvm;
namespace {
struct SPIRVRegularizer : public FunctionPass {
public:
static char ID;
SPIRVRegularizer() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override;
StringRef getPassName() const override { return "SPIR-V Regularizer"; }
void getAnalysisUsage(AnalysisUsage &AU) const override {
FunctionPass::getAnalysisUsage(AU);
}
private:
void runLowerConstExpr(Function &F);
void runLowerI1Comparisons(Function &F);
};
} // namespace
char SPIRVRegularizer::ID = 0;
INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
false)
// Since SPIR-V cannot represent constant expression, constant expressions
// in LLVM IR need to be lowered to instructions. For each function,
// the constant expressions used by instructions of the function are replaced
// by instructions placed in the entry block since it dominates all other BBs.
// Each constant expression only needs to be lowered once in each function
// and all uses of it by instructions in that function are replaced by
// one instruction.
// TODO: remove redundant instructions for common subexpression.
void SPIRVRegularizer::runLowerConstExpr(Function &F) {
LLVMContext &Ctx = F.getContext();
std::list<Instruction *> WorkList;
for (auto &II : instructions(F))
WorkList.push_back(&II);
auto FBegin = F.begin();
while (!WorkList.empty()) {
Instruction *II = WorkList.front();
auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
if (isa<Function>(V))
return V;
auto *CE = cast<ConstantExpr>(V);
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
auto ReplInst = CE->getAsInstruction();
auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
ReplInst->insertBefore(InsPoint->getIterator());
LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
std::vector<Instruction *> Users;
// Do not replace use during iteration of use. Do it in another loop.
for (auto U : CE->users()) {
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
auto InstUser = dyn_cast<Instruction>(U);
// Only replace users in scope of current function.
if (InstUser && InstUser->getParent()->getParent() == &F)
Users.push_back(InstUser);
}
for (auto &User : Users) {
if (ReplInst->getParent() == User->getParent() &&
User->comesBefore(ReplInst))
ReplInst->moveBefore(User->getIterator());
User->replaceUsesOfWith(CE, ReplInst);
}
return ReplInst;
};
WorkList.pop_front();
auto LowerConstantVec = [&II, &LowerOp, &WorkList,
&Ctx](ConstantVector *Vec,
unsigned NumOfOp) -> Value * {
if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
return isa<ConstantExpr>(V) || isa<Function>(V);
})) {
// Expand a vector of constexprs and construct it back with
// series of insertelement instructions.
std::list<Value *> OpList;
std::transform(Vec->op_begin(), Vec->op_end(),
std::back_inserter(OpList),
[LowerOp](Value *V) { return LowerOp(V); });
Value *Repl = nullptr;
unsigned Idx = 0;
auto *PhiII = dyn_cast<PHINode>(II);
Instruction *InsPoint =
PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
std::list<Instruction *> ReplList;
for (auto V : OpList) {
if (auto *Inst = dyn_cast<Instruction>(V))
ReplList.push_back(Inst);
Repl = InsertElementInst::Create(
(Repl ? Repl : PoisonValue::get(Vec->getType())), V,
ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
InsPoint->getIterator());
}
WorkList.splice(WorkList.begin(), ReplList);
return Repl;
}
return nullptr;
};
for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
auto *Op = II->getOperand(OI);
if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
Value *ReplInst = LowerConstantVec(Vec, OI);
if (ReplInst)
II->replaceUsesOfWith(Op, ReplInst);
} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
if (!ConstMD)
continue;
Constant *C = ConstMD->getValue();
Value *ReplInst = nullptr;
if (auto *Vec = dyn_cast<ConstantVector>(C))
ReplInst = LowerConstantVec(Vec, OI);
if (auto *CE = dyn_cast<ConstantExpr>(C))
ReplInst = LowerOp(CE);
if (!ReplInst)
continue;
Metadata *RepMD = ValueAsMetadata::get(ReplInst);
Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
II->setOperand(OI, RepMDVal);
WorkList.push_front(cast<Instruction>(ReplInst));
}
}
}
}
// Lower i1 comparisons with certain predicates to logical operations.
// The backend treats i1 as boolean values, and SPIR-V only allows logical
// operations for boolean values. This function lowers i1 comparisons with
// certain predicates to logical operations to generate valid SPIR-V.
void SPIRVRegularizer::runLowerI1Comparisons(Function &F) {
for (auto &I : make_early_inc_range(instructions(F))) {
auto *Cmp = dyn_cast<ICmpInst>(&I);
if (!Cmp)
continue;
bool IsI1 = Cmp->getOperand(0)->getType()->isIntegerTy(1);
if (!IsI1)
continue;
auto Pred = Cmp->getPredicate();
bool IsTargetPred =
Pred >= ICmpInst::ICMP_UGT && Pred <= ICmpInst::ICMP_SLE;
if (!IsTargetPred)
continue;
Value *P = Cmp->getOperand(0);
Value *Q = Cmp->getOperand(1);
IRBuilder<> Builder(Cmp);
Value *Result = nullptr;
switch (Pred) {
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SLT:
// Result = p & !q
Result = Builder.CreateAnd(P, Builder.CreateNot(Q));
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SGT:
// Result = q & !p
Result = Builder.CreateAnd(Q, Builder.CreateNot(P));
break;
case ICmpInst::ICMP_ULE:
case ICmpInst::ICMP_SGE:
// Result = q | !p
Result = Builder.CreateOr(Q, Builder.CreateNot(P));
break;
case ICmpInst::ICMP_UGE:
case ICmpInst::ICMP_SLE:
// Result = p | !q
Result = Builder.CreateOr(P, Builder.CreateNot(Q));
break;
default:
llvm_unreachable("Unexpected predicate");
}
Result->takeName(Cmp);
Cmp->replaceAllUsesWith(Result);
Cmp->eraseFromParent();
}
}
bool SPIRVRegularizer::runOnFunction(Function &F) {
runLowerI1Comparisons(F);
runLowerConstExpr(F);
return true;
}
FunctionPass *llvm::createSPIRVRegularizerPass() {
return new SPIRVRegularizer();
}