blob: 190d9259e45b396266b9223db702c31b9fa60200 [file] [log] [blame] [edit]
//===- Utils.cpp - IR2Vec/MIR2Vec Embedding Generation Tool -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the IR2VecTool and MIR2VecTool classes for
/// IR2Vec/MIR2Vec embedding generation.
///
//===----------------------------------------------------------------------===//
#include "Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Target/TargetMachine.h"
#define DEBUG_TYPE "ir2vec"
namespace llvm {
namespace ir2vec {
Error IR2VecTool::initializeVocabulary(StringRef VocabPath) {
auto VocabOrErr = Vocabulary::fromFile(VocabPath);
if (!VocabOrErr)
return VocabOrErr.takeError();
Vocab = std::make_unique<Vocabulary>(std::move(*VocabOrErr));
if (!Vocab->isValid())
return createStringError(errc::invalid_argument,
"Failed to initialize IR2Vec vocabulary");
return Error::success();
}
TripletResult IR2VecTool::generateTriplets(const Function &F) const {
if (F.isDeclaration())
return {};
TripletResult Result;
Result.MaxRelation = 0;
unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
<< '\t'
<< Vocabulary::getVocabKeyForOpcode(Opcode + 1)
<< '\t' << "Next\n");
}
// Add "Type" relationship
Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
<< '\t' << "Type\n");
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
// Only update MaxRelation if there were operands
if (ArgIndex > 0)
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
Result.MaxRelation = MaxRelation;
return Result;
}
TripletResult IR2VecTool::generateTriplets() const {
TripletResult Result;
Result.MaxRelation = NextRelation;
for (const Function &F : M.getFunctionDefs()) {
TripletResult FuncResult = generateTriplets(F);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
}
return Result;
}
void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
auto Result = generateTriplets();
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
EntityList IR2VecTool::collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
Result.push_back(Vocabulary::getStringKey(EntityID).str());
return Result;
}
void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
auto Entities = collectEntityMappings();
OS << Entities.size() << "\n";
for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
WithColor::error(errs(), ToolName)
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}
for (const Function &F : M.getFunctionDefs())
writeEmbeddingsToStream(F, OS, Level);
}
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
WithColor::error(errs(), ToolName)
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
// Create embedder for this function
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
if (!Emb) {
WithColor::error(errs(), ToolName)
<< "Failed to create embedder for function " << F.getName() << "\n";
return;
}
OS << "Function: " << F.getName() << "\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
Emb->getFunctionVector().print(OS);
break;
case BasicBlockLevel:
for (const BasicBlock &BB : F) {
OS << BB.getName() << ":";
Emb->getBBVector(BB).print(OS);
}
break;
case InstructionLevel:
for (const Instruction &I : instructions(F)) {
OS << I;
Emb->getInstVector(I).print(OS);
}
break;
}
}
} // namespace ir2vec
namespace mir2vec {
bool MIR2VecTool::initializeVocabulary(const Module &M) {
MIR2VecVocabProvider Provider(MMI);
auto VocabOrErr = Provider.getVocabulary(M);
if (!VocabOrErr) {
WithColor::error(errs(), ToolName)
<< "Failed to load MIR2Vec vocabulary - "
<< toString(VocabOrErr.takeError()) << "\n";
return false;
}
Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
return true;
}
bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF)
continue;
const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
const MachineRegisterInfo &MRI = MF->getRegInfo();
auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
if (!VocabOrErr) {
WithColor::error(errs(), ToolName)
<< "Failed to create dummy vocabulary - "
<< toString(VocabOrErr.takeError()) << "\n";
return false;
}
Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
return true;
}
WithColor::error(errs(), ToolName)
<< "No machine functions found to initialize vocabulary\n";
return false;
}
TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
return Result;
}
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
if (MI.isDebugInstr())
continue;
// Get opcode entity ID
unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
}
// Add "Arg" relationships for operands
unsigned ArgIndex = 0;
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
<< "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
// Update MaxRelation if there were operands
if (ArgIndex > 0)
Result.MaxRelation =
std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
return Result;
}
TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
<< "No MachineFunction for " << F.getName() << "\n";
continue;
}
TripletResult FuncResult = generateTriplets(*MF);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
}
return Result;
}
void MIR2VecTool::writeTripletsToStream(const Module &M,
raw_ostream &OS) const {
auto Result = generateTriplets(M);
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
EntityList MIR2VecTool::collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
Result.push_back(Vocab->getStringKey(EntityID));
return Result;
}
void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
auto Entities = collectEntityMappings();
if (Entities.empty())
return;
OS << Entities.size() << "\n";
for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
<< "No MachineFunction for " << F.getName() << "\n";
continue;
}
writeEmbeddingsToStream(*MF, OS, Level);
}
}
void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
if (!Emb) {
WithColor::error(errs(), ToolName)
<< "Failed to create embedder for " << MF.getName() << "\n";
return;
}
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
OS << "Function vector: ";
Emb->getMFunctionVector().print(OS);
break;
case BasicBlockLevel:
OS << "Basic block vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
OS << "MBB " << MBB.getName() << ": ";
Emb->getMBBVector(MBB).print(OS);
}
break;
case InstructionLevel:
OS << "Instruction vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
OS << MI << " -> ";
Emb->getMInstVector(MI).print(OS);
}
}
break;
}
}
} // namespace mir2vec
} // namespace llvm