blob: 7368b64efd9aec451486b7e526183554d14ada83 [file] [log] [blame]
//===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file Pass to transform <256 x i32> load/store
/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
/// provides simple operation on x86_amx. The basic elementwise operation
/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
/// and only AMX intrinsics can operate on the type, we need transform
/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
/// not be combined with load/store, we transform the bitcast to amx load/store
/// and <256 x i32> store/load.
///
/// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
/// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
/// because that is necessary for AMX fast register allocation. (In Fast
/// registera allocation, register will be allocated before spill/reload, so
/// there is no additional register for amx to identify the step in spill.)
/// The volatileTileData() will handle this case.
/// e.g.
/// ----------------------------------------------------------
/// | def %td = ... |
/// | ... |
/// | "use %td" |
/// ----------------------------------------------------------
/// will transfer to -->
/// ----------------------------------------------------------
/// | def %td = ... |
/// | call void @llvm.x86.tilestored64.internal(mem, %td) |
/// | ... |
/// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
/// | "use %td2" |
/// ----------------------------------------------------------
//
//===----------------------------------------------------------------------===//
//
#include "X86.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
using namespace PatternMatch;
#define DEBUG_TYPE "lower-amx-type"
static bool isAMXCast(Instruction *II) {
return match(II,
m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
}
static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
Type *Ty) {
Function &F = *BB->getParent();
Module *M = BB->getModule();
const DataLayout &DL = M->getDataLayout();
LLVMContext &Ctx = Builder.getContext();
auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
unsigned AllocaAS = DL.getAllocaAddrSpace();
AllocaInst *AllocaRes =
new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front());
AllocaRes->setAlignment(AllocaAlignment);
return AllocaRes;
}
static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
for (Instruction &I : F.getEntryBlock())
if (!isa<AllocaInst>(&I))
return &I;
llvm_unreachable("No terminator in the entry block!");
}
static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
IRBuilder<> Builder(II);
Value *Row = nullptr, *Col = nullptr;
switch (II->getIntrinsicID()) {
default:
llvm_unreachable("Expect amx intrinsics");
case Intrinsic::x86_tileloadd64_internal:
case Intrinsic::x86_tileloaddt164_internal:
case Intrinsic::x86_tilestored64_internal: {
Row = II->getArgOperand(0);
Col = II->getArgOperand(1);
break;
}
// a * b + c
// The shape depends on which operand.
case Intrinsic::x86_tdpbssd_internal:
case Intrinsic::x86_tdpbsud_internal:
case Intrinsic::x86_tdpbusd_internal:
case Intrinsic::x86_tdpbuud_internal:
case Intrinsic::x86_tdpbf16ps_internal: {
switch (OpNo) {
case 3:
Row = II->getArgOperand(0);
Col = II->getArgOperand(1);
break;
case 4:
Row = II->getArgOperand(0);
Col = II->getArgOperand(2);
break;
case 5:
if (isa<ConstantInt>(II->getArgOperand(2)))
Row = Builder.getInt16(
(cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4);
else if (isa<Instruction>(II->getArgOperand(2))) {
// When it is not a const value and it is not a function argument, we
// create Row after the definition of II->getOperand(2) instead of
// before II. For example, II is %118, we try to getshape for %117:
// %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
// i32> %115).
// %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
// %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
// %117).
// If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
// definition is after its user(new tileload for %117).
// So, the best choice is to create %row right after the definition of
// %106.
Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2)));
Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4));
cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2)));
} else {
// When it is not a const value and it is a function argument, we create
// Row at the entry bb.
IRBuilder<> NewBuilder(
getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4));
}
Col = II->getArgOperand(1);
break;
}
break;
}
}
return std::make_pair(Row, Col);
}
namespace {
class X86LowerAMXType {
Function &Func;
// In AMX intrinsics we let Shape = {Row, Col}, but the
// RealCol = Col / ElementSize. We may use the RealCol
// as a new Row for other new created AMX intrinsics.
std::map<Value *, Value *> Col2Row;
public:
X86LowerAMXType(Function &F) : Func(F) {}
bool visit();
void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
bool transformBitcast(BitCastInst *Bitcast);
};
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = bitcast <256 x i32> %src to x86_amx
// -->
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
// i8* %addr, i64 %stride64)
void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
Value *Row = nullptr, *Col = nullptr;
Use &U = *(Bitcast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = cast<IntrinsicInst>(U.getUser());
std::tie(Row, Col) = getShape(II, OpNo);
IRBuilder<> Builder(Bitcast);
// Use the maximun column as stride.
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
Bitcast->replaceAllUsesWith(NewInst);
}
// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
// %stride);
// %13 = bitcast x86_amx %src to <256 x i32>
// store <256 x i32> %13, <256 x i32>* %addr, align 64
// -->
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
// %stride64, %13)
void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
Value *Tile = Bitcast->getOperand(0);
auto *II = cast<IntrinsicInst>(Tile);
// Tile is output from AMX intrinsic. The first operand of the
// intrinsic is row, the second operand of the intrinsic is column.
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
IRBuilder<> Builder(ST);
// Use the maximum column as stride. It must be the same with load
// stride.
Value *Stride = Builder.getInt64(64);
Value *I8Ptr =
Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
if (Bitcast->hasOneUse())
return;
// %13 = bitcast x86_amx %src to <256 x i32>
// store <256 x i32> %13, <256 x i32>* %addr, align 64
// %add = <256 x i32> %13, <256 x i32> %src2
// -->
// %13 = bitcast x86_amx %src to <256 x i32>
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
// %stride64, %13)
// %14 = load <256 x i32>, %addr
// %add = <256 x i32> %14, <256 x i32> %src2
Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
Bitcast->replaceAllUsesWith(Vec);
}
// transform bitcast to <store, load> instructions.
bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
IRBuilder<> Builder(Bitcast);
AllocaInst *AllocaAddr;
Value *I8Ptr, *Stride;
auto *Src = Bitcast->getOperand(0);
auto Prepare = [&](Type *MemTy) {
AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
Stride = Builder.getInt64(64);
};
if (Bitcast->getType()->isX86_AMXTy()) {
// %2 = bitcast <256 x i32> %src to x86_amx
// -->
// %addr = alloca <256 x i32>, align 64
// store <256 x i32> %src, <256 x i32>* %addr, align 64
// %addr2 = bitcast <256 x i32>* to i8*
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
// i8* %addr2,
// i64 64)
Use &U = *(Bitcast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
if (!II)
return false; // May be bitcast from x86amx to <256 x i32>.
Prepare(Bitcast->getOperand(0)->getType());
Builder.CreateStore(Src, AllocaAddr);
// TODO we can pick an constant operand for the shape.
Value *Row = nullptr, *Col = nullptr;
std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
Value *NewInst = Builder.CreateIntrinsic(
Intrinsic::x86_tileloadd64_internal, None, Args);
Bitcast->replaceAllUsesWith(NewInst);
} else {
// %2 = bitcast x86_amx %src to <256 x i32>
// -->
// %addr = alloca <256 x i32>, align 64
// %addr2 = bitcast <256 x i32>* to i8*
// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
// i8* %addr2, i64 %stride)
// %2 = load <256 x i32>, <256 x i32>* %addr, align 64
auto *II = dyn_cast<IntrinsicInst>(Src);
if (!II)
return false; // May be bitcast from <256 x i32> to x86amx.
Prepare(Bitcast->getType());
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
Bitcast->replaceAllUsesWith(NewInst);
}
return true;
}
bool X86LowerAMXType::visit() {
SmallVector<Instruction *, 8> DeadInsts;
Col2Row.clear();
for (BasicBlock *BB : post_order(&Func)) {
for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
if (!Bitcast)
continue;
Value *Src = Bitcast->getOperand(0);
if (Bitcast->getType()->isX86_AMXTy()) {
if (Bitcast->user_empty()) {
DeadInsts.push_back(Bitcast);
continue;
}
LoadInst *LD = dyn_cast<LoadInst>(Src);
if (!LD) {
if (transformBitcast(Bitcast))
DeadInsts.push_back(Bitcast);
continue;
}
// If load has mutli-user, duplicate a vector load.
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = bitcast <256 x i32> %src to x86_amx
// %add = add <256 x i32> %src, <256 x i32> %src2
// -->
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
// i8* %addr, i64 %stride64)
// %add = add <256 x i32> %src, <256 x i32> %src2
// If load has one user, the load will be eliminated in DAG ISel.
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
// %2 = bitcast <256 x i32> %src to x86_amx
// -->
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
// i8* %addr, i64 %stride64)
combineLoadBitcast(LD, Bitcast);
DeadInsts.push_back(Bitcast);
if (LD->hasOneUse())
DeadInsts.push_back(LD);
} else if (Src->getType()->isX86_AMXTy()) {
if (Bitcast->user_empty()) {
DeadInsts.push_back(Bitcast);
continue;
}
StoreInst *ST = nullptr;
for (Use &U : Bitcast->uses()) {
ST = dyn_cast<StoreInst>(U.getUser());
if (ST)
break;
}
if (!ST) {
if (transformBitcast(Bitcast))
DeadInsts.push_back(Bitcast);
continue;
}
// If bitcast (%13) has one use, combine bitcast and store to amx store.
// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
// %stride);
// %13 = bitcast x86_amx %src to <256 x i32>
// store <256 x i32> %13, <256 x i32>* %addr, align 64
// -->
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
// %stride64, %13)
//
// If bitcast (%13) has multi-use, transform as below.
// %13 = bitcast x86_amx %src to <256 x i32>
// store <256 x i32> %13, <256 x i32>* %addr, align 64
// %add = <256 x i32> %13, <256 x i32> %src2
// -->
// %13 = bitcast x86_amx %src to <256 x i32>
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
// %stride64, %13)
// %14 = load <256 x i32>, %addr
// %add = <256 x i32> %14, <256 x i32> %src2
//
combineBitcastStore(Bitcast, ST);
// Delete user first.
DeadInsts.push_back(ST);
DeadInsts.push_back(Bitcast);
}
}
}
bool C = !DeadInsts.empty();
for (auto *Inst : DeadInsts)
Inst->eraseFromParent();
return C;
}
} // anonymous namespace
static Value *getAllocaPos(BasicBlock *BB) {
Module *M = BB->getModule();
Function *F = BB->getParent();
IRBuilder<> Builder(&F->getEntryBlock().front());
const DataLayout &DL = M->getDataLayout();
unsigned AllocaAS = DL.getAllocaAddrSpace();
Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
AllocaInst *AllocaRes =
new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
BasicBlock::iterator Iter = AllocaRes->getIterator();
++Iter;
Builder.SetInsertPoint(&*Iter);
Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
return I8Ptr;
}
static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
auto *II = cast<IntrinsicInst>(TileDef);
assert(II && "Not tile intrinsic!");
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
BasicBlock *BB = TileDef->getParent();
BasicBlock::iterator Iter = TileDef->getIterator();
IRBuilder<> Builder(BB, ++Iter);
Value *Stride = Builder.getInt64(64);
std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
Instruction *TileStore =
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
return TileStore;
}
static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
Value *V = U.get();
assert(V->getType()->isX86_AMXTy() && "Not define tile!");
// Get tile shape.
IntrinsicInst *II = nullptr;
if (IsPHI) {
Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
II = cast<IntrinsicInst>(PhiOp);
} else {
II = cast<IntrinsicInst>(V);
}
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
Instruction *UserI = dyn_cast<Instruction>(U.getUser());
IRBuilder<> Builder(UserI);
Value *Stride = Builder.getInt64(64);
std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
Value *TileLoad =
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
UserI->replaceUsesOfWith(V, TileLoad);
}
static bool isIncomingOfPHI(Instruction *I) {
for (Use &U : I->uses()) {
User *V = U.getUser();
if (isa<PHINode>(V))
return true;
}
return false;
}
// Let all AMX tile data become volatile data, shorten the life range
// of each tile register before fast register allocation.
namespace {
class X86VolatileTileData {
Function &F;
public:
X86VolatileTileData(Function &Func) : F(Func) {}
Value *updatePhiIncomings(BasicBlock *BB,
SmallVector<Instruction *, 2> &Incomings);
void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
bool volatileTileData();
void volatileTilePHI(PHINode *Inst);
void volatileTileNonPHI(Instruction *I);
};
Value *X86VolatileTileData::updatePhiIncomings(
BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
Value *I8Ptr = getAllocaPos(BB);
for (auto *I : Incomings) {
User *Store = createTileStore(I, I8Ptr);
// All its uses (except phi) should load from stored mem.
for (Use &U : I->uses()) {
User *V = U.getUser();
if (isa<PHINode>(V) || V == Store)
continue;
replaceWithTileLoad(U, I8Ptr);
}
}
return I8Ptr;
}
void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
Value *StorePtr) {
for (Use &U : PHI->uses())
replaceWithTileLoad(U, StorePtr, true);
PHI->eraseFromParent();
}
// Smilar with volatileTileNonPHI, this function only handle PHI Nodes
// and their related AMX intrinsics.
// 1) PHI Def should change to tileload.
// 2) PHI Incoming Values should tilestored in just after their def.
// 3) The mem of these tileload and tilestores should be same.
// e.g.
// ------------------------------------------------------
// bb_dom:
// ...
// br i1 %bool.cond, label %if.else, label %if.then
//
// if.then:
// def %t0 = ...
// ...
// use %t0
// ...
// br label %if.end
//
// if.else:
// def %t1 = ...
// br label %if.end
//
// if.end:
// %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
// ...
// use %td
// ------------------------------------------------------
// -->
// ------------------------------------------------------
// bb_entry:
// %mem = alloca <256 x i32>, align 1024 *
// ...
// bb_dom:
// ...
// br i1 %bool.cond, label %if.else, label %if.then
//
// if.then:
// def %t0 = ...
// call void @llvm.x86.tilestored64.internal(mem, %t0) *
// ...
// %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
// use %t0` *
// ...
// br label %if.end
//
// if.else:
// def %t1 = ...
// call void @llvm.x86.tilestored64.internal(mem, %t1) *
// br label %if.end
//
// if.end:
// ...
// %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
// use %td
// ------------------------------------------------------
void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
BasicBlock *BB = PHI->getParent();
SmallVector<Instruction *, 2> Incomings;
for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
Value *Op = PHI->getIncomingValue(I);
Instruction *Inst = dyn_cast<Instruction>(Op);
assert(Inst && "We shouldn't fold AMX instrution!");
Incomings.push_back(Inst);
}
Value *StorePtr = updatePhiIncomings(BB, Incomings);
replacePhiDefWithLoad(PHI, StorePtr);
}
// Store the defined tile and load it before use.
// All its users are not PHI.
// e.g.
// ------------------------------------------------------
// def %td = ...
// ...
// "use %td"
// ------------------------------------------------------
// -->
// ------------------------------------------------------
// def %td = ...
// call void @llvm.x86.tilestored64.internal(mem, %td)
// ...
// %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
// "use %td2"
// ------------------------------------------------------
void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
BasicBlock *BB = I->getParent();
Value *I8Ptr = getAllocaPos(BB);
User *Store = createTileStore(I, I8Ptr);
// All its uses should load from stored mem.
for (Use &U : I->uses()) {
User *V = U.getUser();
assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
if (V != Store)
replaceWithTileLoad(U, I8Ptr);
}
}
// Volatile Tile Model:
// 1) All the uses of tile data comes from tileload in time.
// 2) All the defs of tile data tilestore into mem immediately.
// For example:
// --------------------------------------------------------------------------
// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
// call void @llvm.x86.tilestored64.internal(... td) area
// --------------------------------------------------------------------------
// 3) No terminator, call or other amx instructions in the key amx area.
bool X86VolatileTileData::volatileTileData() {
bool Changed = false;
for (BasicBlock &BB : F) {
SmallVector<Instruction *, 2> PHIInsts;
SmallVector<Instruction *, 8> AMXDefInsts;
for (Instruction &I : BB) {
if (!I.getType()->isX86_AMXTy())
continue;
if (isa<PHINode>(&I))
PHIInsts.push_back(&I);
else
AMXDefInsts.push_back(&I);
}
// First we "volatile" the non-phi related amx intrinsics.
for (Instruction *I : AMXDefInsts) {
if (isIncomingOfPHI(I))
continue;
volatileTileNonPHI(I);
Changed = true;
}
for (Instruction *I : PHIInsts) {
volatileTilePHI(dyn_cast<PHINode>(I));
Changed = true;
}
}
return Changed;
}
} // anonymous namespace
namespace {
class X86LowerAMXCast {
Function &Func;
public:
X86LowerAMXCast(Function &F) : Func(F) {}
bool combineAMXcast(TargetLibraryInfo *TLI);
bool transformAMXCast(IntrinsicInst *AMXCast);
bool transformAllAMXCast();
bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
SmallSetVector<Instruction *, 16> &DeadInst);
};
static bool DCEInstruction(Instruction *I,
SmallSetVector<Instruction *, 16> &WorkList,
const TargetLibraryInfo *TLI) {
if (isInstructionTriviallyDead(I, TLI)) {
salvageDebugInfo(*I);
salvageKnowledge(I);
// Null out all of the instruction's operands to see if any operand becomes
// dead as we go.
for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
Value *OpV = I->getOperand(i);
I->setOperand(i, nullptr);
if (!OpV->use_empty() || I == OpV)
continue;
// If the operand is an instruction that became dead as we nulled out the
// operand, and if it is 'trivially' dead, delete it in a future loop
// iteration.
if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
if (isInstructionTriviallyDead(OpI, TLI)) {
WorkList.insert(OpI);
}
}
}
I->eraseFromParent();
return true;
}
return false;
}
/// This function handles following case
///
/// A -> B amxcast
/// PHI
/// B -> A amxcast
///
/// All the related PHI nodes can be replaced by new PHI nodes with type A.
/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
bool X86LowerAMXCast::optimizeAMXCastFromPhi(
IntrinsicInst *CI, PHINode *PN,
SmallSetVector<Instruction *, 16> &DeadInst) {
IRBuilder<> Builder(CI);
Value *Src = CI->getOperand(0);
Type *SrcTy = Src->getType(); // Type B
Type *DestTy = CI->getType(); // Type A
SmallVector<PHINode *, 4> PhiWorklist;
SmallSetVector<PHINode *, 4> OldPhiNodes;
// Find all of the A->B casts and PHI nodes.
// We need to inspect all related PHI nodes, but PHIs can be cyclic, so
// OldPhiNodes is used to track all known PHI nodes, before adding a new
// PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
PhiWorklist.push_back(PN);
OldPhiNodes.insert(PN);
while (!PhiWorklist.empty()) {
auto *OldPN = PhiWorklist.pop_back_val();
for (Value *IncValue : OldPN->incoming_values()) {
// TODO: currently, We ignore cases where it is a const. In the future, we
// might support const.
if (isa<Constant>(IncValue))
return false;
if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
if (OldPhiNodes.insert(PNode))
PhiWorklist.push_back(PNode);
continue;
}
Instruction *ACI = dyn_cast<Instruction>(IncValue);
if (ACI && isAMXCast(ACI)) {
// Verify it's a A->B cast.
Type *TyA = ACI->getOperand(0)->getType();
Type *TyB = ACI->getType();
if (TyA != DestTy || TyB != SrcTy)
return false;
continue;
}
return false;
}
}
// Check that each user of each old PHI node is something that we can
// rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
for (auto *OldPN : OldPhiNodes) {
for (User *V : OldPN->users()) {
Instruction *ACI = dyn_cast<Instruction>(V);
if (ACI && isAMXCast(ACI)) {
// Verify it's a B->A cast.
Type *TyB = ACI->getOperand(0)->getType();
Type *TyA = ACI->getType();
if (TyA != DestTy || TyB != SrcTy)
return false;
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
// As long as the user is another old PHI node, then even if we don't
// rewrite it, the PHI web we're considering won't have any users
// outside itself, so it'll be dead.
// example:
// bb.0:
// %0 = amxcast ...
// bb.1:
// %1 = amxcast ...
// bb.2:
// %goodphi = phi %0, %1
// %3 = amxcast %goodphi
// bb.3:
// %goodphi2 = phi %0, %goodphi
// %4 = amxcast %goodphi2
// When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
// outside the phi-web, so the combination stop When
// optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
// will be done.
if (OldPhiNodes.count(PHI) == 0)
return false;
} else
return false;
}
}
// For each old PHI node, create a corresponding new PHI node with a type A.
SmallDenseMap<PHINode *, PHINode *> NewPNodes;
for (auto *OldPN : OldPhiNodes) {
Builder.SetInsertPoint(OldPN);
PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
NewPNodes[OldPN] = NewPN;
}
// Fill in the operands of new PHI nodes.
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
Value *V = OldPN->getOperand(j);
Value *NewV = nullptr;
Instruction *ACI = dyn_cast<Instruction>(V);
// There should not be a AMXcast from a const.
if (ACI && isAMXCast(ACI))
NewV = ACI->getOperand(0);
else if (auto *PrevPN = dyn_cast<PHINode>(V))
NewV = NewPNodes[PrevPN];
assert(NewV);
NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
}
}
// Traverse all accumulated PHI nodes and process its users,
// which are Stores and BitcCasts. Without this processing
// NewPHI nodes could be replicated and could lead to extra
// moves generated after DeSSA.
// If there is a store with type B, change it to type A.
// Replace users of BitCast B->A with NewPHI. These will help
// later to get rid of a closure formed by OldPHI nodes.
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
for (User *V : make_early_inc_range(OldPN->users())) {
Instruction *ACI = dyn_cast<Instruction>(V);
if (ACI && isAMXCast(ACI)) {
Type *TyB = ACI->getOperand(0)->getType();
Type *TyA = ACI->getType();
assert(TyA == DestTy && TyB == SrcTy);
(void)TyA;
(void)TyB;
ACI->replaceAllUsesWith(NewPN);
DeadInst.insert(ACI);
} else if (auto *PHI = dyn_cast<PHINode>(V)) {
// We don't need to push PHINode into DeadInst since they are operands
// of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
assert(OldPhiNodes.contains(PHI));
(void)PHI;
} else
llvm_unreachable("all uses should be handled");
}
}
return true;
}
bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
bool Change = false;
// Collect tile cast instruction.
SmallVector<Instruction *, 8> Vec2TileInsts;
SmallVector<Instruction *, 8> Tile2VecInsts;
SmallVector<Instruction *, 8> PhiCastWorkList;
SmallSetVector<Instruction *, 16> DeadInst;
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
Value *Vec;
if (match(&I,
m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
Vec2TileInsts.push_back(&I);
else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
m_Value(Vec))))
Tile2VecInsts.push_back(&I);
}
}
auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
for (auto *Inst : Insts) {
for (User *U : Inst->users()) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
if (!II || II->getIntrinsicID() != IID)
continue;
// T1 = vec2tile V0
// V2 = tile2vec T1
// V3 = OP V2
// -->
// T1 = vec2tile V0
// V2 = tile2vec T1
// V3 = OP V0
II->replaceAllUsesWith(Inst->getOperand(0));
Change = true;
}
}
};
Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
for (auto *Inst : Insts) {
if (Inst->use_empty()) {
Inst->eraseFromParent();
Change = true;
}
}
};
EraseInst(Vec2TileInsts);
EraseInst(Tile2VecInsts);
// Handle the A->B->A cast, and there is an intervening PHI node.
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
if (isAMXCast(&I)) {
if (isa<PHINode>(I.getOperand(0)))
PhiCastWorkList.push_back(&I);
}
}
}
for (auto *I : PhiCastWorkList) {
// We skip the dead Amxcast.
if (DeadInst.contains(I))
continue;
PHINode *PN = cast<PHINode>(I->getOperand(0));
if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
DeadInst.insert(PN);
Change = true;
}
}
// Since we create new phi and merge AMXCast, some old phis and AMXCast might
// have no uses. We do some DeadCodeElimination for them.
while (!DeadInst.empty()) {
Instruction *I = DeadInst.pop_back_val();
Change |= DCEInstruction(I, DeadInst, TLI);
}
return Change;
}
// There might be remaining AMXcast after combineAMXcast and they should be
// handled elegantly.
bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
IRBuilder<> Builder(AMXCast);
AllocaInst *AllocaAddr;
Value *I8Ptr, *Stride;
auto *Src = AMXCast->getOperand(0);
auto Prepare = [&](Type *MemTy) {
AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
Stride = Builder.getInt64(64);
};
if (AMXCast->getType()->isX86_AMXTy()) {
// %2 = amxcast <225 x i32> %src to x86_amx
// call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
// i8* %addr3, i64 60, x86_amx %2)
// -->
// %addr = alloca <225 x i32>, align 64
// store <225 x i32> %src, <225 x i32>* %addr, align 64
// %addr2 = bitcast <225 x i32>* %addr to i8*
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
// i8* %addr2,
// i64 60)
// call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
// i8* %addr3, i64 60, x86_amx %2)
Use &U = *(AMXCast->use_begin());
unsigned OpNo = U.getOperandNo();
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
if (!II)
return false; // May be bitcast from x86amx to <256 x i32>.
Prepare(AMXCast->getOperand(0)->getType());
Builder.CreateStore(Src, AllocaAddr);
// TODO we can pick an constant operand for the shape.
Value *Row = nullptr, *Col = nullptr;
std::tie(Row, Col) = getShape(II, OpNo);
std::array<Value *, 4> Args = {
Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
Value *NewInst = Builder.CreateIntrinsic(
Intrinsic::x86_tileloadd64_internal, None, Args);
AMXCast->replaceAllUsesWith(NewInst);
AMXCast->eraseFromParent();
} else {
// %2 = amxcast x86_amx %src to <225 x i32>
// -->
// %addr = alloca <225 x i32>, align 64
// %addr2 = bitcast <225 x i32>* to i8*
// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
// i8* %addr2, i64 %stride)
// %2 = load <225 x i32>, <225 x i32>* %addr, align 64
auto *II = dyn_cast<IntrinsicInst>(Src);
if (!II)
return false; // May be bitcast from <256 x i32> to x86amx.
Prepare(AMXCast->getType());
Value *Row = II->getOperand(0);
Value *Col = II->getOperand(1);
std::array<Value *, 5> Args = {
Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
AMXCast->replaceAllUsesWith(NewInst);
AMXCast->eraseFromParent();
}
return true;
}
bool X86LowerAMXCast::transformAllAMXCast() {
bool Change = false;
// Collect tile cast instruction.
SmallVector<Instruction *, 8> WorkLists;
for (BasicBlock &BB : Func) {
for (Instruction &I : BB) {
if (isAMXCast(&I))
WorkLists.push_back(&I);
}
}
for (auto *Inst : WorkLists) {
Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
}
return Change;
}
} // anonymous namespace
namespace {
class X86LowerAMXTypeLegacyPass : public FunctionPass {
public:
static char ID;
X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
bool C = false;
TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
TargetLibraryInfo *TLI =
&getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
X86LowerAMXCast LAC(F);
C |= LAC.combineAMXcast(TLI);
// There might be remaining AMXcast after combineAMXcast and they should be
// handled elegantly.
C |= LAC.transformAllAMXCast();
X86LowerAMXType LAT(F);
C |= LAT.visit();
// Prepare for fast register allocation at O0.
// Todo: May better check the volatile model of AMX code, not just
// by checking Attribute::OptimizeNone and CodeGenOpt::None.
if (TM->getOptLevel() == CodeGenOpt::None) {
// If Front End not use O0 but the Mid/Back end use O0, (e.g.
// "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
// sure the amx data is volatile, that is nessary for AMX fast
// register allocation.
if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
X86VolatileTileData VTD(F);
C = VTD.volatileTileData() || C;
}
}
return C;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<TargetPassConfig>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
}
};
} // anonymous namespace
static const char PassName[] = "Lower AMX type for load/store";
char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
false)
FunctionPass *llvm::createX86LowerAMXTypePass() {
return new X86LowerAMXTypeLegacyPass();
}