blob: b60488b2a1b6baa1f2715567aa9171447bced5c9 [file] [edit]
//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the IR2Vec and MIR2Vec embedding generation tool.
///
/// This tool supports two modes:
/// - LLVM IR mode (-mode=llvm): Process LLVM IR
/// - Machine IR mode (-mode=mir): Process Machine IR
///
/// Available subcommands:
///
/// 1. Triplet Generation (triplets):
/// Generates numeric triplets (head, tail, relation) for vocabulary
/// training. Output format: MAX_RELATION=N header followed by
/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
///
/// For LLVM IR:
/// llvm-ir2vec triplets input.bc -o train2id.txt
///
/// For Machine IR:
/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
///
/// 2. Entity Mappings (entities):
/// Generates entity mappings for vocabulary training.
/// Output format: <total_entities> header followed by entity\tid lines.
///
/// For LLVM IR:
/// llvm-ir2vec entities input.bc -o entity2id.txt
///
/// For Machine IR:
/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
///
/// 3. Embedding Generation (embeddings):
/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
///
/// For LLVM IR:
/// llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
/// --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt
/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware
///
/// For Machine IR:
/// llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json
/// --level=<level> input.mir -o embeddings.txt
///
/// Levels: --level=inst (instructions), --level=bb (basic blocks),
/// --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding
/// generation options)
///
//===----------------------------------------------------------------------===//
#include "llvm-ir2vec.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Host.h"
#define DEBUG_TYPE "ir2vec"
using namespace llvm;
using namespace llvm::ir2vec;
using namespace llvm::mir2vec;
// Common option category for options shared between IR2Vec and MIR2Vec
static cl::OptionCategory CommonCategory("Common Options",
"Options applicable to both IR2Vec "
"and MIR2Vec modes");
enum IRKind {
LLVMIR = 0, ///< LLVM IR
MIR ///< Machine IR
};
static cl::opt<IRKind>
IRMode("mode", cl::desc("Tool operation mode:"),
cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"),
clEnumValN(MIR, "mir", "Process Machine IR")),
cl::init(LLVMIR), cl::cat(CommonCategory));
// Subcommands
static cl::SubCommand
TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
static cl::SubCommand
EntitiesSubCmd("entities",
"Generate entity mappings for vocabulary training");
static cl::SubCommand
EmbeddingsSubCmd("embeddings",
"Generate embeddings using trained vocabulary");
// Common options
static cl::opt<std::string> InputFilename(
cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"),
cl::cat(CommonCategory));
// Embedding-specific options
static cl::opt<std::string>
FunctionName("function", cl::desc("Process specific function only"),
cl::value_desc("name"), cl::Optional, cl::init(""),
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
static cl::opt<EmbeddingLevel>
Level("level", cl::desc("Embedding generation level:"),
cl::values(clEnumValN(InstructionLevel, "inst",
"Generate instruction-level embeddings"),
clEnumValN(BasicBlockLevel, "bb",
"Generate basic block-level embeddings"),
clEnumValN(FunctionLevel, "func",
"Generate function-level embeddings")),
cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
cl::cat(CommonCategory));
bool IR2VecTool::initializeVocabulary() {
// Register and run the IR2Vec vocabulary analysis
// The vocabulary file path is specified via --ir2vec-vocab-path global
// option
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
// This will throw an error if vocab is not found or invalid
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
return Vocab->isValid();
}
TripletResult IR2VecTool::generateTriplets(const Function &F) const {
if (F.isDeclaration())
return {};
TripletResult Result;
Result.MaxRelation = 0;
unsigned MaxRelation = NextRelation;
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
for (const BasicBlock &BB : F) {
for (const auto &I : BB.instructionsWithoutDebug()) {
unsigned Opcode = Vocabulary::getIndex(I.getOpcode());
unsigned TypeID = Vocabulary::getIndex(I.getType()->getTypeID());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, Opcode, NextRelation});
LLVM_DEBUG(dbgs() << Vocabulary::getVocabKeyForOpcode(PrevOpcode + 1)
<< '\t'
<< Vocabulary::getVocabKeyForOpcode(Opcode + 1)
<< '\t' << "Next\n");
}
// Add "Type" relationship
Result.Triplets.push_back({Opcode, TypeID, TypeRelation});
LLVM_DEBUG(
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID())
<< '\t' << "Type\n");
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
unsigned OperandID = Vocabulary::getIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
Result.Triplets.push_back({Opcode, OperandID, RelationID});
LLVM_DEBUG({
StringRef OperandStr = Vocabulary::getVocabKeyForOperandKind(
Vocabulary::getOperandKind(U.get()));
dbgs() << Vocabulary::getVocabKeyForOpcode(Opcode + 1) << '\t'
<< OperandStr << '\t' << "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
// Only update MaxRelation if there were operands
if (ArgIndex > 0)
MaxRelation = std::max(MaxRelation, ArgRelation + ArgIndex - 1);
PrevOpcode = Opcode;
HasPrevOpcode = true;
}
}
Result.MaxRelation = MaxRelation;
return Result;
}
TripletResult IR2VecTool::generateTriplets() const {
TripletResult Result;
Result.MaxRelation = NextRelation;
for (const Function &F : M.getFunctionDefs()) {
TripletResult FuncResult = generateTriplets(F);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
}
return Result;
}
void IR2VecTool::writeTripletsToStream(raw_ostream &OS) const {
auto Result = generateTriplets();
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
EntityList IR2VecTool::collectEntityMappings() {
auto EntityLen = Vocabulary::getCanonicalSize();
EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityLen; ++EntityID)
Result.push_back(Vocabulary::getStringKey(EntityID).str());
return Result;
}
void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
auto Entities = collectEntityMappings();
OS << Entities.size() << "\n";
for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
WithColor::error(errs(), ToolName)
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}
for (const Function &F : M.getFunctionDefs())
writeEmbeddingsToStream(F, OS, Level);
}
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab || !Vocab->isValid()) {
WithColor::error(errs(), ToolName)
<< "Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}
// Create embedder for this function
auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
if (!Emb) {
WithColor::error(errs(), ToolName)
<< "Failed to create embedder for function " << F.getName() << "\n";
return;
}
OS << "Function: " << F.getName() << "\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
Emb->getFunctionVector().print(OS);
break;
case BasicBlockLevel:
for (const BasicBlock &BB : F) {
OS << BB.getName() << ":";
Emb->getBBVector(BB).print(OS);
}
break;
case InstructionLevel:
for (const Instruction &I : instructions(F)) {
OS << I;
Emb->getInstVector(I).print(OS);
}
break;
}
}
/// Process the module and generate output based on selected subcommand
static Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
if (EmbeddingsSubCmd) {
// Initialize vocabulary for embedding generation
// Note: Requires --ir2vec-vocab-path option to be set
auto VocabStatus = Tool.initializeVocabulary();
assert(VocabStatus && "Failed to initialize IR2Vec vocabulary");
(void)VocabStatus;
if (!FunctionName.empty()) {
// Process single function
if (const Function *F = M.getFunction(FunctionName))
Tool.writeEmbeddingsToStream(*F, OS, Level);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
} else {
// Process all functions
Tool.writeEmbeddingsToStream(OS, Level);
}
} else {
// Both triplets and entities use triplet generation
Tool.writeTripletsToStream(OS);
}
return Error::success();
}
bool MIR2VecTool::initializeVocabulary(const Module &M) {
MIR2VecVocabProvider Provider(MMI);
auto VocabOrErr = Provider.getVocabulary(M);
if (!VocabOrErr) {
WithColor::error(errs(), ToolName)
<< "Failed to load MIR2Vec vocabulary - "
<< toString(VocabOrErr.takeError()) << "\n";
return false;
}
Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
return true;
}
bool MIR2VecTool::initializeVocabularyForLayout(const Module &M) {
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF)
continue;
const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
const MachineRegisterInfo &MRI = MF->getRegInfo();
auto VocabOrErr = MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
if (!VocabOrErr) {
WithColor::error(errs(), ToolName)
<< "Failed to create dummy vocabulary - "
<< toString(VocabOrErr.takeError()) << "\n";
return false;
}
Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
return true;
}
WithColor::error(errs(), ToolName)
<< "No machine functions found to initialize vocabulary\n";
return false;
}
TripletResult MIR2VecTool::generateTriplets(const MachineFunction &MF) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "MIR Vocabulary must be initialized for triplet generation.\n";
return Result;
}
unsigned PrevOpcode = 0;
bool HasPrevOpcode = false;
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
// Skip debug instructions
if (MI.isDebugInstr())
continue;
// Get opcode entity ID
unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
// Add "Next" relationship with previous instruction
if (HasPrevOpcode) {
Result.Triplets.push_back({PrevOpcode, OpcodeID, MIRNextRelation});
LLVM_DEBUG(dbgs() << Vocab->getStringKey(PrevOpcode) << '\t'
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
}
// Add "Arg" relationships for operands
unsigned ArgIndex = 0;
for (const MachineOperand &MO : MI.operands()) {
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
unsigned RelationID = MIRArgRelation + ArgIndex;
Result.Triplets.push_back({OpcodeID, OperandID, RelationID});
LLVM_DEBUG({
std::string OperandStr = Vocab->getStringKey(OperandID);
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr << '\t'
<< "Arg" << ArgIndex << '\n';
});
++ArgIndex;
}
// Update MaxRelation if there were operands
if (ArgIndex > 0)
Result.MaxRelation =
std::max(Result.MaxRelation, MIRArgRelation + ArgIndex - 1);
PrevOpcode = OpcodeID;
HasPrevOpcode = true;
}
}
return Result;
}
TripletResult MIR2VecTool::generateTriplets(const Module &M) const {
TripletResult Result;
Result.MaxRelation = MIRNextRelation;
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
<< "No MachineFunction for " << F.getName() << "\n";
continue;
}
TripletResult FuncResult = generateTriplets(*MF);
Result.MaxRelation = std::max(Result.MaxRelation, FuncResult.MaxRelation);
Result.Triplets.insert(Result.Triplets.end(), FuncResult.Triplets.begin(),
FuncResult.Triplets.end());
}
return Result;
}
void MIR2VecTool::writeTripletsToStream(const Module &M,
raw_ostream &OS) const {
auto Result = generateTriplets(M);
OS << "MAX_RELATION=" << Result.MaxRelation << '\n';
for (const auto &T : Result.Triplets)
OS << T.Head << '\t' << T.Tail << '\t' << T.Relation << '\n';
}
EntityList MIR2VecTool::collectEntityMappings() const {
if (!Vocab) {
WithColor::error(errs(), ToolName)
<< "Vocabulary must be initialized for entity mappings.\n";
return {};
}
const unsigned EntityCount = Vocab->getCanonicalSize();
EntityList Result;
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
Result.push_back(Vocab->getStringKey(EntityID));
return Result;
}
void MIR2VecTool::writeEntitiesToStream(raw_ostream &OS) const {
auto Entities = collectEntityMappings();
if (Entities.empty())
return;
OS << Entities.size() << "\n";
for (unsigned EntityID = 0; EntityID < Entities.size(); ++EntityID)
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
void MIR2VecTool::writeEmbeddingsToStream(const Module &M, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
for (const Function &F : M.getFunctionDefs()) {
MachineFunction *MF = MMI.getMachineFunction(F);
if (!MF) {
WithColor::warning(errs(), ToolName)
<< "No MachineFunction for " << F.getName() << "\n";
continue;
}
writeEmbeddingsToStream(*MF, OS, Level);
}
}
void MIR2VecTool::writeEmbeddingsToStream(MachineFunction &MF, raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab) {
WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n";
return;
}
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab);
if (!Emb) {
WithColor::error(errs(), ToolName)
<< "Failed to create embedder for " << MF.getName() << "\n";
return;
}
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel:
OS << "Function vector: ";
Emb->getMFunctionVector().print(OS);
break;
case BasicBlockLevel:
OS << "Basic block vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
OS << "MBB " << MBB.getName() << ": ";
Emb->getMBBVector(MBB).print(OS);
}
break;
case InstructionLevel:
OS << "Instruction vectors:\n";
for (const MachineBasicBlock &MBB : MF) {
for (const MachineInstr &MI : MBB) {
OS << MI << " -> ";
Emb->getMInstVector(MI).print(OS);
}
}
break;
}
}
/// Setup MIR context from input file
static Error setupMIRContext(const std::string &InputFile, MIRContext &Ctx) {
SMDiagnostic Err;
auto MIR = createMIRParserFromFile(InputFile, Err, Ctx.Context);
if (!MIR) {
Err.print(ToolName, errs());
return createStringError(errc::invalid_argument,
"Failed to parse MIR file");
}
auto SetDataLayout = [&](StringRef DataLayoutTargetTriple,
StringRef OldDLStr) -> std::optional<std::string> {
std::string IRTargetTriple = DataLayoutTargetTriple.str();
Triple TheTriple = Triple(IRTargetTriple);
if (TheTriple.getTriple().empty())
TheTriple.setTriple(sys::getDefaultTargetTriple());
auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str());
if (!TMOrErr) {
Err.print(ToolName, errs());
exit(1); // Match original behavior
}
Ctx.TM = std::move(*TMOrErr);
return Ctx.TM->createDataLayout().getStringRepresentation();
};
Ctx.M = MIR->parseIRModule(SetDataLayout);
if (!Ctx.M) {
Err.print(ToolName, errs());
return createStringError(errc::invalid_argument,
"Failed to parse IR module");
}
Ctx.MMI = std::make_unique<MachineModuleInfo>(Ctx.TM.get());
if (!Ctx.MMI || MIR->parseMachineFunctions(*Ctx.M, *Ctx.MMI)) {
Err.print(ToolName, errs());
return createStringError(errc::invalid_argument,
"Failed to parse machine functions");
}
return Error::success();
}
/// Generic vocabulary initialization and processing
template <typename ProcessFunc>
static Error processWithVocabulary(MIRContext &Ctx, raw_ostream &OS,
bool useLayoutVocab, ProcessFunc processFn) {
MIR2VecTool Tool(*Ctx.MMI);
// Initialize appropriate vocabulary type
bool success = useLayoutVocab ? Tool.initializeVocabularyForLayout(*Ctx.M)
: Tool.initializeVocabulary(*Ctx.M);
if (!success) {
WithColor::error(errs(), ToolName)
<< "Failed to initialize MIR2Vec vocabulary"
<< (useLayoutVocab ? " for layout" : "") << ".\n";
return createStringError(errc::invalid_argument,
"Vocabulary initialization failed");
}
assert(Tool.getVocabulary() &&
"MIR2Vec vocabulary should be initialized at this point");
LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
<< "Vocabulary dimension: "
<< Tool.getVocabulary()->getDimension() << "\n"
<< "Vocabulary size: "
<< Tool.getVocabulary()->getCanonicalSize() << "\n");
// Execute the specific processing logic
return processFn(Tool);
}
/// Process module for triplet generation
static Error processModuleForTriplets(MIRContext &Ctx, raw_ostream &OS) {
return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
[&](MIR2VecTool &Tool) -> Error {
Tool.writeTripletsToStream(*Ctx.M, OS);
return Error::success();
});
}
/// Process module for entity generation
static Error processModuleForEntities(MIRContext &Ctx, raw_ostream &OS) {
return processWithVocabulary(Ctx, OS, /*useLayoutVocab=*/true,
[&](MIR2VecTool &Tool) -> Error {
Tool.writeEntitiesToStream(OS);
return Error::success();
});
}
/// Process module for embedding generation
static Error processModuleForEmbeddings(MIRContext &Ctx, raw_ostream &OS) {
return processWithVocabulary(
Ctx, OS, /*useLayoutVocab=*/false, [&](MIR2VecTool &Tool) -> Error {
if (!FunctionName.empty()) {
// Process single function
Function *F = Ctx.M->getFunction(FunctionName);
if (!F) {
WithColor::error(errs(), ToolName)
<< "Function '" << FunctionName << "' not found\n";
return createStringError(errc::invalid_argument,
"Function not found");
}
MachineFunction *MF = Ctx.MMI->getMachineFunction(*F);
if (!MF) {
WithColor::error(errs(), ToolName)
<< "No MachineFunction for " << FunctionName << "\n";
return createStringError(errc::invalid_argument,
"No MachineFunction");
}
Tool.writeEmbeddingsToStream(*MF, OS, Level);
} else {
// Process all functions
Tool.writeEmbeddingsToStream(*Ctx.M, OS, Level);
}
return Error::success();
});
}
/// Main entry point for MIR processing
static Error processModule(const std::string &InputFile, raw_ostream &OS) {
MIRContext Ctx;
// Setup MIR context (parse file, setup target machine, etc.)
if (auto Err = setupMIRContext(InputFile, Ctx))
return Err;
// Process based on subcommand
if (TripletsSubCmd)
return processModuleForTriplets(Ctx, OS);
else if (EntitiesSubCmd)
return processModuleForEntities(Ctx, OS);
else if (EmbeddingsSubCmd)
return processModuleForEmbeddings(Ctx, OS);
else {
WithColor::error(errs(), ToolName)
<< "Please specify a subcommand: triplets, entities, or embeddings\n";
return createStringError(errc::invalid_argument, "No subcommand specified");
}
}
int main(int argc, char **argv) {
InitLLVM X(argc, argv);
// Show Common, IR2Vec and MIR2Vec option categories
cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{
&CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory});
cl::ParseCommandLineOptions(
argc, argv,
"IR2Vec/MIR2Vec - Embedding Generation Tool\n"
"Generates embeddings for a given LLVM IR or MIR and "
"supports triplet generation for vocabulary "
"training and embedding generation.\n\n"
"See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
"information.\n");
std::error_code EC;
raw_fd_ostream OS(OutputFilename, EC);
if (EC) {
WithColor::error(errs(), ToolName)
<< "opening output file: " << EC.message() << "\n";
return 1;
}
if (IRMode == IRKind::LLVMIR) {
if (EntitiesSubCmd) {
// Just dump entity mappings without processing any IR
IR2VecTool::writeEntitiesToStream(OS);
return 0;
}
// Parse the input LLVM IR file or stdin
SMDiagnostic Err;
LLVMContext Context;
std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
if (!M) {
Err.print(ToolName, errs());
return 1;
}
if (Error Err = processModule(*M, OS)) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
WithColor::error(errs(), ToolName) << EIB.message() << "\n";
});
return 1;
}
return 0;
}
if (IRMode == IRKind::MIR) {
// Initialize targets for Machine IR processing
InitializeAllTargets();
InitializeAllTargetMCs();
InitializeAllAsmParsers();
InitializeAllAsmPrinters();
static codegen::RegisterCodeGenFlags CGF;
if (Error Err = processModule(InputFilename, OS)) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) {
WithColor::error(errs(), ToolName) << EIB.message() << "\n";
});
return 1;
}
return 0;
}
return 0;
}