blob: e1e5fad13f413e32702842b30efe7ef79db296f1 [file] [log] [blame]
//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the IR2Vec embedding generation tool.
///
/// This tool provides two main functionalities:
///
/// 1. Triplet Generation Mode (--mode=triplets):
/// Generates triplets (opcode, type, operands) for vocabulary training.
/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
///
/// 2. Embedding Generation Mode (--mode=embeddings):
/// Generates IR2Vec embeddings using a trained vocabulary.
/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
/// --level=func input.bc -o embeddings.txt Levels: --level=inst
/// (instructions), --level=bb (basic blocks), --level=func (functions)
/// (See IR2Vec.cpp for more embedding generation options)
///
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "ir2vec"
namespace llvm {
namespace ir2vec {
static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options");
static cl::opt<std::string>
InputFilename(cl::Positional,
cl::desc("<input bitcode file or '-' for stdin>"),
cl::init("-"), cl::cat(IR2VecToolCategory));
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"),
cl::cat(IR2VecToolCategory));
enum ToolMode {
TripletMode, // Generate triplets for vocabulary training
EmbeddingMode // Generate embeddings using trained vocabulary
};
static cl::opt<ToolMode>
Mode("mode", cl::desc("Tool operation mode:"),
cl::values(clEnumValN(TripletMode, "triplets",
"Generate triplets for vocabulary training"),
clEnumValN(EmbeddingMode, "embeddings",
"Generate embeddings using trained vocabulary")),
cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
static cl::opt<std::string>
FunctionName("function", cl::desc("Process specific function only"),
cl::value_desc("name"), cl::Optional, cl::init(""),
cl::cat(IR2VecToolCategory));
enum EmbeddingLevel {
InstructionLevel, // Generate instruction-level embeddings
BasicBlockLevel, // Generate basic block-level embeddings
FunctionLevel // Generate function-level embeddings
};
static cl::opt<EmbeddingLevel>
Level("level", cl::desc("Embedding generation level (for embedding mode):"),
cl::values(clEnumValN(InstructionLevel, "inst",
"Generate instruction-level embeddings"),
clEnumValN(BasicBlockLevel, "bb",
"Generate basic block-level embeddings"),
clEnumValN(FunctionLevel, "func",
"Generate function-level embeddings")),
cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
namespace {
/// Helper class for collecting IR triplets and generating embeddings
class IR2VecTool {
private:
Module &M;
ModuleAnalysisManager MAM;
const Vocabulary *Vocab = nullptr;
public:
explicit IR2VecTool(Module &M) : M(M) {}
/// Initialize the IR2Vec vocabulary analysis
bool initializeVocabulary() {
// Register and run the IR2Vec vocabulary analysis
// The vocabulary file path is specified via --ir2vec-vocab-path global
// option
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
return Vocab->isValid();
}
/// Generate triplets for the entire module
void generateTriplets(raw_ostream &OS) const {
for (const Function &F : M)
generateTriplets(F, OS);
}
/// Generate triplets for a single function
void generateTriplets(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration())
return;
std::string LocalOutput;
raw_string_ostream LocalOS(LocalOutput);
for (const BasicBlock &BB : F)
traverseBasicBlock(BB, LocalOS);
LocalOS.flush();
OS << LocalOutput;
}
/// Generate embeddings for the entire module
void generateEmbeddings(raw_ostream &OS) const {
if (!Vocab->isValid()) {
OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}
for (const Function &F : M)
generateEmbeddings(F, OS);
}
/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
// Create embedder for this function
assert(Vocab->isValid() && "Vocabulary is not valid");
auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab);
if (!Emb) {
OS << "Error: Failed to create embedder for function " << F.getName()
<< "\n";
return;
}
OS << "Function: " << F.getName() << "\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel: {
Emb->getFunctionVector().print(OS);
break;
}
case BasicBlockLevel: {
const auto &BBVecMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBVecMap.find(&BB);
if (It != BBVecMap.end()) {
OS << BB.getName() << ":";
It->second.print(OS);
}
}
break;
}
case InstructionLevel: {
const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
auto It = InstMap.find(&I);
if (It != InstMap.end()) {
I.print(OS);
It->second.print(OS);
}
}
}
break;
}
}
}
private:
/// Process a single basic block for triplet generation
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
// Consider only non-debug and non-pseudo instructions
for (const auto &I : BB.instructionsWithoutDebug()) {
StringRef OpcStr = Vocabulary::getVocabKeyForOpcode(I.getOpcode());
StringRef TypeStr =
Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID());
OS << '\n' << OpcStr << ' ' << TypeStr << ' ';
LLVM_DEBUG({
I.print(dbgs());
dbgs() << "\n";
I.getType()->print(dbgs());
dbgs() << " Type\n";
});
for (const Use &U : I.operands())
OS << Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()))
<< ' ';
}
}
};
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
if (Mode == EmbeddingMode) {
// Initialize vocabulary for embedding generation
// Note: Requires --ir2vec-vocab-path option to be set
if (!Tool.initializeVocabulary())
return createStringError(
errc::invalid_argument,
"Failed to initialize IR2Vec vocabulary. "
"Make sure to specify --ir2vec-vocab-path for embedding mode.");
if (!FunctionName.empty()) {
// Process single function
if (const Function *F = M.getFunction(FunctionName))
Tool.generateEmbeddings(*F, OS);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
} else {
// Process all functions
Tool.generateEmbeddings(OS);
}
} else {
// Triplet generation mode - no vocabulary needed
if (!FunctionName.empty())
// Process single function
if (const Function *F = M.getFunction(FunctionName))
Tool.generateTriplets(*F, OS);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
else
// Process all functions
Tool.generateTriplets(OS);
}
return Error::success();
}
} // namespace
} // namespace ir2vec
} // namespace llvm
int main(int argc, char **argv) {
using namespace llvm;
using namespace llvm::ir2vec;
InitLLVM X(argc, argv);
cl::HideUnrelatedOptions(IR2VecToolCategory);
cl::ParseCommandLineOptions(
argc, argv,
"IR2Vec - Embedding Generation Tool\n"
"Generates embeddings for a given LLVM IR and "
"supports triplet generation for vocabulary "
"training and embedding generation.\n\n"
"See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
"information.\n");
// Validate command line options
if (Mode == TripletMode && Level.getNumOccurrences() > 0)
errs() << "Warning: --level option is ignored in triplet mode\n";
// Parse the input LLVM IR file or stdin
SMDiagnostic Err;
LLVMContext Context;
std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
if (!M) {
Err.print(argv[0], errs());
return 1;
}
std::error_code EC;
raw_fd_ostream OS(OutputFilename, EC);
if (EC) {
errs() << "Error opening output file: " << EC.message() << "\n";
return 1;
}
if (Error Err = processModule(*M, OS)) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
errs() << "Error: " << EIB.message() << "\n";
});
return 1;
}
return 0;
}