blob: b04003b5090fa5e41c22ebd0c481996ba82922dd [file] [log] [blame] [edit]
//===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Utilities for generating tiled loops for matrix operations.
//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/MatrixUtils.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
using namespace llvm;
namespace llvm {
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
} // end namespace llvm
BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
ConstantInt *Bound, ConstantInt *Step,
StringRef Name, IRBuilderBase &B,
DomTreeUpdater &DTU, Loop *L, LoopInfo &LI) {
LLVMContext &Ctx = Preheader->getContext();
BasicBlock *Header = BasicBlock::Create(
Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
Header->getParent(), Exit);
BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
Header->getParent(), Exit);
Type *I32Ty = Type::getInt64Ty(Ctx);
UncondBrInst::Create(Body, Header);
UncondBrInst::Create(Latch, Body);
PHINode *IV =
PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
B.SetInsertPoint(Latch);
Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
auto *BR = B.CreateCondBr(Cond, Header, Exit);
if (!ProfcheckDisableMetadataFixes) {
assert(Step->getZExtValue() != 0 &&
"Expected a non-zero step size. This is chosen by the pass and "
"should always be non-zero to imply a finite loop.");
MDBuilder MDB(Preheader->getContext());
setFittedBranchWeights(
*BR, {Bound->getZExtValue() / Step->getZExtValue(), 1}, false);
}
IV->addIncoming(Inc, Latch);
UncondBrInst *PreheaderBr = cast<UncondBrInst>(Preheader->getTerminator());
BasicBlock *Tmp = PreheaderBr->getSuccessor();
PreheaderBr->setSuccessor(0, Header);
DTU.applyUpdatesPermissive({
{DominatorTree::Delete, Preheader, Tmp},
{DominatorTree::Insert, Header, Body},
{DominatorTree::Insert, Body, Latch},
{DominatorTree::Insert, Latch, Header},
{DominatorTree::Insert, Latch, Exit},
{DominatorTree::Insert, Preheader, Header},
});
L->addBasicBlockToLoop(Header, LI);
L->addBasicBlockToLoop(Body, LI);
L->addBasicBlockToLoop(Latch, LI);
return Body;
}
// Creates the following loop nest skeleton:
// for C = 0; C < NumColumns; C += TileSize
// for R = 0; R < NumRows; R += TileSize
// for K = 0; K < Inner ; K += TileSize
BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, DomTreeUpdater &DTU,
LoopInfo &LI) {
Loop *ColumnLoopInfo = LI.AllocateLoop();
Loop *RowLoopInfo = LI.AllocateLoop();
Loop *KLoopInfo = LI.AllocateLoop();
RowLoopInfo->addChildLoop(KLoopInfo);
ColumnLoopInfo->addChildLoop(RowLoopInfo);
if (Loop *ParentL = LI.getLoopFor(Start))
ParentL->addChildLoop(ColumnLoopInfo);
else
LI.addTopLevelLoop(ColumnLoopInfo);
BasicBlock *ColBody =
CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
"cols", B, DTU, ColumnLoopInfo, LI);
ColumnLoop.Latch = ColBody->getSingleSuccessor();
BasicBlock *RowBody =
CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
RowLoop.Latch = RowBody->getSingleSuccessor();
BasicBlock *InnerBody =
CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
KLoop.Latch = InnerBody->getSingleSuccessor();
ColumnLoop.Header = ColBody->getSinglePredecessor();
RowLoop.Header = RowBody->getSinglePredecessor();
KLoop.Header = InnerBody->getSinglePredecessor();
RowLoop.Index = &*RowLoop.Header->begin();
ColumnLoop.Index = &*ColumnLoop.Header->begin();
KLoop.Index = &*KLoop.Header->begin();
return InnerBody;
}