| //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Utilities for generating tiled loops for matrix operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Transforms/Utils/MatrixUtils.h" |
| #include "llvm/Analysis/DomTreeUpdater.h" |
| #include "llvm/Analysis/LoopInfo.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/MDBuilder.h" |
| #include "llvm/IR/ProfDataUtils.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/Support/CommandLine.h" |
| |
| using namespace llvm; |
| |
| namespace llvm { |
| extern cl::opt<bool> ProfcheckDisableMetadataFixes; |
| } // end namespace llvm |
| |
| BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, |
| ConstantInt *Bound, ConstantInt *Step, |
| StringRef Name, IRBuilderBase &B, |
| DomTreeUpdater &DTU, Loop *L, LoopInfo &LI) { |
| LLVMContext &Ctx = Preheader->getContext(); |
| BasicBlock *Header = BasicBlock::Create( |
| Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); |
| BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", |
| Header->getParent(), Exit); |
| BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", |
| Header->getParent(), Exit); |
| |
| Type *I32Ty = Type::getInt64Ty(Ctx); |
| UncondBrInst::Create(Body, Header); |
| UncondBrInst::Create(Latch, Body); |
| PHINode *IV = |
| PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); |
| IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); |
| |
| B.SetInsertPoint(Latch); |
| Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); |
| Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); |
| auto *BR = B.CreateCondBr(Cond, Header, Exit); |
| if (!ProfcheckDisableMetadataFixes) { |
| assert(Step->getZExtValue() != 0 && |
| "Expected a non-zero step size. This is chosen by the pass and " |
| "should always be non-zero to imply a finite loop."); |
| MDBuilder MDB(Preheader->getContext()); |
| setFittedBranchWeights( |
| *BR, {Bound->getZExtValue() / Step->getZExtValue(), 1}, false); |
| } |
| IV->addIncoming(Inc, Latch); |
| |
| UncondBrInst *PreheaderBr = cast<UncondBrInst>(Preheader->getTerminator()); |
| BasicBlock *Tmp = PreheaderBr->getSuccessor(); |
| PreheaderBr->setSuccessor(0, Header); |
| DTU.applyUpdatesPermissive({ |
| {DominatorTree::Delete, Preheader, Tmp}, |
| {DominatorTree::Insert, Header, Body}, |
| {DominatorTree::Insert, Body, Latch}, |
| {DominatorTree::Insert, Latch, Header}, |
| {DominatorTree::Insert, Latch, Exit}, |
| {DominatorTree::Insert, Preheader, Header}, |
| }); |
| |
| L->addBasicBlockToLoop(Header, LI); |
| L->addBasicBlockToLoop(Body, LI); |
| L->addBasicBlockToLoop(Latch, LI); |
| return Body; |
| } |
| |
| // Creates the following loop nest skeleton: |
| // for C = 0; C < NumColumns; C += TileSize |
| // for R = 0; R < NumRows; R += TileSize |
| // for K = 0; K < Inner ; K += TileSize |
| BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, |
| IRBuilderBase &B, DomTreeUpdater &DTU, |
| LoopInfo &LI) { |
| Loop *ColumnLoopInfo = LI.AllocateLoop(); |
| Loop *RowLoopInfo = LI.AllocateLoop(); |
| Loop *KLoopInfo = LI.AllocateLoop(); |
| RowLoopInfo->addChildLoop(KLoopInfo); |
| ColumnLoopInfo->addChildLoop(RowLoopInfo); |
| if (Loop *ParentL = LI.getLoopFor(Start)) |
| ParentL->addChildLoop(ColumnLoopInfo); |
| else |
| LI.addTopLevelLoop(ColumnLoopInfo); |
| |
| BasicBlock *ColBody = |
| CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), |
| "cols", B, DTU, ColumnLoopInfo, LI); |
| ColumnLoop.Latch = ColBody->getSingleSuccessor(); |
| BasicBlock *RowBody = |
| CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows), |
| B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI); |
| RowLoop.Latch = RowBody->getSingleSuccessor(); |
| |
| BasicBlock *InnerBody = |
| CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner), |
| B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI); |
| KLoop.Latch = InnerBody->getSingleSuccessor(); |
| ColumnLoop.Header = ColBody->getSinglePredecessor(); |
| RowLoop.Header = RowBody->getSinglePredecessor(); |
| KLoop.Header = InnerBody->getSinglePredecessor(); |
| RowLoop.Index = &*RowLoop.Header->begin(); |
| ColumnLoop.Index = &*ColumnLoop.Header->begin(); |
| KLoop.Index = &*KLoop.Header->begin(); |
| |
| return InnerBody; |
| } |